├── .gitignore ├── LICENSE ├── README.md ├── requirements.txt └── src ├── model ├── README.md ├── data_loader.py ├── eval_utils.py ├── generate_stories.py ├── generate_utils.py ├── logger.py ├── loss.py ├── model.py ├── parallel.py └── train.py └── preprocessing ├── README.md ├── extract_outlines.py ├── nyt_splits.txt └── wikiplots_splits.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | 3 | 4 | __pycache__/ 5 | checkpoints_local/ 6 | generate.sh 7 | out*/ 8 | savedir/ 9 | train.sh 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 hrashkin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # plotmachines 2 | Cleaned up version of the PlotMachines code 3 | 4 | 5 | 6 | ### Preprocessing code: 7 | 8 | code located in src/preprocessing (follow instructions in the README in that directory) 9 | 10 | ### PlotMachines model: 11 | 12 | code located in src/model (follow instructions in the README in that directory) 13 | 14 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | pyrouge 3 | rouge 4 | nltk 5 | ftfy 6 | spacy 7 | pytorch-transformers 8 | transformers == 2.0.0 9 | -------------------------------------------------------------------------------- /src/model/README.md: -------------------------------------------------------------------------------- 1 | # Instructions for training and running models: 2 | ## Prerequisites 3 | 4 | pytorch 5 | 6 | transformers (v. 2.0.0) 7 | 8 | spacy 9 | 10 | nltk 11 | 12 | tqdm 13 | 14 | rouge 15 | 16 | shutil 17 | 18 | 19 | ## Training PlotMachines models 20 | Using train.py, 21 | E.g.: 22 | ```python train.py --data_dir datadir --output_dir savedir --experiment_name pmfull --accum_iter 4 --n_batch 64 --p 90 --num_epochs 10 --use_model plotmachines --use_neighbor_feat --use_discourse``` 23 | 24 | 25 | ### Important command line arguments: 26 | - ```use_neighbor_feat``` : use representation of previous paragraph in input (i.e preceding context) 27 | - ```use_discourse``` : whether to use discourse type tags (`_i_`,`_b_`,`_c_`) or not 28 | - ```use_model={base/plotmachines}```: either the base gpt model without memory, or PlotMachines with memory 29 | - ```memstatesize={int}```: size of additional memory slots aside from the ones initialized from the outline (default:100) 30 | - ```n_batch={int}```: must be mulitple of number of gpus 31 | - ```output_dir```: a directory to save outputs to 32 | - ```data_dir```: location of all of the train/dev input files, each of which must be named {train/val/test}\_encoded.jsonl, should also contain {train/val/test}\_gpt.pkl files where the encoding of the previous paragraph is stored offline 33 | - ```p ={int}```: the % to use in nucleus sampling 34 | - ```repeattheta={float}```: how much to penalize repetitions. should be a float >= 1. (1=no penalty) 35 | 36 | ### Output files 37 | At the end of running the outputs are stored in: 38 | - `{output_dir}/{experiment_name}/checkpoints/checkpoint_best.pt`: best checkpoint from training 39 | - `{output_dir}/{experiment_name}/logs/output_losses.tsv` : log of loss scores on train/val examples throughout training 40 | - `{output_dir}/{experiment_name}/logs/output_rouge.tsv` : log of rouge scores on five val examples throughout training 41 | 42 | 43 | ## Generating stories 44 | Using generate_stories.py, 45 | E.g.: 46 | ```python generate_stories.py --data_dir datadir --save_dir outputdir --n_batch 64 --p 90 --load_dir savedir/pmfull/checkpoints --use_model plotmachines --use_neighbor_feat --use_discourse``` 47 | 48 | ### Important command line arguments: 49 | - ```bodynum={int}```: number of body paragraphs to generate (default=3, for 5 paragraph format) 50 | - ```testset```: use test set instead of validation 51 | - ```use_neighbor_feat``` : use representation of previous paragraph in input (i.e preceding context) 52 | - ```use_discourse``` : whether to use discourse type tags (`_i_`,`_b_`,`_c_`) or not 53 | - ```use_model={base/plotmachines}```: either the base gpt model without memory, or PlotMachines with memory 54 | - ```memstatesize={int}```: size of additional memory slots aside from the ones initialized from the outline (default:100) 55 | - ```n_batch={int}```: must be mulitple of number of gpus 56 | - ```save_dir```: a directory to save generatins to 57 | - ```data_dir```: location of all of the train/dev input files, each of which must be named {train/val/test}\_encoded.jsonl (this script doesnt use the pkl files) 58 | - ```p ={int}```: the % to use in nucleus sampling 59 | - ```repeattheta={float}```: how much to penalize repetitions. should be a float >= 1. (1=no penalty) 60 | - ```load_dir={str}```: the location of checkpoint_best.pt saved from training 61 | 62 | ### Output format 63 | At the end of running the generated story outputs are stored in `output_dir`: 64 | - `{val/test}eval.tsv`: generated stories 65 | 66 | Note, each row is a single story paragraph and the paragraphs of each story might not be in contiguous order. 67 | 68 | Each row contains: 69 | 70 | ```story-idx \t story-name \t plot-outline \t paragraph-idx \t paragraph-text``` 71 | 72 | 73 | 74 | ## Additional Acknowledgements 75 | Thanks to other codebases that were used in writing this code: 76 | - Huggingface's original gpt repo 77 | - Huggingface's current transformers repo 78 | - Transformer for abstractive summarization (used for parallel model classes): https://github.com/Andrew03/transformer-abstractive-summarization 79 | -------------------------------------------------------------------------------- /src/model/data_loader.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader 5 | import json 6 | from pytorch_transformers import * 7 | import pickle 8 | 9 | 10 | 11 | '''Paragraph Dataset: get a single paragraph from a story 12 | optionally include the previous paragraph encoding (inckude_neigh) 13 | optionally include discourse tag delimiter (include_discourse_type) 14 | ''' 15 | class ParagraphDataset(Dataset): 16 | def __init__(self, data_file, encoder, max_size=None, n_ctx=102, n_gen=401, include_neigh=False, 17 | include_discourse_type=True, include_kw=True, dim=0 ,debug_mode=False): 18 | with open(data_file, "rb") as f: 19 | self.data = f.readlines() 20 | 21 | if include_neigh: 22 | self.prev = [] 23 | fn = ".".join(data_file.split(".")[:-1]) + "_gpt2.pkl" 24 | if debug_mode: 25 | fn = ".".join(data_file.split(".")[:-1]) + "_gpt.pkl" 26 | with open(fn, 'rb') as fp: 27 | for k in range(len(self.data)): 28 | temp = pickle.load(fp) 29 | assert temp[0] == k and temp[1] == self.data[k].decode('utf-8', 'ignore').split("\t")[-1].replace( 30 | "", "").strip() 31 | self.prev.append(temp[2]) 32 | else: 33 | self.prev = None 34 | 35 | self.dids = [] 36 | for d in range(1, len(self.data)): 37 | t = self.data[d].decode("utf-8", "ignore").strip().split('\t') 38 | if len(t) == 7 and t[5].replace("", "").strip() != "": 39 | try: 40 | x, y = int(t[0].split("_")[-1]), int(t[4]) 41 | self.dids.append(d) 42 | except: 43 | pass 44 | 45 | if max_size is not None: 46 | self.dids = self.dids[:max_size] 47 | self.encoder = encoder 48 | self.ctx = n_ctx - 2 49 | self.gen = n_gen - 1 50 | self.dim = dim 51 | self.len = len(self.data) 52 | self.include_neigh = include_neigh 53 | self.include_discourse_type = include_discourse_type 54 | self.include_kw = include_kw 55 | 56 | 57 | def __getitem__(self, index): 58 | idx = self.dids[index] 59 | csv_data = self.data[idx].decode("utf-8", "ignore").strip().split('\t') 60 | kws = csv_data[2].split("[SEP]") 61 | # print(self.encoder.encode(csv_data[5])) 62 | tgt_phrase = self.encoder.encode(csv_data[5].replace("", ""), add_prefix_space=True, add_special_tokens=False)[:self.gen] 63 | start = torch.LongTensor([self.encoder.bos_token_id]) 64 | clstok = torch.LongTensor([self.encoder.cls_token_id]) 65 | end = torch.LongTensor([self.encoder.eos_token_id]) 66 | tstart = torch.LongTensor([self.encoder.convert_tokens_to_ids('_t_')]) 67 | istart = torch.LongTensor([self.encoder.convert_tokens_to_ids('_i_')]) 68 | bstart = torch.LongTensor([self.encoder.convert_tokens_to_ids('_b_')]) 69 | cstart = torch.LongTensor([self.encoder.convert_tokens_to_ids('_c_')]) 70 | keytok = torch.LongTensor([self.encoder.convert_tokens_to_ids('_kw_')]) 71 | endkeytok = torch.LongTensor([self.encoder.convert_tokens_to_ids('_endkw_')]) 72 | if self.include_discourse_type: 73 | starttyptok = bstart 74 | if int(csv_data[0].split("_")[-1]) == 0: 75 | starttyptok = istart 76 | elif int(csv_data[0].split("_")[-1]) == int(csv_data[4]) - 1: 77 | starttyptok = cstart 78 | else: 79 | starttyptok = clstok 80 | 81 | pad_output = torch.zeros(self.ctx + self.gen + 3).long() 82 | mask_output = torch.zeros(self.ctx + self.gen + 3).long() 83 | 84 | pad_output[0] = start 85 | if self.include_kw: 86 | i = 1 87 | for k in kws: 88 | if i - 1 >= self.ctx: 89 | break 90 | enck = self.encoder.encode(k.strip(), add_prefix_space=True, add_special_tokens=False)[:self.ctx - i] 91 | # print(enck, i) 92 | pad_output[i:i + len(enck)] = torch.LongTensor(enck) 93 | pad_output[i + len(enck)] = keytok 94 | i += len(enck) + 1 95 | pad_output[i - 1] = endkeytok 96 | mask_output[0:i] = torch.ones(i).long() 97 | 98 | pad_output[self.ctx + 1] = starttyptok if self.include_discourse_type else clstok # [101] -> discourse tag 99 | pad_output[self.ctx + 1 + 1:self.ctx + 1 + 1 + len(tgt_phrase)] = torch.LongTensor(tgt_phrase) 100 | pad_output[self.ctx + 1 + 1 + len(tgt_phrase)] = end 101 | 102 | # Mask 103 | mask_output[self.ctx + 1:self.ctx + 1 + len(tgt_phrase) + 2] = torch.ones(len(tgt_phrase) + 2).long() 104 | 105 | if self.include_neigh: 106 | n = torch.FloatTensor(self.prev[idx].flatten()) 107 | else: 108 | n = torch.zeros(self.dim, dtype=torch.float64) 109 | return pad_output, mask_output, n 110 | 111 | def __len__(self): 112 | return len(self.dids) 113 | 114 | ''' 115 | get_paragraph_input_loader: Get data loader for plot machines 116 | 117 | @params- 118 | data_file: the file location for the data 119 | encoder: tokenizer 120 | max_size: truncate to # examples (or None to use full dataset) 121 | n_ctx: number of plotoutline tokens + 2 for delimiters 122 | gen_len: number of paragraph tokens + 1 for end token 123 | include_neigh: whether to return the neighboring (previous) paragraph encoding 124 | include_discourse_type: whether to use special discouse tokens 125 | include_kw: unused, if I want to ignore the context 126 | dim: the dimension of the neighboring (previous) paragraph vectors (should be same as the PlotMachines embeddings) 127 | debug_mode: make a toy dataset for debugging 128 | ''' 129 | def get_paragraph_input_loader(data_file, batch_size, encoder, shuffle=True, num_workers=0, max_size=None, n_ctx=102, 130 | gen_len=401, include_neigh=False, include_discourse_type=True, include_kw=True, 131 | dim=768, debug_mode=False): 132 | dataset = ParagraphDataset(data_file, encoder, max_size=max_size, n_ctx=n_ctx, n_gen=gen_len, 133 | include_neigh=include_neigh,include_discourse_type=include_discourse_type, 134 | include_kw=include_kw, dim=dim,debug_mode=debug_mode) 135 | return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=True) 136 | 137 | 138 | 139 | 140 | ''' 141 | FullStory Dataset: get a full multi-paragraph story 142 | ''' 143 | class FullStoryDataset(Dataset): 144 | def __init__(self, data_file, encoder, max_size=None, n_ctx=102, n_gen=401, include_kw=True): 145 | self.data = [] 146 | 147 | with open(data_file, "rb") as f: 148 | data = f.readlines() 149 | 150 | self.encoder = encoder 151 | self.ctx = n_ctx - 2 152 | self.gen = n_gen - 1 153 | self.dids = [] 154 | 155 | for d in range(1, len(data)): 156 | t = data[d].decode("utf-8", "ignore").strip().split('\t') 157 | newinput = t[0].split("_")[0] + "\t" + t[2] 158 | if not (newinput in self.dids) and len(t) == 7: 159 | self.dids.append(newinput) 160 | 161 | if max_size is not None: 162 | self.dids = self.dids[:max_size] 163 | self.len = len(self.dids) 164 | self.include_kw = include_kw 165 | 166 | def __getitem__(self, index): 167 | csv_data = self.dids[index].split('\t') 168 | kws = csv_data[1].split("[SEP]") 169 | 170 | tgt_phrase = [] # this is only used at generation time, so ignore the gold 171 | start = torch.LongTensor([self.encoder.bos_token_id]) 172 | clstok = torch.LongTensor([self.encoder.cls_token_id]) 173 | end = torch.LongTensor([self.encoder.eos_token_id]) 174 | tstart = torch.LongTensor([self.encoder.convert_tokens_to_ids('_t_')]) 175 | istart = torch.LongTensor([self.encoder.convert_tokens_to_ids('_i_')]) 176 | bstart = torch.LongTensor([self.encoder.convert_tokens_to_ids('_b_')]) 177 | cstart = torch.LongTensor([self.encoder.convert_tokens_to_ids('_c_')]) 178 | keytok = torch.LongTensor([self.encoder.convert_tokens_to_ids('_kw_')]) 179 | endkeytok = torch.LongTensor([self.encoder.convert_tokens_to_ids('_endkw_')]) 180 | 181 | pad_output = torch.zeros(self.ctx + self.gen + 3).long() 182 | mask_output = torch.zeros(self.ctx + self.gen + 3).long() 183 | 184 | # print(tgt_phrase) 185 | # Tokens 186 | pad_output[0] = start 187 | 188 | if self.include_kw: 189 | i = 1 190 | for k in kws: 191 | if i - 1 >= self.ctx: 192 | break 193 | enck = self.encoder.encode(k.strip(), add_prefix_space=True, add_special_tokens=False)[:self.ctx - i] 194 | # print(enck, i) 195 | pad_output[i:i + len(enck)] = torch.LongTensor(enck) 196 | pad_output[i + len(enck)] = keytok 197 | i += len(enck) + 1 198 | pad_output[i - 1] = endkeytok 199 | mask_output[0:i] = torch.ones(i).long() 200 | 201 | pad_output[self.ctx + 1] = clstok 202 | 203 | # Mask (this will get written over by the generation code anyways) 204 | mask_output[self.ctx + 1:self.ctx + 1 + len(tgt_phrase) + 2] = torch.ones(len(tgt_phrase) + 2).long() 205 | 206 | ids = csv_data + [index] 207 | return pad_output, mask_output, ids 208 | 209 | def __len__(self): 210 | return len(self.dids) 211 | 212 | ''' 213 | get_fullstory_loader: Get single story context 214 | 215 | @params- 216 | data_file: the file location for the data 217 | encoder: tokenizer 218 | max_size: truncate to # examples (or None to use full dataset) 219 | n_ctx: number of plotoutline tokens + 2 for delimiters 220 | gen_len: number of paragraph tokens + 1 for end token 221 | include_kw: unused, if I want to ignore the context 222 | ''' 223 | def get_fullstory_loader(data_file, batch_size, encoder, shuffle=True, num_workers=0, max_size=None, n_ctx=102, 224 | gen_len=401, include_kw=True): 225 | dataset = FullStoryDataset(data_file, encoder, max_size=max_size, n_ctx=n_ctx, n_gen=gen_len, include_kw=include_kw) 226 | return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=True) 227 | 228 | 229 | 230 | '''Paragraph Dataset: get a single paragraph from a story with memory for PlotMachines model 231 | optionally include discourse tag delimiter (include_discourse_type) 232 | ''' 233 | class ParagraphWithMemoryDataset(Dataset): 234 | def __init__(self, data_file, encoder, max_size=None, n_ctx=102, n_gen=401, include_discourse_type=True, 235 | include_kw=True, memsize=10, dim=768, use_kwmem=False, debug_mode=False): 236 | 237 | def isClean(line): 238 | chunks = line.decode('utf-8', 'ignore').strip().split('\t') 239 | if len(chunks) < 7: 240 | return False 241 | 242 | keys = chunks[2] 243 | par = chunks[5] 244 | par_prev = chunks[6].strip() 245 | 246 | if len(keys) < 5 or len(par) < 5 or (par_prev != 'NA' and len(par_prev) < 5): # or len(par_prev): 247 | return False 248 | return True 249 | 250 | with open(data_file, "rb") as f: 251 | self.data = f.readlines() 252 | temp_data = [] 253 | 254 | self.prevmat = [] 255 | 256 | fn = ".".join(data_file.split(".")[:-1]) + "_gpt2.pkl" 257 | if debug_mode: 258 | fn = ".".join(data_file.split(".")[:-1]) + "_gpt.pkl" 259 | with open(fn, 'rb') as fp: 260 | for k in range(len(self.data)): 261 | temp = pickle.load(fp) 262 | if k == 0: 263 | continue 264 | 265 | if not isClean(self.data[k]): 266 | continue 267 | 268 | if temp[0] != k or temp[1] != self.data[k].decode('utf-8', 'ignore').split("\t")[-1].replace("","").strip(): 269 | print(str(temp[0])) 270 | print(str(k)) 271 | print(temp[1]) 272 | print(self.data[k].decode('utf-8', 'ignore').split("\t")[-1].replace("","").strip()) 273 | continue 274 | 275 | temp_data.append(self.data[k]) 276 | self.prevmat.append(temp[2]) 277 | 278 | self.data = temp_data 279 | temp_data = [] 280 | 281 | print('i read so many of ... ' + str(len(self.data))) 282 | assert len(self.prevmat) == len(self.data) 283 | 284 | self.dids = [] 285 | self.history = dict() 286 | self.h = 10 287 | for d in range(1, len(self.data)): 288 | t = self.data[d].decode("utf-8", "ignore").strip().split('\t') 289 | docid = t[0].split("_")[0] 290 | if len(t) == 7 and t[5].replace("", "").strip() != "": 291 | try: 292 | x, y = int(t[0].split("_")[-1]), int(t[4]) 293 | except: 294 | continue 295 | self.dids.append(d) 296 | if docid not in self.history: 297 | self.history[docid] = dict() 298 | self.history[docid][x] = self.prevmat[d] ##t[5].replace("","") 299 | # print(len(self.history)) 300 | 301 | if max_size is not None: 302 | self.dids = self.dids[:max_size] 303 | self.encoder = encoder 304 | self.ctx = n_ctx - 2 305 | self.gen = n_gen - 1 306 | self.memsize = memsize 307 | self.len = len(self.data) 308 | self.include_discourse_type = include_discourse_type 309 | self.include_kw = include_kw 310 | self.h = 10 311 | self.dim = dim 312 | # asli 313 | self.use_kwmem = use_kwmem 314 | 315 | def __getitem__(self, index): 316 | idx = self.dids[index] 317 | csv_data = self.data[idx].decode("utf-8", "ignore").strip().split('\t') 318 | kws = csv_data[2].split("[SEP]") 319 | 320 | tgt_phrase = self.encoder.encode(csv_data[5].replace("", ""), add_prefix_space=True, add_special_tokens=False)[:self.gen] 321 | start = torch.LongTensor([self.encoder.bos_token_id]) 322 | clstok = torch.LongTensor([self.encoder.cls_token_id]) 323 | end = torch.LongTensor([self.encoder.eos_token_id]) 324 | tstart = torch.LongTensor([self.encoder.convert_tokens_to_ids('_t_')]) 325 | istart = torch.LongTensor([self.encoder.convert_tokens_to_ids('_i_')]) 326 | bstart = torch.LongTensor([self.encoder.convert_tokens_to_ids('_b_')]) 327 | cstart = torch.LongTensor([self.encoder.convert_tokens_to_ids('_c_')]) 328 | keytok = torch.LongTensor([self.encoder.convert_tokens_to_ids('_kw_')]) 329 | endkeytok = torch.LongTensor([self.encoder.convert_tokens_to_ids('_endkw_')]) 330 | 331 | if self.include_discourse_type: 332 | starttyptok = bstart 333 | if int(csv_data[0].split("_")[-1]) == 0: 334 | starttyptok = istart 335 | elif int(csv_data[0].split("_")[-1]) == int(csv_data[4]) - 1: 336 | starttyptok = cstart 337 | else: 338 | starttyptok = clstok 339 | 340 | pad_output = torch.zeros(self.ctx + self.gen + 3).long() 341 | mask_output = torch.zeros(self.ctx + self.gen + 3).long() 342 | 343 | if self.use_kwmem: 344 | mem = torch.torch.empty(self.ctx + self.memsize, self.dim).normal_(std=.02) 345 | mmask = torch.zeros(self.ctx + self.memsize).long() 346 | else: 347 | mem = torch.torch.empty(self.memsize, self.dim).normal_(std=.02) 348 | mmask = torch.zeros(self.memsize).long() 349 | 350 | pad_output[0] = start 351 | 352 | if self.include_kw: 353 | ##if self.use_kwmem: 354 | i = 1 355 | for k in kws: 356 | if i - 1 >= self.ctx: 357 | break 358 | enck = self.encoder.encode(k.strip(), add_prefix_space=True, add_special_tokens=False)[:self.ctx - i] 359 | # print(enck, i) 360 | pad_output[i:i + len(enck)] = torch.LongTensor(enck) 361 | pad_output[i + len(enck)] = keytok 362 | i += len(enck) + 1 363 | pad_output[i - 1] = endkeytok 364 | mask_output[0:i] = torch.ones(i).long() 365 | 366 | # mem[0:i-1,0] = pad_output[1:i, 0] 367 | if self.use_kwmem: 368 | mmask[0:i - 1] = torch.ones(i - 1).long() 369 | mmask[-self.memsize:] = torch.ones(self.memsize).long() 370 | 371 | pad_output[self.ctx + 1] = starttyptok if self.include_discourse_type else clstok 372 | pad_output[self.ctx + 1 + 1:self.ctx + 1 + 1 + len(tgt_phrase)] = torch.LongTensor(tgt_phrase) 373 | pad_output[self.ctx + 1 + 1 + len(tgt_phrase)] = end 374 | 375 | # Mask 376 | mask_output[self.ctx + 1:self.ctx + 1 + len(tgt_phrase) + 2] = torch.ones(len(tgt_phrase) + 2).long() 377 | 378 | prev = torch.zeros(self.h, 1, self.dim).float() # .long() 379 | pmask = torch.zeros(self.h, 1).long() 380 | docid = csv_data[0].split("_")[0] 381 | pid = int(csv_data[0].split("_")[-1]) 382 | 383 | for p in range(1, min(pid + 1, self.h + 1)): 384 | # p = 1 --> pid+1 385 | if self.history[docid].get(p) is None: 386 | continue 387 | try: 388 | prev[p - 1, 0, :] = torch.LongTensor(self.history[docid][p]) 389 | pmask[p - 1, 0] = torch.ones(1).long() 390 | except: 391 | continue 392 | return pad_output, mask_output, mem, mmask, prev, pmask, torch.FloatTensor(self.prevmat[idx].flatten()) 393 | 394 | def __len__(self): 395 | return len(self.dids) 396 | 397 | ''' 398 | get_paragraph_history_input_loader: Get data loader for plot machines 399 | 400 | @params- 401 | data_file: the file location for the data 402 | encoder: tokenizer 403 | max_size: truncate to # examples (or None to use full dataset) 404 | n_ctx: number of plotoutline tokens + 2 for delimiters 405 | gen_len: number of paragraph tokens + 1 for end token 406 | include_discourse_type: whether to use special discouse tokens 407 | include_kw: unused, if I want to ignore the context 408 | memsize: the total number of memory slots (aside from the keyword slots) 409 | dim: the dimension of the memory slot vectors (should be same as the PlotMachines embeddings) 410 | use_kwmem: use keyword-based memory slots in the memory 411 | debug_mode: make a toy dataset for debugging 412 | ''' 413 | def get_paragraph_memory_input_loader(data_file, batch_size, encoder, shuffle=True, num_workers=0, max_size=None, 414 | n_ctx=102, gen_len=401, include_neigh=False, include_discourse_type=True, 415 | include_kw=True, memsize=10, dim=768, use_kwmem=False, debug_mode=False): 416 | dataset = ParagraphWithMemoryDataset(data_file, encoder, max_size=max_size, n_ctx=n_ctx, n_gen=gen_len, 417 | include_discourse_type=include_discourse_type, include_kw=include_kw, 418 | memsize=memsize, dim=dim, use_kwmem=use_kwmem, debug_mode=debug_mode) 419 | return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=True) 420 | -------------------------------------------------------------------------------- /src/model/eval_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import glob 4 | import json 5 | import os 6 | import random 7 | import re 8 | 9 | from nltk.tokenize import sent_tokenize 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | import rouge 14 | from transformers import * 15 | 16 | from tqdm import tqdm 17 | from generate_utils import generate_paragraph 18 | 19 | def clear_dirs(gen_dir, tgt_dir): 20 | for f in glob.glob("{}/*".format(tgt_dir)): 21 | os.remove(f) 22 | for f in glob.glob("{}/*".format(gen_dir)): 23 | os.remove(f) 24 | os.makedirs(tgt_dir, exist_ok=True) 25 | os.makedirs(gen_dir, exist_ok=True) 26 | 27 | def format_text(text, max_len, stop_words=[]): 28 | text = "\n".join(sent_tokenize(text)).replace("<", "<").replace(">", ">") 29 | for stop_word in stop_words: 30 | text = text.replace(" {} ".format(stop_word), " ") 31 | if max_len is not None: 32 | text = " ".join(text.split(" ")[:max_len]) 33 | return text.encode('ascii','ignore').decode("ascii","ignore") 34 | 35 | def get_average_scores(jsonfile, srcs,hyps, refs,maxlen=110, stop_words=[]): 36 | rouge_scorer = rouge.Rouge() 37 | averaged_scores = {'rouge-1': {'f': 0, 'p': 0, 'r': 0}, 38 | 'rouge-2': {'f': 0, 'p': 0, 'r': 0}, 39 | 'rouge-l': {'f': 0, 'p': 0, 'r': 0}} 40 | 41 | scores = rouge_scorer.get_scores(hyps, refs) 42 | for metric in averaged_scores.keys(): 43 | for values in scores: 44 | for sub_metric in averaged_scores[metric]: 45 | averaged_scores[metric][sub_metric] += values[metric][sub_metric] 46 | for key in averaged_scores.keys(): 47 | for sub_key in averaged_scores[key].keys(): 48 | averaged_scores[key][sub_key] /= len(hyps) 49 | for i in range(len(srcs)): 50 | jsonfile.write(json.dumps({'r1': scores[i]['rouge-1'], 'r2': scores[i]['rouge-2'], 'rl': scores[i]['rouge-l'],'hyp':hyps[i], 'ref':refs[i],'src':srcs[i]})+"\n") 51 | return averaged_scores 52 | 53 | 54 | 55 | 56 | def evaluate_doc_model(model, val_loader, text_encoder, device, beam, gen_len, k, p, decoding_strategy, save_file, gen_dir="gen", tgt_dir="tgt", max_len=110, stop_words=[], args=None): 57 | data = {"src": [], "gen": [], "tgt": []} 58 | srcs, hyps, refs = [], [], [] 59 | model.eval() 60 | for batchargs in tqdm(val_loader): 61 | with torch.no_grad(): 62 | # Generating outputs for evaluation 63 | src_strs, tgt_strs, gen_strs = generate_paragraph(model, batchargs, text_encoder, device, beam, gen_len, k, p, decoding_strategy, min_len=args.min_len) 64 | data["src"].extend(src_strs) 65 | data["gen"].extend(gen_strs) 66 | data["tgt"].extend(tgt_strs) 67 | 68 | jsf = open(save_file+".output.json","w") 69 | for i in range(min(len(data['src']),50)): 70 | print("*" * 50) 71 | try: 72 | print("Source: {}".format(data['src'][i])) 73 | print('Hypothesis: {}'.format(data['gen'][i])) 74 | print("Reference: {}".format(data['tgt'][i])) 75 | except: 76 | pass 77 | 78 | with open(save_file, "w") as f: 79 | json.dump( 80 | #get_rouge_scores(gen_dir, tgt_dir), 81 | get_average_scores(jsf,data['src'],data['gen'],data['tgt'],max_len,stop_words), 82 | f, 83 | indent=4, 84 | sort_keys=True 85 | ) 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /src/model/generate_stories.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import numpy as np 5 | import rouge 6 | import torch 7 | from torch import nn 8 | from tqdm import tqdm 9 | 10 | from eval_utils import format_text 11 | from data_loader import get_paragraph_input_loader, get_fullstory_loader 12 | from model import GPT2BaseModel, PlotMachinesModel 13 | from generate_utils import toks_to_str 14 | from parallel import DataParallelModel, DataParallelCriterion 15 | from transformers import * 16 | 17 | def tfmclassifier(textlines, model, tokenizer, gen_len): 18 | '''Create encoding of the previous paragraph (textlines) using the model and tokenizer''' 19 | clf = [] 20 | nb = len(textlines) 21 | #if nb < 8: 22 | wds = torch.zeros(nb, gen_len, dtype=torch.long).cuda() 23 | mask = torch.zeros(nb, gen_len, dtype=torch.long).cuda() 24 | for j in range(nb): 25 | 26 | temp = torch.tensor(tokenizer.encode(textlines[j], add_special_tokens=False)[:gen_len]) 27 | wds[j,:len(temp)] = temp.cuda() 28 | mask[j,:len(temp)] = torch.ones(len(temp), dtype=torch.long).cuda() 29 | model.eval() 30 | outputs = model(wds) 31 | total = (mask.unsqueeze(2).type_as(outputs[0]) * outputs[0]).sum(dim=1) / mask.type_as(outputs[0]).sum(dim=1).unsqueeze(1) 32 | return total 33 | 34 | '''Generate a single paragraph''' 35 | def generate_paragraph(model, args, text_encoder, device, beam, gen_len, k, p, decoding_strategy, ids, tagnum, min_len=None, returnnewmem=False): 36 | src_strs, tgt_strs, gen_strs, genraw, gentok = [], [], [], [], [] 37 | n_gpu = torch.cuda.device_count() 38 | 39 | outputs = model(*args, text_encoder=text_encoder, device=device, beam=beam, gen_len=gen_len, k=k, p=p, decoding_strategy=decoding_strategy, generate=True, min_len=min_len) 40 | if n_gpu == 1: 41 | outputs = [outputs] 42 | i = 0 43 | 44 | seenout = [] 45 | if len(outputs[0]) > 3: 46 | for generated_toks, input_toks, target_toks, seenuni in outputs: ##outputs[0][i],outputs[1][i],outputs[2][i] 47 | for idx in range(generated_toks.size(0)): 48 | gentok.append(generated_toks[idx].view(1,-1)) 49 | seenout.append(seenuni[idx]) 50 | #print(toks_to_str(input_toks[idx], text_encoder, is_input=True)) 51 | gen_str = toks_to_str(generated_toks[idx], text_encoder).replace("\n", " ") 52 | genraw.append(gen_str) 53 | gen_strs.append(str(ids[2][i].item()) + "\t"+str(ids[0][i])+"\t"+ str(ids[1][i]) +"\t"+str(tagnum)+"\t"+gen_str) 54 | i+=1 55 | return gen_strs, genraw, torch.cat(gentok, dim = 0), torch.stack(seenout, dim=0) 56 | 57 | else: 58 | for generated_toks, input_toks, target_toks in outputs: ##outputs[0][i],outputs[1][i],outputs[2][i] 59 | for idx in range(generated_toks.size(0)): 60 | gentok.append(generated_toks[idx].view(1,-1)) 61 | #print(toks_to_str(input_toks[idx], text_encoder, is_input=True)) 62 | gen_str = toks_to_str(generated_toks[idx], text_encoder).replace("\n", " ") 63 | genraw.append(gen_str) 64 | gen_strs.append(str(ids[2][i].item()) + "\t"+str(ids[0][i])+"\t"+ str(ids[1][i]) +"\t"+str(tagnum)+"\t"+gen_str) 65 | i+=1 66 | return gen_strs, genraw, torch.cat(gentok, dim = 0) 67 | 68 | 69 | '''Generate full stories''' 70 | def generatedocs(model, gptmodel, gpttok, val_loader, text_encoder, device, beam, gen_len, k, p, decoding_strategy, save_file, gen_dir="gen", tgt_dir="tgt", max_len=110, stop_words=[], args=None, tags=['_i_','_b_','_b_','_b_','_c_'], dim=768, localfile=None, save_dir=None): 71 | def dump_to_file(jsf, data): 72 | for i in range(len(data)): 73 | try: 74 | jsf.write(data[i] + "\n") 75 | except: 76 | jsf.write('error on line ' + str(i) + "\n") 77 | pass 78 | 79 | data = {'gen':[]} 80 | srcs, hyps, refs = [], [], [] 81 | model.eval() 82 | gptmodel.eval() 83 | iter = 0 84 | 85 | try: 86 | if os._exists(save_file): 87 | os.remove(save_file) 88 | except: 89 | print("Error while deleting file ", save_file) 90 | jsf = open(localfile, "w") 91 | 92 | for pad_seq, mask_seq, docids in tqdm(val_loader): 93 | with torch.no_grad(): 94 | # Generating outputs for evaluation 95 | prev= ['NA']*pad_seq.size(0) 96 | 97 | kwsize = args.n_ctx-2 98 | mem = torch.torch.empty(pad_seq.size(0), kwsize + args.memstatesize, args.n_embd).normal_(std=0.02) 99 | mmask = torch.zeros(mask_seq.size(0), kwsize + args.memstatesize)#.long() 100 | mmask[:, :kwsize] = mask_seq[:, 1:args.n_ctx-1] 101 | mmask[:, -args.memstatesize:] = torch.ones(mmask.size(0), args.memstatesize)#.long() 102 | 103 | ph = torch.zeros(pad_seq.size(0), 10, 1, dim)#.long() 104 | pmask = torch.zeros(pad_seq.size(0), 10, 1)#.long() 105 | seenunigrams = torch.ones(pad_seq.size(0), len(text_encoder)) #[{} for _ in range(pad_seq.size(0))] 106 | idces = torch.arange(pad_seq.size(0)) 107 | 108 | 109 | for tnum in range(len(tags)): 110 | tag= tags[tnum] 111 | if args.use_model =="plotmachines": 112 | if args.use_neighbor_feat: 113 | prevprc = tfmclassifier(prev, gptmodel, gpttok, gen_len) 114 | if args.use_discourse: 115 | pad_seq [:,args.n_ctx-1] = text_encoder.added_tokens_encoder[tag] #add discourse marker 116 | modelargs = (pad_seq, mask_seq, mem, mmask, ph, pmask, prevprc, seenunigrams, idces) 117 | gen_strs, genraw, gentok, seenunigrams = generate_paragraph(model, modelargs, text_encoder, device, beam, gen_len, k, p, decoding_strategy, docids, tnum, min_len=args.min_len) 118 | prevprc = tfmclassifier(genraw, gptmodel, gpttok,gen_len) 119 | ph[:, tnum, 0, :] = prevprc 120 | pmask[:, tnum, 0] = 1 121 | 122 | else: 123 | prevprc = None 124 | if args.use_neighbor_feat: 125 | prevprc = tfmclassifier(prev, gptmodel, gpttok, gen_len) 126 | if args.use_discourse: 127 | pad_seq[:,args.n_ctx-1] = text_encoder.added_tokens_encoder[tag] # add discourse marker 128 | modelargs = (pad_seq, mask_seq, prevprc, seenunigrams, idces) 129 | gen_strs, genraw, gentok, seenunigrams = generate_paragraph(model, modelargs, text_encoder, device, beam, gen_len, k, p, decoding_strategy, docids, tnum, min_len=args.min_len) 130 | data["gen"].extend(gen_strs) 131 | prev = genraw 132 | 133 | if iter %100 == 0: 134 | dump_to_file(jsf, data["gen"]) 135 | data = {'gen': []} 136 | iter+=1 137 | 138 | dump_to_file(jsf, data["gen"]) 139 | 140 | import shutil 141 | trial = 0 142 | while trial < 10: 143 | try: 144 | print('Copying the generated file from ' + localfile + ' to ' + save_file) 145 | shutil.move(localfile, save_file) 146 | trial = 100 147 | except Exception as e: 148 | print(e) 149 | os.makedirs(save_dir, exist_ok=True) 150 | trial += 1 151 | 152 | def init(args): 153 | 154 | random.seed(args.seed) 155 | np.random.seed(args.seed) 156 | torch.manual_seed(args.seed) 157 | torch.cuda.manual_seed_all(args.seed) 158 | 159 | def main(args): 160 | init(args) 161 | #Args setup: 162 | 163 | beam = args.beam 164 | p = args.p 165 | n_ctx = args.n_ctx 166 | gen_len = args.gen_len 167 | k = args.k 168 | decoding_strategy = args.decoding_strategy 169 | accum_iter = args.accum_iter 170 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 171 | n_gpu = torch.cuda.device_count() 172 | print("device", device, "n_gpu", n_gpu) 173 | data_dir = args.data_dir 174 | #Text Encoder 175 | 176 | if args.debug_mode: 177 | text_encoder = GPT2Tokenizer.from_pretrained('gpt2') 178 | else: 179 | text_encoder = GPT2Tokenizer.from_pretrained('gpt2-medium') 180 | text_encoder.add_special_tokens({'bos_token':'_start_', 181 | 'cls_token':'_classify_', 182 | 'eos_token':'_end_', 183 | 'additional_special_tokens': ['_kw_','_endkw_', '_t_', '_i_', '_b_', '_c_'] 184 | }) 185 | 186 | 187 | vocab = len(text_encoder) 188 | 189 | datafile = os.path.join(data_dir, "test_encoded.csv") if args.testset else os.path.join(data_dir, "val_encoded.csv") 190 | print("Loading dataset...") 191 | val_loader = get_fullstory_loader(datafile, args.n_batch, text_encoder, num_workers=0, shuffle=False, gen_len=gen_len, n_ctx=n_ctx, include_kw = not args.exclude_kw, max_size=args.max_ex) 192 | print(len(val_loader)) 193 | 194 | if args.use_model == "plotmachines": 195 | doc_model = PlotMachinesModel(args, vocab=vocab, n_ctx=n_ctx, gen_len=gen_len, lastidx=text_encoder.eos_token_id, includeprev=args.use_neighbor_feat) 196 | else: 197 | doc_model = GPT2BaseModel(args, vocab=vocab, n_ctx=n_ctx, gen_len=gen_len, lastidx=text_encoder.eos_token_id, includeprev=args.use_neighbor_feat) 198 | 199 | doc_model.to(device) 200 | if n_gpu > 1: 201 | doc_model = DataParallelModel(doc_model) 202 | 203 | 204 | if args.debug_mode: 205 | gptclf = GPT2Model.from_pretrained('gpt2') 206 | gptclf.eval() 207 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 208 | gptclf.to(device) 209 | #gpttok = gptTokenizer.from_pretrained('openai-gpt') 210 | gpttok = GPT2Tokenizer.from_pretrained('gpt2') 211 | 212 | else: 213 | gptclf = GPT2Model.from_pretrained('gpt2-medium') 214 | gptclf.eval() 215 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 216 | gptclf.to(device) 217 | #gpttok = gptTokenizer.from_pretrained('openai-gpt') 218 | gpttok = GPT2Tokenizer.from_pretrained('gpt2-medium') 219 | 220 | prevloss = [] 221 | upd = [] 222 | start_iter, running_loss = 1,0 223 | load_dir = args.load_dir 224 | bestcheck = os.path.join(load_dir,"checkpoint_best.pt") 225 | checkpoint = torch.load(bestcheck, map_location='cpu') 226 | state_dict = checkpoint["state_dict"] 227 | if n_gpu ==1: 228 | if state_dict.get('module.pos_emb_mask') is None and doc_model.state_dict().get('pos_emb_mask') is not None: 229 | state_dict['module.pos_emb_mask'] = doc_model.state_dict().get('pos_emb_mask') 230 | for k in list(state_dict.keys()): 231 | state_dict[k[7:]] = state_dict[k] 232 | del state_dict[k] 233 | else: 234 | if state_dict.get('module.pos_emb_mask') is None and doc_model.state_dict().get('module.pos_emb_mask') is not None: 235 | state_dict['module.pos_emb_mask'] = doc_model.state_dict().get('module.pos_emb_mask') 236 | doc_model.load_state_dict(state_dict) 237 | 238 | print("Parallelized") 239 | tagset = ['_i_'] + args.bodynum* ['_b_'] + ['_c_'] 240 | vort = 'test' if args.testset else 'val' 241 | generatedocs(doc_model, gptclf, gpttok, val_loader, text_encoder, device, beam, gen_len, k, p, args.decoding_strategy, os.path.join(args.save_dir,vort+'.gens.tsv'), 242 | 'gen','tgt', gen_len, [], args, tags = tagset, dim=args.n_embd, save_dir=args.save_dir, localfile=os.path.join('/tmp',vort+'.gens.tsv')) 243 | 244 | print('done decoding....') 245 | 246 | 247 | if __name__ == "__main__": 248 | parser = argparse.ArgumentParser() 249 | parser.add_argument('--desc', type=str, help="Description") 250 | parser.add_argument('--seed', type=int, default=42) 251 | parser.add_argument('--output_hidden_states', action='store_true') 252 | parser.add_argument('--output_attentions', action='store_true') 253 | parser.add_argument('--output_past', action='store_true') 254 | parser.add_argument('--num_epochs', type=int, default=10) 255 | parser.add_argument('--n_batch', type=int, default=2) 256 | parser.add_argument('--max_grad_norm', type=int, default=1) 257 | parser.add_argument('--lr', type=float, default=6.25e-5) 258 | parser.add_argument('--lr_warmup', type=float, default=0.002) 259 | parser.add_argument('--n_embd', type=int, default=1024) 260 | parser.add_argument('--n_head', type=int, default=12) 261 | parser.add_argument('--n_layer', type=int, default=12) 262 | parser.add_argument('--embd_pdrop', type=float, default=0.1) 263 | parser.add_argument('--attn_pdrop', type=float, default=0.1) 264 | parser.add_argument('--resid_pdrop', type=float, default=0.1) 265 | parser.add_argument('--clf_pdrop', type=float, default=0.1) 266 | parser.add_argument('--l2', type=float, default=0.01) 267 | parser.add_argument('--vector_l2', action='store_true') 268 | parser.add_argument('--opt', type=str, default='adam') 269 | parser.add_argument('--afn', type=str, default='gelu') 270 | parser.add_argument('--lr_schedule', type=str, default='warmup_linear') 271 | parser.add_argument('--n_transfer', type=int, default=12) 272 | parser.add_argument('--lm_coef', type=float, default=0.5) 273 | parser.add_argument('--b1', type=float, default=0.9) 274 | parser.add_argument('--b2', type=float, default=0.999) 275 | parser.add_argument('--e', type=float, default=1e-8) 276 | parser.add_argument('--load_dir', type=str, default="output", help='directory containing checkpoint_best.pt') 277 | parser.add_argument('--save_dir', type=str, default="output", help='directory to save generations to') 278 | parser.add_argument('--data_dir', type=str, default='data', help='directory containing dev/test inputs') 279 | parser.add_argument('--max_ex', type=int, default=None, help='maximum number of inputs to use, or None for using whole dataset') 280 | parser.add_argument('--beam', type=int, default=0) 281 | parser.add_argument('--k', type=int, default=0) 282 | parser.add_argument('--p', type=int, default=0) 283 | parser.add_argument('--decoding_strategy', type=int, default=0) 284 | parser.add_argument('--accum_iter', type=int, default=2) 285 | parser.add_argument('--gen_len', type=int, default=922) 286 | parser.add_argument('--n_ctx', type=int, default=102) 287 | parser.add_argument('--min_len', type=int, default=100) 288 | parser.add_argument('--repeattheta', type=float, default=1.5) 289 | parser.add_argument('--show_progress', action='store_true') 290 | parser.add_argument('--exclude_kw', action='store_true') 291 | parser.add_argument('--testset', action='store_true', help='if true will generate from test set, if false will generate from dev set') 292 | parser.add_argument('--memstatesize', type=int, default=100) 293 | parser.add_argument('--use_model', type=str, choices=['base', 'plotmachines']) 294 | parser.add_argument('--use_neighbor_feat', action='store_true') 295 | parser.add_argument('--use_discourse', action='store_true') 296 | parser.add_argument('--debug_mode', action='store_true') 297 | #--bodynum determines format of discourse template for output 298 | #(for five paragraph format, use 3, because intro and conclusion will be added automatically) 299 | parser.add_argument('--bodynum', type=int, default=3, help='The number of body pargraphs to use in generation') 300 | 301 | 302 | args = parser.parse_args() 303 | print(torch.__version__) 304 | print(args) 305 | main(args) 306 | -------------------------------------------------------------------------------- /src/model/generate_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | import torch 7 | from tqdm import tqdm 8 | from pytorch_transformers import * 9 | 10 | 11 | def generate_paragraph(model, args, text_encoder, device, beam, gen_len, k, p, decoding_strategy, min_len=None): 12 | src_strs, tgt_strs, gen_strs = [], [], [] 13 | mask = args[1] 14 | n_gpu = torch.cuda.device_count() 15 | outputs = model(*args, text_encoder=text_encoder, device=device, beam=beam, gen_len=gen_len, k=k, p=p, decoding_strategy=decoding_strategy, generate=True, min_len=min_len) 16 | #print(len(outputs[0])) 17 | # for i in range(len(outputs[0])): 18 | if n_gpu == 1: 19 | outputs = [outputs] 20 | for generated_toks, input_toks, target_toks, _ in outputs: ##outputs[0][i],outputs[1][i],outputs[2][i] 21 | for idx in range(generated_toks.size(0)): 22 | src_str = toks_to_str(input_toks[idx], text_encoder, is_input=True, mask=mask[idx], ctx=102) 23 | src_strs.append(src_str) 24 | tgt_str = toks_to_str(target_toks[idx], text_encoder) 25 | tgt_strs.append(tgt_str) 26 | gen_str = toks_to_str(generated_toks[idx], text_encoder) 27 | gen_strs.append(gen_str) 28 | return src_strs, tgt_strs, gen_strs 29 | 30 | 31 | 32 | def toks_to_str(toks, text_encoder, is_input=False, mask=None, ctx=102): 33 | str_rep = [] 34 | end_tok = text_encoder.convert_tokens_to_ids('_endkw_') if is_input else text_encoder.convert_tokens_to_ids('_end_') 35 | 36 | for token in toks: 37 | if token.item() == end_tok : #or token.item() == 0:# or x.item() == end_idx: 38 | break 39 | str_rep.append( text_encoder.convert_ids_to_tokens(token.item())) 40 | 41 | if is_input: 42 | str_rep.append(text_encoder.convert_ids_to_tokens(toks[ctx-1].item())) 43 | str_rep = text_encoder.convert_tokens_to_string(str_rep) 44 | 45 | # This makes sure rouge scorers doesn't complain about no sentences 46 | if not str_rep: 47 | str_rep = "unk." 48 | elif "." not in str_rep: 49 | str_rep += "." 50 | 51 | return str_rep 52 | -------------------------------------------------------------------------------- /src/model/logger.py: -------------------------------------------------------------------------------- 1 | # Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 2 | #import tensorflow as tf 3 | import os 4 | import time 5 | class Logger(): 6 | 7 | def __init__(self, log_dir): 8 | """Create a summary writer logging to log_dir.""" 9 | #self.writer = tf.summary.FileWriter(log_dir) 10 | self.rlog = os.path.join(log_dir,"output_rouge.tsv") 11 | fio = open(self.rlog, "w") 12 | fio.write(str(time.ctime())+"\tsummary start\n") 13 | 14 | self.log = os.path.join(log_dir,"output_losses.tsv") 15 | fio = open(self.log, "w") 16 | fio.write(str(time.ctime())+"\tupdates\tData\tLce\n") 17 | 18 | def rouge_summary(self, tag, value, step): 19 | """Log a scalar variable.""" 20 | fio = open(self.rlog, "a") 21 | fio.write(str(time.ctime())+"\t"+str(step)+"\t"+str(tag)+"\t"+str(value)+"\n") 22 | 23 | def scalar_summary(self,tag,num, denom, step): 24 | """Log a scalar variable.""" 25 | fio = open(self.log, "a") 26 | value = num/denom 27 | fio.write(str(time.ctime())+"\t"+str(step)+"\t"+str(tag)+"\t"+str(value)+"\t\t\t\n") -------------------------------------------------------------------------------- /src/model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | class LMLoss(nn.Module): 6 | ''' Classic LM Loss ''' 7 | def __init__(self, lm_criterion, opt=None): 8 | super(LMLoss, self).__init__() 9 | self.lm_criterion = lm_criterion 10 | self.opt = opt 11 | 12 | def forward(self, lm_logits, X, mask): 13 | x_shifted = X[:, 1:, 0].contiguous().view(-1) 14 | mask = mask[:, 1:].view(-1, mask.size(-1) - 1).float() 15 | lm_logits = lm_logits[:, :-1, :].contiguous().view(-1, lm_logits.size(-1)) 16 | lm_losses = self.lm_criterion(lm_logits, x_shifted) 17 | lm_losses = lm_losses.view(X.size(0), X.size(1) - 1) 18 | lm_losses = lm_losses * mask 19 | lm_losses = lm_losses.sum(1) / torch.sum(mask, 1) 20 | return lm_losses 21 | 22 | 23 | 24 | class ParagraphLoss(nn.Module): 25 | ''' LM Loss but ignoring the first n_ctx tokens ''' 26 | def __init__(self, lm_criterion, opt=None, n_ctx=102, gen_len=401): 27 | super(ParagraphLoss, self).__init__() 28 | self.lm_criterion = lm_criterion 29 | self.opt = opt 30 | self.ctx = n_ctx 31 | self.tgt = gen_len 32 | 33 | def forward(self, lm_logits, X, mask): 34 | ## LM Loss, but ignoring the ctx tokens 35 | x_shifted = X[:, self.ctx:].contiguous().view(-1) #[102:] (text only) 36 | mask = mask[:, self.ctx:].view(-1, mask.size(-1) - (self.ctx)).float() #[102:] 37 | lm_logits = lm_logits[:, self.ctx-1:-1, :].contiguous().view(-1, lm_logits.size(-1)) #shifted over predictions 38 | lm_losses = self.lm_criterion(lm_logits, x_shifted) 39 | lm_losses = lm_losses.view(X.size(0), -1) 40 | lm_losses = lm_losses * mask 41 | lm_losses = lm_losses.sum(1) / torch.sum(mask, 1) 42 | return lm_losses 43 | -------------------------------------------------------------------------------- /src/model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import copy 6 | 7 | from transformers.modeling_gpt2 import * 8 | 9 | class GPT2NeighborModel(GPT2Model): 10 | '''GPT2 model but with slightly altered foward function to include previous paragraph encoding as an additional input''' 11 | def __init__(self, config): 12 | super(GPT2NeighborModel, self).__init__(config) 13 | 14 | def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, labels=None, includeprev=False, x_prev=None): 15 | if includeprev: 16 | #if need to add previous paragraph 17 | input_shape = input_ids.size() 18 | input_ids = input_ids.view(-1, input_shape[-1]) 19 | if token_type_ids is not None: 20 | token_type_ids = token_type_ids.view(-1, input_shape[-1]) 21 | if position_ids is not None: 22 | position_ids = position_ids.view(-1, input_shape[-1]) 23 | 24 | if past is None: 25 | past_length = 0 26 | past = [None] * len(self.h) 27 | else: 28 | past_length = past[0][0].size(-2) 29 | if position_ids is None: 30 | position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device) 31 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 32 | 33 | # Attention mask. 34 | if attention_mask is not None: 35 | attention_mask = attention_mask.view(-1, input_shape[-1]) 36 | attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 37 | attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 38 | attention_mask = (1.0 - attention_mask) * -10000.0 39 | 40 | if head_mask is not None: 41 | if head_mask.dim() == 1: 42 | head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) 43 | head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1) 44 | elif head_mask.dim() == 2: 45 | head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer 46 | head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility 47 | else: 48 | head_mask = [None] * self.config.n_layer 49 | 50 | inputs_embeds = self.wte(input_ids) 51 | position_embeds = self.wpe(position_ids) 52 | if token_type_ids is not None: 53 | token_type_embeds = self.wte(token_type_ids) 54 | else: 55 | token_type_embeds = 0 56 | 57 | ####### THIS IS THE PART that needs to be changed from inherited function: 58 | x_prev = x_prev.unsqueeze(1) #[b,1,d] + [d] = [b,1,d] 59 | inputs_embeds = torch.cat([x_prev, inputs_embeds [:,1:,:]], dim = 1) #x_prev: [b, 1, d], h : [b, s, d]-->[B, s+1, D] 60 | ########### END HERE ######################### 61 | 62 | hidden_states = inputs_embeds + position_embeds + token_type_embeds 63 | hidden_states = self.drop(hidden_states) 64 | 65 | output_shape = input_shape + (hidden_states.size(-1),) 66 | 67 | presents = () 68 | all_attentions = [] 69 | all_hidden_states = () 70 | for i, (block, layer_past) in enumerate(zip(self.h, past)): 71 | if self.output_hidden_states: 72 | all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),) 73 | 74 | outputs = block(hidden_states, 75 | layer_past=layer_past, 76 | attention_mask=attention_mask, 77 | head_mask=head_mask[i]) 78 | 79 | hidden_states, present = outputs[:2] 80 | if self.output_past: 81 | presents = presents + (present,) 82 | 83 | if self.output_attentions: 84 | all_attentions.append(outputs[2]) 85 | 86 | hidden_states = self.ln_f(hidden_states) 87 | 88 | hidden_states = hidden_states.view(*output_shape) 89 | # Add last hidden state 90 | if self.output_hidden_states: 91 | all_hidden_states = all_hidden_states + (hidden_states,) 92 | 93 | outputs = (hidden_states,) 94 | if self.output_past: 95 | outputs = outputs + (presents,) 96 | if self.output_hidden_states: 97 | outputs = outputs + (all_hidden_states,) 98 | if self.output_attentions: 99 | # let the number of heads free (-1) so we can extract attention even after head pruning 100 | attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:] 101 | all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions) 102 | outputs = outputs + (all_attentions,) 103 | 104 | return outputs 105 | 106 | else: 107 | return super().forward(input_ids, past=past, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask) 108 | 109 | 110 | class GPT2NeighborLMHeadModel(GPT2LMHeadModel): 111 | '''GPT2 LM Head model but with the GPT2NeighborModel Class''' 112 | def __init__(self, config): 113 | super(GPT2NeighborLMHeadModel, self).__init__(config) 114 | self.transformer = GPT2NeighborModel(config) 115 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 116 | self.init_weights() 117 | self.tie_weights() 118 | 119 | def tie_weights(self): 120 | """ Make sure we are sharing the input and output embeddings. 121 | Export to TorchScript can't handle parameter sharing so we are cloning them instead. 122 | """ 123 | self._tie_or_clone_weights(self.lm_head, 124 | self.transformer.wte) 125 | 126 | def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, labels=None, includeprev=False, x_prev=None): 127 | transformer_outputs = self.transformer(input_ids, 128 | past=past, 129 | attention_mask=attention_mask, 130 | token_type_ids=token_type_ids, 131 | position_ids=position_ids, 132 | head_mask=head_mask, 133 | includeprev=includeprev, 134 | x_prev= x_prev) 135 | hidden_states = transformer_outputs[0] 136 | 137 | lm_logits = self.lm_head(hidden_states) 138 | 139 | outputs = (lm_logits,) + transformer_outputs[1:] 140 | if labels is not None: 141 | # Shift so that tokens < n predict n 142 | shift_logits = lm_logits[..., :-1, :].contiguous() 143 | shift_labels = labels[..., 1:].contiguous() 144 | # Flatten the tokens 145 | loss_fct = CrossEntropyLoss(ignore_index=-1) 146 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), 147 | shift_labels.view(-1)) 148 | outputs = (loss,) + outputs 149 | 150 | outputs = (hidden_states,) + outputs 151 | return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions) 152 | 153 | 154 | ''' Base GPT2 LM Head model 155 | Uses slightly altered GPT2NeighborLMHeadModel and my decoding methods that have nucleus sampling 156 | ''' 157 | class GPT2BaseModel(nn.Module): 158 | ''' base GPT2 model (no memory): 159 | init params: 160 | cfg: command line argument settings 161 | vocab: total vocab size including special tokens 162 | n_ctx: total context including delimiters 163 | gen_len: total generation length including end tokens 164 | includeprev: use the neighboring (previous) paragraph in input 165 | lastidx: eos index in tokenizer 166 | use_offline_gpt2: true if we've already downloaded from huggingface to server 167 | ''' 168 | def __init__(self, cfg, vocab=40990, n_ctx=102, gen_len=401, return_probs=False, includeprev=False, lastidx=0,use_offline_gpt2=False ): 169 | ###ctx: [/ kw<=100 _ ] gen<=400 == 503 170 | #LM mask:[0x101][1x401] 0 - padded 171 | super(GPT2BaseModel,self).__init__() 172 | 173 | if use_offline_gpt2: 174 | self.lmmodel = GPT2NeighborLMHeadModel.from_pretrained('./gpt2model', n_ctx=n_ctx+gen_len, n_positions=n_ctx+gen_len) 175 | elif cfg.debug_mode: 176 | self.lmmodel = GPT2NeighborLMHeadModel.from_pretrained('gpt2', n_ctx=n_ctx + gen_len, 177 | n_positions=n_ctx + gen_len) 178 | else: 179 | self.lmmodel = GPT2NeighborLMHeadModel.from_pretrained('gpt2-medium', n_ctx=n_ctx + gen_len, 180 | n_positions=n_ctx + gen_len) 181 | self.lmmodel.resize_token_embeddings(vocab) 182 | self.includeprev = includeprev 183 | self.n_ctx = n_ctx 184 | self.gen_len = gen_len 185 | self.epsilon = 1e-8 186 | self.lastidx = lastidx 187 | self.cfg = cfg 188 | self.repeatfactor = self.cfg.repeattheta 189 | pos_emb_mask = torch.zeros(1, 1, vocab) #+n_ctx+gen_len) 190 | self.register_buffer('pos_emb_mask', pos_emb_mask) 191 | 192 | 193 | def _forward(self, x,mask_output,prev, log=False, return_probs=False, returnlast=False, returnnewmem=False, past=None, returnpasts=False): 194 | lmout = self.lmmodel(x, past=past, attention_mask=mask_output, includeprev=self.includeprev, x_prev=prev) 195 | h_dec = lmout[0] 196 | lm_logits = lmout[1] 197 | presents = lmout[2] 198 | if returnpasts: 199 | return lm_logits,presents 200 | if returnlast: 201 | lasttoken = torch.where(x[:,:,0] == self.lastidx, torch.ones_like(x[:,:,0]), torch.zeros_like(x[:,:,0])).unsqueeze(-1) #[B,503,1] 202 | lasttoken = lasttoken.type_as(h_dec)*h_dec 203 | hdecmasked = lasttoken.sum(dim=1) #[B,768] 204 | return lm_logits, hdecmasked 205 | return lm_logits 206 | 207 | ''' 208 | Forward function: 209 | Either performs decoding, training step- default is to just do training step 210 | @param: 211 | *args: tuple of model inputs 212 | generate: if True, then generate new tokens using decoding method 213 | 214 | text_encoder: tokenizer 215 | device: cpu, cuda 216 | beam, decoding_strategy, log: old params for compatability that are not in use 217 | k: if using top k sampling 218 | p: if using nucleus sampling 219 | gen_len: maximum length for decoding 220 | min_len: minimum length for decoding 221 | returnlast: training parameter - return the last token hidden state (this is not in use in the latest codebase) 222 | ''' 223 | def forward(self, *args, text_encoder=None, device=None, beam=0, gen_len=401, k=0, p=0, decoding_strategy=0, log=False, generate=False, min_len=None, returnlast=False, returnnewmem=False): 224 | if generate: 225 | return self.generate(*args,text_encoder=text_encoder, device=device, beam=beam, gen_len=gen_len, k=k, p=p, decoding_strategy=decoding_strategy, min_len=min_len) 226 | return self._forward(*args, log=log, returnlast=returnlast) 227 | 228 | def sample(self, *args, classify_idx=0, text_encoder=None, gen_len=401, k=0, p=0, decoding_strategy=0, min_len=None, eos_idx=None): 229 | XMB, mask, prev, seen_unigrams, idxes = args 230 | pasts = None 231 | for _ in range(gen_len): 232 | lm_logits = self._forward(XMB, mask[:, :XMB.size(-1)], prev) 233 | pem = copy.deepcopy(self.pos_emb_mask) 234 | if _ < min_len: 235 | pem[:,:,eos_idx] = -1e12 #don't let it stop decoding early 236 | 237 | # penalize seen unigrams 238 | lm_logits[:,-1, :] = lm_logits[:,-1,:] / seen_unigrams 239 | lm_probs = F.softmax((lm_logits + pem), dim=-1) 240 | dist = lm_probs[:, -1, :].squeeze(1) 241 | if k == 0 and p == 0: 242 | # Pure Sampling 243 | next_idx = torch.multinomial(lm_probs[:, -1, :], 1) 244 | else: 245 | if p ==0: 246 | # Top K Sampling 247 | values, indices = dist.topk(k) 248 | next_idx = indices.gather(-1, torch.multinomial(values, 1)) 249 | else: 250 | # Nucleus Sampling 251 | indices = torch.argsort(dist,dim=1,descending=True) 252 | values = dist.gather(-1,indices) 253 | probsum = torch.cumsum(values,dim=1) 254 | include = ~ ((probsum.gt(p*.01)) & ((probsum-values).gt(p*.01))) 255 | newdist = torch.where(include, values, torch.zeros_like(values) + 1e-10) 256 | next_idx = indices.gather(-1, torch.multinomial(newdist, 1)) 257 | for i in range(XMB.size(0)): 258 | seen_unigrams[i, next_idx[i]] = self.repeatfactor #add a new seen unigram 259 | XMB = self.append_batch(XMB, next_idx) 260 | return XMB[:, -gen_len:], seen_unigrams 261 | 262 | def append_batch(self, X, next_idx): 263 | return torch.cat((X, next_idx), 1) 264 | 265 | def generate(self, *args, text_encoder=None, device=None, beam=0, gen_len=401, k=0, p=0, decoding_strategy=0, min_len=None): 266 | ##print(len(args)) 267 | if len(args) == 5: 268 | pad_output, mask, prev, seen_trigrams, idxes = args 269 | else: 270 | pad_output, mask, prev = args 271 | seen_trigrams = torch.ones(pad_output.size(0), len(text_encoder)).to(pad_output.device) 272 | idxes = None 273 | classify_idx = None # don't use this in the code anymore 274 | eos_idx = text_encoder.eos_token_id 275 | input_toks = pad_output[:, :self.n_ctx] # includes delimiter 276 | target_toks = pad_output[:, -gen_len:] 277 | mask_pad = torch.ones(mask.size()).type_as(mask) 278 | mask_pad[:, :self.n_ctx] = mask[:, :self.n_ctx] 279 | mask = mask_pad 280 | pad_output = pad_output.to(device) 281 | XMB = pad_output[:, :self.n_ctx] 282 | if beam == 0: 283 | generated_toks, seen = self.sample(XMB, mask, prev, seen_trigrams, idxes, classify_idx=classify_idx, text_encoder=text_encoder, gen_len=gen_len, k=k, p=p, decoding_strategy=decoding_strategy, min_len=min_len, eos_idx=eos_idx) 284 | else: 285 | raise NotImplementedError 286 | output = generated_toks.type_as(XMB), input_toks.type_as(XMB), target_toks.type_as(XMB), seen 287 | return output 288 | 289 | ############################################# 290 | # PlotMachines classes below: 291 | ############################################# 292 | 293 | class MemoryAttention(nn.Module): 294 | '''An Attention Block for attending over the memory slots with word tokens as queries''' 295 | def __init__(self, nx, n_ctx, config, scale=False): 296 | super(MemoryAttention, self).__init__() 297 | self.output_attentions = config.output_attentions 298 | 299 | n_state = nx # in Attention: n_state=768 (nx=n_embd) 300 | # [switch nx => n_state from Block to Attention to keep identical to TF implem] 301 | assert n_state % config.n_head == 0 302 | self.n_head = config.n_head 303 | self.split_size = n_state 304 | self.scale = scale 305 | self.c_attn = Conv1D(n_state, nx) 306 | self.c_memory = Conv1D(n_state * 2, nx) 307 | self.c_proj = Conv1D(n_state, nx) 308 | self.attn_dropout = nn.Dropout(config.attn_pdrop) 309 | self.resid_dropout = nn.Dropout(config.resid_pdrop) 310 | self.pruned_heads = set() 311 | 312 | def prune_heads(self, heads): 313 | if len(heads) == 0: 314 | return 315 | mask = torch.ones(self.n_head, self.split_size // self.n_head) 316 | heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads 317 | for head in heads: 318 | # Compute how many pruned heads are before the head and move the index accordingly 319 | head = head - sum(1 if h < head else 0 for h in self.pruned_heads) 320 | mask[head] = 0 321 | mask = mask.view(-1).contiguous().eq(1) 322 | index = torch.arange(len(mask))[mask].long() 323 | index_attn = torch.cat([index, index + self.split_size, index + (2*self.split_size)]) 324 | 325 | # Prune conv1d layers 326 | self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) 327 | self.c_memory = prune_conv1d_layer(self.c_memory, index_attn, dim=1) 328 | self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) 329 | 330 | # Update hyper params 331 | self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads)) 332 | self.n_head = self.n_head - len(heads) 333 | self.pruned_heads = self.pruned_heads.union(heads) 334 | 335 | def _attn(self, q, k, v, attention_mask=None, head_mask=None, M=None, Mmask=None): 336 | w = torch.matmul(q, k) ## w = b,h,s,p 337 | if self.scale: 338 | w = w / math.sqrt(v.size(-1)) 339 | nd, ns = w.size(-2), w.size(-1) 340 | 341 | if M is not None: # if there is indeed memory being used (should always be true) 342 | #There may be some memory slots might need to be ignored (if a key point list was padded) 343 | p = M.size(1) 344 | temp = Mmask.unsqueeze(1).unsqueeze(2).float() #temp = b,1,1,p 345 | attention_mask = (1.0 - temp) * -10000.0 #b,1,1,p 346 | w = w + attention_mask #b,h,s,p 347 | 348 | w = nn.Softmax(dim=-1)(w) 349 | w = self.attn_dropout(w) 350 | # Mask heads if we want to 351 | if head_mask is not None: 352 | w = w * head_mask 353 | 354 | outputs = [torch.matmul(w, v)] #b,h,s,p * b,h,p,d 355 | if self.output_attentions: 356 | outputs.append(w) 357 | return outputs 358 | 359 | def merge_heads(self, x): 360 | x = x.permute(0, 2, 1, 3).contiguous() 361 | new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) 362 | return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states 363 | 364 | def split_heads(self, x, k=False): 365 | new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) 366 | x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states 367 | if k: 368 | return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length) 369 | else: 370 | return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) 371 | 372 | def forward(self, x, layer_past=None, attention_mask=None, head_mask=None, M=None, Mmask=None): 373 | x = self.c_attn(x) 374 | query = x 375 | key, value = self.c_memory(M).split(self.split_size,dim=2) 376 | 377 | query = self.split_heads(query) 378 | key = self.split_heads(key, k=True) 379 | value = self.split_heads(value) 380 | if layer_past is not None: 381 | past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below 382 | key = torch.cat((past_key, key), dim=-1) 383 | value = torch.cat((past_value, value), dim=-2) 384 | present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking 385 | 386 | attn_outputs = self._attn(query, key, value, attention_mask=attention_mask, head_mask=head_mask, M=M, Mmask=Mmask) 387 | a = attn_outputs[0] 388 | 389 | a = self.merge_heads(a) 390 | a = self.c_proj(a) 391 | a = self.resid_dropout(a) 392 | 393 | outputs = [a, present] + attn_outputs[1:] 394 | return outputs # a, present, (attentions) 395 | 396 | 397 | 398 | 399 | class GPT2MemoryBlock(nn.Module): 400 | def __init__(self, n_ctx, config, scale=False): 401 | super(GPT2MemoryBlock, self).__init__() 402 | nx = config.n_embd 403 | self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) 404 | self.attnextra = MemoryAttention(nx, n_ctx, config, scale) 405 | self.attn = Attention(nx, n_ctx, config, scale) 406 | self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) 407 | self.mlp = MLP(4 * nx, config) 408 | 409 | def forward(self, x, layer_past=None, attention_mask=None, head_mask=None, M=None, Mmask=None): 410 | lnx = self.ln_1(x) 411 | output_attnR = self.attn(lnx, 412 | layer_past=layer_past, 413 | attention_mask=attention_mask, 414 | head_mask=head_mask) 415 | aR = output_attnR[0] # output_attn: a, present, (attentions) 416 | output_attnL = self.attnextra(lnx, 417 | layer_past=layer_past, 418 | head_mask=head_mask, 419 | M= F.normalize(1e-7 + M, dim=-1), 420 | Mmask=Mmask) 421 | aL = output_attnL[0] # output_attn: a, present, (attentions) 422 | 423 | a = (aL + aR) / 2.0 424 | x = x + a 425 | m = self.mlp(self.ln_2(x)) 426 | x = x + m 427 | outputs = [x] + output_attnR[1:] 428 | return outputs # x, present, (attentions) 429 | 430 | 431 | class GPT2MemModel(GPT2Model): 432 | def __init__(self, config, use_dual_att=False): 433 | super(GPT2MemModel, self).__init__(config) 434 | del self.h 435 | self.h = nn.ModuleList([GPT2MemoryBlock(config.n_ctx, config, scale=True) for _ in range(config.n_layer)]) 436 | 437 | def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, M=None, Mmask=None, includeprev=False, x_prev=None): 438 | input_shape = input_ids.size() 439 | input_ids = input_ids.view(-1, input_shape[-1]) 440 | if token_type_ids is not None: 441 | token_type_ids = token_type_ids.view(-1, input_shape[-1]) 442 | if position_ids is not None: 443 | position_ids = position_ids.view(-1, input_shape[-1]) 444 | 445 | if past is None: 446 | past_length = 0 447 | past = [None] * len(self.h) 448 | else: 449 | past_length = past[0][0].size(-2) 450 | if position_ids is None: 451 | position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device) 452 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 453 | 454 | # Attention mask. 455 | if attention_mask is not None: 456 | attention_mask = attention_mask.view(-1, input_shape[-1]) 457 | attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 458 | attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 459 | attention_mask = (1.0 - attention_mask) * -10000.0 460 | 461 | 462 | if head_mask is not None: 463 | if head_mask.dim() == 1: 464 | head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) 465 | head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1) 466 | elif head_mask.dim() == 2: 467 | head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer 468 | head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility 469 | else: 470 | head_mask = [None] * self.config.n_layer 471 | 472 | inputs_embeds = self.wte(input_ids) 473 | position_embeds = self.wpe(position_ids) 474 | if token_type_ids is not None: 475 | token_type_embeds = self.wte(token_type_ids) 476 | else: 477 | token_type_embeds = 0 478 | 479 | 480 | ####### changed from inherited function: ##### 481 | if includeprev: 482 | x_prev = x_prev.unsqueeze(1) #[b,1,d] + [d] = [b,1,d] 483 | ## asli : input_embeds is not even used, commenting it out right now. 484 | inputs_embeds = torch.cat([x_prev, inputs_embeds[:,1:,:]], dim = 1) #x_prev: [b, 1, d], h : [b, s, d]-->[B, s+1, D] 485 | ########### END HERE ######################### 486 | 487 | hidden_states = inputs_embeds + position_embeds + token_type_embeds 488 | hidden_states = self.drop(hidden_states) 489 | 490 | output_shape = input_shape + (hidden_states.size(-1),) 491 | 492 | presents = () 493 | all_attentions = [] 494 | all_hidden_states = () 495 | for i, (block, layer_past) in enumerate(zip(self.h, past)): 496 | if self.output_hidden_states: 497 | all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),) 498 | 499 | outputs = block(hidden_states, 500 | layer_past=None, 501 | attention_mask=attention_mask, 502 | head_mask=head_mask[i], 503 | M= M, #changed from inherited function 504 | Mmask=Mmask) #changed from inherited function 505 | 506 | hidden_states, present = outputs[:2] 507 | if self.output_past: 508 | presents = presents + (present,) 509 | 510 | if self.output_attentions: 511 | all_attentions.append(outputs[2]) 512 | 513 | hidden_states = self.ln_f(hidden_states) 514 | 515 | hidden_states = hidden_states.view(*output_shape) 516 | # Add last hidden state 517 | if self.output_hidden_states: 518 | all_hidden_states = all_hidden_states + (hidden_states,) 519 | 520 | outputs = (hidden_states,) 521 | if self.output_past: 522 | outputs = outputs + (presents,) 523 | if self.output_hidden_states: 524 | outputs = outputs + (all_hidden_states,) 525 | if self.output_attentions: 526 | # let the number of heads free (-1) so we can extract attention even after head pruning 527 | attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:] 528 | all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions) 529 | outputs = outputs + (all_attentions,) 530 | return outputs # last hidden state, (presents), (all hidden_states), (attentions) 531 | 532 | class GPT2MemLMHeadModel(GPT2LMHeadModel): 533 | 534 | def __init__(self, config): 535 | super(GPT2MemLMHeadModel, self).__init__(config) 536 | self.transformer = GPT2MemModel(config) 537 | self.init_weights() 538 | self.tie_weights() 539 | 540 | def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, 541 | labels=None, M=None, Mmask=None, includeprev=False, x_prev=None): 542 | transformer_outputs = self.transformer(input_ids, 543 | past=past, 544 | attention_mask=attention_mask, 545 | token_type_ids=token_type_ids, 546 | position_ids=position_ids, 547 | head_mask=head_mask, 548 | M = M, 549 | Mmask = Mmask, includeprev=includeprev, x_prev=x_prev) 550 | hidden_states = transformer_outputs[0] 551 | 552 | lm_logits = self.lm_head(hidden_states) 553 | 554 | outputs = (lm_logits,) + transformer_outputs[1:] 555 | if labels is not None: 556 | # Shift so that tokens < n predict n 557 | shift_logits = lm_logits[..., :-1, :].contiguous() 558 | shift_labels = labels[..., 1:].contiguous() 559 | # Flatten the tokens 560 | loss_fct = CrossEntropyLoss(ignore_index=-1) 561 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), 562 | shift_labels.view(-1)) 563 | outputs = (loss,) + outputs 564 | 565 | outputs = (hidden_states,) + outputs 566 | return outputs # (loss), hidden_states, lm_logits, presents, (all hidden_states), (attentions) 567 | 568 | 569 | 570 | class GatedMemoryUpdate(nn.Module): 571 | """ Transformer model """ 572 | def __init__(self, cfg, n_ctx): 573 | super(GatedMemoryUpdate, self).__init__() 574 | 575 | self.W1 = torch.nn.Linear(cfg.n_embd, cfg.n_embd, bias=False) 576 | self.W2 = torch.nn.Linear(cfg.n_embd, cfg.n_embd, bias=False) 577 | self.W3 = torch.nn.Linear(cfg.n_embd, cfg.n_embd, bias=False) 578 | self.W4 = torch.nn.Linear(cfg.n_embd, cfg.n_embd, bias=False) 579 | 580 | def forward(self, Y, Mi, Ymask, Mmask): 581 | #Y=[B,1xd] Ymask=[B,1x1] Mi=[B,mxd] Mmask=[B,mx1] 582 | Mhat = torch.tanh(self.W1(Y.expand(-1, Mi.size(1), -1))+self.W2(Mi)) 583 | g = torch.sigmoid(self.W3(Y.expand(-1, Mi.size(1), -1))+self.W4(Mi)) 584 | Mnext = torch.mul(1-g, Mi) + torch.mul(g, Mhat) 585 | return F.normalize(1e-7+Mnext, dim = -1) 586 | 587 | 588 | class PlotMachinesModel(nn.Module): 589 | ''' full PlotMachines model: 590 | init params: 591 | cfg: command line argument settings 592 | vocab: total vocab size including special tokens 593 | n_ctx: total context including delimiters 594 | gen_len: total generation length including end tokens 595 | includeprev: use the neighboring (previous) paragraph in input 596 | lastidx: eos index in tokenizer 597 | use_offline_gpt2: true if we've already downloaded from huggingface to server 598 | ''' 599 | def __init__(self, cfg, vocab=40990, n_ctx=102, gen_len=401, return_probs=False, includeprev=False, lastidx=0, use_offline_gpt2=False): 600 | ###ctx: [/ kw<=100 _ ] gen<=400 == 503 601 | #LM mask:[0x101][1x401] 0 - padded 602 | super(PlotMachinesModel,self).__init__() 603 | self.n_ctx = n_ctx 604 | self.gen_len = gen_len 605 | self.lastidx = lastidx 606 | 607 | self.memupd = GatedMemoryUpdate(cfg, n_ctx-2+cfg.memstatesize) 608 | if use_offline_gpt2: 609 | self.lmmodel = GPT2MemLMHeadModel.from_pretrained('./gpt2model', n_positions=n_ctx + gen_len) 610 | elif cfg.debug_mode: 611 | self.lmmodel = GPT2MemLMHeadModel.from_pretrained('gpt2', n_positions=n_ctx + gen_len) 612 | else: 613 | self.lmmodel = GPT2MemLMHeadModel.from_pretrained('gpt2-medium', n_positions=n_ctx + gen_len) 614 | 615 | self.lmmodel.resize_token_embeddings(vocab) 616 | self.epsilon = 1e-8 617 | self.cfg = cfg 618 | pos_emb_mask = torch.zeros(1, 1, vocab) #+n_ctx+gen_len) 619 | self.includeprev = includeprev 620 | self.repeatfactor = cfg.repeattheta 621 | self.register_buffer('pos_emb_mask', pos_emb_mask) 622 | 623 | 624 | ''' Training step 625 | *args are expected to be in this format: 626 | x: [B, S] - batch of paragraphs encoded as token ids 627 | mask_output: [B, S] - masks over padding 628 | mem: [B, Memsize, D] - the initial memory 629 | mmask: [B, Memsize] - masks over any padded memory cells 630 | prev: [B, 10, D] - up to 10 previous paragraph encodings with which to update the memory 631 | pmask: [B, 10] - mask over previous paragraphs that aren't there 632 | pvect: [B, D] - previous paragraph encoding to use as neighboring input vector 633 | ''' 634 | def _forward(self, *args, log=False, return_probs=False, returnnewmem=False, returnlast=False, past=None, returnpasts=False): 635 | 636 | x, mask_output, mem, mmask, prev, pmask, pvect = args 637 | 638 | n_ctx = self.n_ctx 639 | #print(mem) 640 | if prev is not None: 641 | mem,mmask= self.updatememory(x,mem,mmask,prev,pmask) 642 | 643 | lmout = self.lmmodel(x, past=past, attention_mask=mask_output, M=mem, Mmask=mmask, includeprev=self.includeprev, x_prev=pvect) 644 | h_dec = lmout[0] 645 | lm_logits = lmout[1] 646 | presents = lmout[2] 647 | if returnpasts: 648 | return lm_logits,presents 649 | if returnlast: 650 | lasttoken = torch.where(x[:,:] == self.lastidx, torch.ones_like(x[:,:]), torch.zeros_like(x[:,:])).unsqueeze(-1) #[B,503,1] 651 | lasttoken = lasttoken.type_as(h_dec)*h_dec 652 | hdecmasked = lasttoken.sum(dim=1) #[B,768] 653 | return lm_logits, hdecmasked 654 | return lm_logits 655 | 656 | def updatememory(self, *args): 657 | x, mem, mmask, prev, pmask = args #xraw = [B,T] 658 | mem[:,: self.n_ctx-2, :] = self.lmmodel.transformer.wte(x[:,1:self.n_ctx-1]) 659 | if prev is not None: 660 | for p in range(prev.size(1)): 661 | U = prev[:,p,:] 662 | Umask = pmask[:,p,:]#.squeeze(1) 663 | update = (Umask.sum(dim=-1) != 0).view(-1,1) 664 | oldmem = mem 665 | if update.any() > 0: 666 | mem = (1-update.view(-1,1,1).float())* mem + (update.view(-1,1,1).float()) * self.memupd(U, mem, Umask, mmask) 667 | return mem, mmask 668 | 669 | 670 | ''' 671 | Forward function: 672 | Either performs decoding, training step, or updates the memory depending on parameters, default is to just do training step 673 | @param: 674 | *args: tuple of model inputs 675 | returnnewmem: if True, then update the memory 676 | generate: if True, then generate new tokens using decoding method 677 | 678 | text_encoder: tokenizer 679 | device: cpu, cuda 680 | beam, decoding_strategy, log: old params for compatability that are not in use 681 | k: if using top k sampling 682 | p: if using nucleus sampling 683 | gen_len: maximum length for decoding 684 | min_len: minimum length for decoding 685 | returnlast: training parameter - return the last token hidden state (this is not in use in the latest codebase) 686 | ''' 687 | def forward(self, *args, text_encoder=None, device=None, beam=0, gen_len=401, k=0, p=0, decoding_strategy=0, log=False, generate=False, min_len=None, returnlast=False, returnnewmem=False): 688 | if returnnewmem: 689 | return self.updatememory(*args) 690 | elif generate: 691 | return self.generate(*args, text_encoder=text_encoder, device=device, beam=beam, gen_len=gen_len, k=k, p=p, decoding_strategy=decoding_strategy, min_len=min_len, returnnewmem=returnnewmem) 692 | return self._forward(*args, log=log, returnlast=returnlast) 693 | 694 | def sample(self, *args, classify_idx=None, text_encoder=None, gen_len=401, k=0, p=0, decoding_strategy=0, min_len=None, eos_idx=None, returnnewmem = False): 695 | XMB, mask, mem, mmask, prev, pmask,pvect, seen_unigrams, idxes = args 696 | mem,mmask = self.updatememory(XMB, mem, mmask, prev, pmask) 697 | 698 | pasts = None 699 | for _ in range(gen_len): 700 | 701 | fargs =(XMB, mask[:, :XMB.size(-1)], mem, mmask, None, None, pvect) 702 | lm_logits = self._forward(*fargs) # past=pasts, returnpasts=True) 703 | lm_logits[:,-1, :] = lm_logits[:,-1,:] / seen_unigrams 704 | pem = copy.deepcopy(self.pos_emb_mask) 705 | if _ < min_len: 706 | pem[:,:,eos_idx] = -1e12 707 | 708 | 709 | lm_probs = F.softmax((lm_logits + pem), dim=-1) 710 | dist = lm_probs[:, -1, :].squeeze(1) 711 | if k == 0 and p == 0: 712 | next_idx = torch.multinomial(lm_probs[:, -1, :], 1) 713 | else: 714 | if p ==0: 715 | # Sample from top k 716 | 717 | values, indices = dist.topk(k) 718 | next_idx = indices.gather(-1, torch.multinomial(values, 1)) 719 | else: 720 | indices = torch.argsort(dist,dim=1,descending=True) 721 | values = dist.gather(-1,indices) 722 | probsum = torch.cumsum(values,dim=1) 723 | include = ~ ((probsum.gt(p*.01)) & ((probsum-values).gt(p*.01))) 724 | newdist = torch.where(include, values, torch.zeros_like(values) + 1e-10) 725 | next_idx = indices.gather(-1, torch.multinomial(newdist, 1)) 726 | for i in range(XMB.size(0)): 727 | seen_unigrams[i, next_idx[i]] = self.repeatfactor 728 | XMB = self.append_batch(XMB, next_idx) 729 | return XMB[:, -gen_len:], seen_unigrams 730 | 731 | 732 | 733 | ''' Generate: 734 | *args are expected to be in this format: 735 | 736 | pad_output: [B, S] - batch of plot outline contexts encoded as token ids 737 | mask: [B, S] - masks over padding in outline contexts 738 | mem: [B, Memsize, D] - the initial memory 739 | mmask: [B, Memsize] - masks over any padded memory cells 740 | prev: [B, 10, D] - up to 10 previous paragraph encodings with which to update the memory 741 | pmask: [B, 10] - mask over previous paragraphs that aren't there 742 | xprev: [B, D] - previous paragraph encoding to use as neighboring input vector 743 | seen_unigrams [B, V] - previously used tokens in previous paragraphs 744 | idxes: [B] - the doc ids 745 | note: S= ctx + gen_len even though the generation will be blank tokens before decoding 746 | ''' 747 | def generate(self, *args, text_encoder=None, device=None, beam=0, gen_len=401, k=0, p=0, decoding_strategy=0, min_len=None, returnnewmem=False): 748 | 749 | if len(args) == 9: 750 | pad_output, mask, mem, mmask, prev, pmask, xprev, seen_unigrams, idxes = args 751 | else: 752 | pad_output, mask, mem, mmask, prev, pmask, xprev = args 753 | seen_unigrams = torch.ones(pad_output.size(0), len(text_encoder)).to(pad_output.device) 754 | idxes = None 755 | classify_idx = None #not in use by generation code anymore 756 | eos_idx = text_encoder.eos_token_id 757 | input_toks = pad_output[:, :self.n_ctx] # includes delimiter 758 | target_toks = pad_output[:, -gen_len:] 759 | mask_pad = torch.ones(mask.size()).type_as(mask) 760 | mask_pad[:, :self.n_ctx] = mask[:, :self.n_ctx] 761 | mask = mask_pad 762 | pad_output = pad_output.to(device) 763 | XMB = pad_output[:, :self.n_ctx] 764 | if beam == 0: 765 | generated_toks, seen_unigrams = self.sample(XMB, mask, mem, mmask, prev, pmask, xprev, seen_unigrams, idxes, classify_idx=classify_idx, text_encoder=text_encoder, gen_len=gen_len, k=k, p=p, decoding_strategy=decoding_strategy, min_len=min_len, eos_idx=eos_idx, returnnewmem=returnnewmem) 766 | return generated_toks.type_as(XMB), input_toks.type_as(XMB), target_toks.type_as(XMB), seen_unigrams 767 | else: 768 | raise NotImplementedError 769 | 770 | 771 | 772 | def append_batch(self, X, next_idx): 773 | return torch.cat((X, next_idx), 1) 774 | 775 | -------------------------------------------------------------------------------- /src/model/parallel.py: -------------------------------------------------------------------------------- 1 | # Code referenced from https://gist.github.com/thomwolf/7e2407fbd5945f07821adae3d9fd1312 2 | 3 | """Encoding Data Parallel""" 4 | import threading 5 | import functools 6 | import torch 7 | from torch.autograd import Variable, Function 8 | import torch.cuda.comm as comm 9 | from torch.nn.parallel.data_parallel import DataParallel 10 | from torch.nn.parallel.parallel_apply import get_a_var 11 | from torch.nn.parallel.scatter_gather import gather 12 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 13 | 14 | torch_ver = torch.__version__[:3] 15 | 16 | 17 | __all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion', 18 | 'patch_replication_callback'] 19 | 20 | def allreduce(*inputs): 21 | """Cross GPU all reduce autograd operation for calculate mean and 22 | variance in SyncBN. 23 | """ 24 | return AllReduce.apply(*inputs) 25 | 26 | class AllReduce(Function): 27 | @staticmethod 28 | def forward(ctx, num_inputs, *inputs): 29 | ctx.num_inputs = num_inputs 30 | ctx.target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)] 31 | inputs = [inputs[i:i + num_inputs] 32 | for i in range(0, len(inputs), num_inputs)] 33 | # sort before reduce sum 34 | inputs = sorted(inputs, key=lambda i: i[0].get_device()) 35 | results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0]) 36 | outputs = comm.broadcast_coalesced(results, ctx.target_gpus) 37 | return tuple([t for tensors in outputs for t in tensors]) 38 | 39 | @staticmethod 40 | def backward(ctx, *inputs): 41 | inputs = [i.data for i in inputs] 42 | inputs = [inputs[i:i + ctx.num_inputs] 43 | for i in range(0, len(inputs), ctx.num_inputs)] 44 | results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0]) 45 | outputs = comm.broadcast_coalesced(results, ctx.target_gpus) 46 | return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors]) 47 | 48 | 49 | class Reduce(Function): 50 | @staticmethod 51 | def forward(ctx, *inputs): 52 | ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))] 53 | inputs = sorted(inputs, key=lambda i: i.get_device()) 54 | return comm.reduce_add(inputs) 55 | 56 | @staticmethod 57 | def backward(ctx, gradOutput): 58 | return Broadcast.apply(ctx.target_gpus, gradOutput) 59 | 60 | class DistributedDataParallelModel(torch.nn.parallel.distributed.DistributedDataParallel): 61 | """Implements data parallelism at the module level for the DistributedDataParallel module. 62 | This container parallelizes the application of the given module by 63 | splitting the input across the specified devices by chunking in the 64 | batch dimension. 65 | In the forward pass, the module is replicated on each device, 66 | and each replica handles a portion of the input. During the backwards pass, 67 | gradients from each replica are summed into the original module. 68 | Note that the outputs are not gathered, please use compatible 69 | :class:`encoding.parallel.DataParallelCriterion`. 70 | The batch size should be larger than the number of GPUs used. It should 71 | also be an integer multiple of the number of GPUs so that each chunk is 72 | the same size (so that each GPU processes the same number of samples). 73 | Args: 74 | module: module to be parallelized 75 | device_ids: CUDA devices (default: all devices) 76 | Reference: 77 | Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, 78 | Amit Agrawal. “Context Encoding for Semantic Segmentation. 79 | *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* 80 | Example:: 81 | >>> net = encoding.nn.DistributedDataParallelModel(model, device_ids=[0, 1, 2]) 82 | >>> y = net(x) 83 | """ 84 | def gather(self, outputs, output_device): 85 | return outputs 86 | 87 | class DataParallelModel(DataParallel): 88 | """Implements data parallelism at the module level. 89 | 90 | This container parallelizes the application of the given module by 91 | splitting the input across the specified devices by chunking in the 92 | batch dimension. 93 | In the forward pass, the module is replicated on each device, 94 | and each replica handles a portion of the input. During the backwards pass, 95 | gradients from each replica are summed into the original module. 96 | Note that the outputs are not gathered, please use compatible 97 | :class:`encoding.parallel.DataParallelCriterion`. 98 | 99 | The batch size should be larger than the number of GPUs used. It should 100 | also be an integer multiple of the number of GPUs so that each chunk is 101 | the same size (so that each GPU processes the same number of samples). 102 | 103 | Args: 104 | module: module to be parallelized 105 | device_ids: CUDA devices (default: all devices) 106 | 107 | Reference: 108 | Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, 109 | Amit Agrawal. “Context Encoding for Semantic Segmentation. 110 | *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* 111 | 112 | Example:: 113 | 114 | >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2]) 115 | >>> y = net(x) 116 | """ 117 | def gather(self, outputs, output_device): 118 | return outputs 119 | 120 | def replicate(self, module, device_ids): 121 | modules = super(DataParallelModel, self).replicate(module, device_ids) 122 | execute_replication_callbacks(modules) 123 | return modules 124 | 125 | 126 | class DataParallelCriterion(DataParallel): 127 | """ 128 | Calculate loss in multiple-GPUs, which balance the memory usage. 129 | The targets are splitted across the specified devices by chunking in 130 | the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`. 131 | 132 | Reference: 133 | Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, 134 | Amit Agrawal. “Context Encoding for Semantic Segmentation. 135 | *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* 136 | 137 | Example:: 138 | 139 | >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2]) 140 | >>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2]) 141 | >>> y = net(x) 142 | >>> loss = criterion(y, target) 143 | """ 144 | #def forward(self, inputs, *targets, **kwargs): 145 | def forward(self, inputs, *targets, only_return_losses=False, **kwargs): 146 | # input should be already scatterd 147 | # scattering the targets instead 148 | if not self.device_ids: 149 | return self.module(inputs, *targets, only_return_losses=only_return_losses, **kwargs) 150 | targets, kwargs = self.scatter(targets, kwargs, self.device_ids) 151 | if len(self.device_ids) == 1: 152 | return self.module(inputs, *targets[0], **kwargs[0]) 153 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) 154 | outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs) 155 | return self.gather(outputs, self.output_device) 156 | 157 | 158 | def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None): 159 | assert len(modules) == len(inputs) 160 | assert len(targets) == len(inputs) 161 | if kwargs_tup: 162 | assert len(modules) == len(kwargs_tup) 163 | else: 164 | kwargs_tup = ({},) * len(modules) 165 | if devices is not None: 166 | assert len(modules) == len(devices) 167 | else: 168 | devices = [None] * len(modules) 169 | 170 | lock = threading.Lock() 171 | results = {} 172 | if torch_ver != "0.3": 173 | grad_enabled = torch.is_grad_enabled() 174 | 175 | def _worker(i, module, input, target, kwargs, device=None): 176 | if torch_ver != "0.3": 177 | torch.set_grad_enabled(grad_enabled) 178 | if device is None: 179 | device = get_a_var(input).get_device() 180 | try: 181 | with torch.cuda.device(device): 182 | # this also avoids accidental slicing of `input` if it is a Tensor 183 | if not isinstance(input, (list, tuple)): 184 | input = (input,) 185 | if not isinstance(target, (list, tuple)): 186 | target = (target,) 187 | output = module(*(input + target), **kwargs) 188 | with lock: 189 | results[i] = output 190 | except Exception as e: 191 | with lock: 192 | results[i] = e 193 | 194 | if len(modules) > 1: 195 | threads = [threading.Thread(target=_worker, 196 | args=(i, module, input, target, 197 | kwargs, device),) 198 | for i, (module, input, target, kwargs, device) in 199 | enumerate(zip(modules, inputs, targets, kwargs_tup, devices))] 200 | 201 | for thread in threads: 202 | thread.start() 203 | for thread in threads: 204 | thread.join() 205 | else: 206 | _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) 207 | 208 | outputs = [] 209 | for i in range(len(inputs)): 210 | output = results[i] 211 | if isinstance(output, Exception): 212 | raise output 213 | outputs.append(output) 214 | return outputs 215 | 216 | 217 | ########################################################################### 218 | # Adapted from Synchronized-BatchNorm-PyTorch. 219 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 220 | # 221 | class CallbackContext(object): 222 | pass 223 | 224 | 225 | def execute_replication_callbacks(modules): 226 | """ 227 | Execute an replication callback `__data_parallel_replicate__` on each module created 228 | by original replication. 229 | 230 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 231 | 232 | Note that, as all modules are isomorphism, we assign each sub-module with a context 233 | (shared among multiple copies of this module on different devices). 234 | Through this context, different copies can share some information. 235 | 236 | We guarantee that the callback on the master copy (the first copy) will be called ahead 237 | of calling the callback of any slave copies. 238 | """ 239 | master_copy = modules[0] 240 | nr_modules = len(list(master_copy.modules())) 241 | ctxs = [CallbackContext() for _ in range(nr_modules)] 242 | 243 | for i, module in enumerate(modules): 244 | for j, m in enumerate(module.modules()): 245 | if hasattr(m, '__data_parallel_replicate__'): 246 | m.__data_parallel_replicate__(ctxs[j], i) 247 | 248 | 249 | def patch_replication_callback(data_parallel): 250 | """ 251 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 252 | Useful when you have customized `DataParallel` implementation. 253 | 254 | Examples: 255 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 256 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 257 | > patch_replication_callback(sync_bn) 258 | # this is equivalent to 259 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 260 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 261 | """ 262 | 263 | assert isinstance(data_parallel, DataParallel) 264 | 265 | old_replicate = data_parallel.replicate 266 | 267 | @functools.wraps(old_replicate) 268 | def new_replicate(module, device_ids): 269 | modules = old_replicate(module, device_ids) 270 | execute_replication_callbacks(modules) 271 | return modules 272 | 273 | data_parallel.replicate = new_replicate 274 | -------------------------------------------------------------------------------- /src/model/train.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import argparse 3 | import os 4 | import random 5 | import numpy as np 6 | import rouge 7 | import torch 8 | from torch import nn 9 | from tqdm import tqdm 10 | import math 11 | 12 | from data_loader import get_paragraph_input_loader, get_paragraph_memory_input_loader 13 | from eval_utils import format_text, evaluate_doc_model 14 | from generate_utils import generate_paragraph 15 | from model import GPT2BaseModel, PlotMachinesModel 16 | from logger import Logger 17 | from loss import ParagraphLoss 18 | from parallel import DataParallelModel, DataParallelCriterion 19 | from transformers import * 20 | 21 | def get_average_scores(hyps, refs, maxlen=400, stop_words=[]): 22 | rouge_scorer = rouge.Rouge() 23 | averaged_scores = {'rouge-1': {'f': 0, 'p': 0, 'r': 0}, 24 | 'rouge-2': {'f': 0, 'p': 0, 'r': 0}, 25 | 'rouge-l': {'f': 0, 'p': 0, 'r': 0}} 26 | 27 | scores = rouge_scorer.get_scores(hyps, refs) 28 | for metric in averaged_scores.keys(): 29 | for values in scores: 30 | for sub_metric in averaged_scores[metric]: 31 | averaged_scores[metric][sub_metric] += values[metric][sub_metric] 32 | for key in averaged_scores.keys(): 33 | for sub_key in averaged_scores[key].keys(): 34 | averaged_scores[key][sub_key] /= len(hyps) 35 | return averaged_scores 36 | 37 | def run_batch(model, args, device, compute_loss_fct): 38 | for arg in args: 39 | if arg is not None: 40 | arg = arg.to(device) 41 | 42 | output = model(*args) 43 | allloss = compute_loss_fct(output, args[0], args[1]) 44 | 45 | return allloss.mean() 46 | 47 | def save_checkpoint(iter_num, running_loss, model_state_dict, optimizer_state_dict, save_dir,my_local_dir): 48 | print('Saving a checkpoint...' + my_local_dir) 49 | torch.save({ 50 | "iter": iter_num, 51 | "running_loss": running_loss, 52 | "state_dict": model_state_dict, 53 | "optimizer": optimizer_state_dict 54 | }, os.path.join(my_local_dir, "checkpoint_best.pt")) 55 | 56 | trial = 0 57 | while trial < 10: 58 | try: 59 | print('Copying a checkpoint...from ' + my_local_dir + ' to ' + save_dir) 60 | shutil.copy(os.path.join(my_local_dir, "checkpoint_best.pt"), os.path.join(save_dir, "checkpoint_best.pt")) 61 | trial = 100 62 | except Exception as e: 63 | print(e) 64 | os.makedirs(save_dir, exist_ok=True) 65 | trial += 1 66 | 67 | def load_checkpoint(checkpoint_file, model, model_opt): 68 | """ 69 | Loads a checkpoint including model state and running loss for continued training 70 | """ 71 | if checkpoint_file is not None: 72 | checkpoint = torch.load(checkpoint_file) 73 | state_dict = checkpoint["state_dict"] 74 | start_iter = checkpoint['iter'] 75 | running_loss = checkpoint['running_loss'] 76 | opt_state_dict = checkpoint['optimizer'] 77 | model_opt.load_state_dict(opt_state_dict) 78 | for state in model_opt.state.values(): 79 | for key, value in state.items(): 80 | if isinstance(value, torch.Tensor): 81 | state[key] = value.cuda() 82 | model.load_state_dict(state_dict) 83 | else: 84 | start_iter = 1 85 | running_loss = 0 86 | return start_iter, running_loss 87 | 88 | 89 | def evaluate(val_loader, train_log_interval, model, text_encoder, device, beam, gen_len, k,p, decoding_strategy, compute_loss_fct, min_len=10): 90 | hyps, refs = [], [] 91 | val_loss = 0 92 | for j, args in enumerate(val_loader): 93 | with torch.no_grad(): 94 | if j <= 5: 95 | #evaluate Rouge on a very small subset of dev examples just to double check that training is working 96 | model.eval() 97 | # Generating outputs for evaluation 98 | src_strs, new_refs, new_hyps = generate_paragraph(model, args, text_encoder, device, beam, gen_len, k,p, decoding_strategy, min_len=min_len) 99 | hyps.extend(new_hyps) 100 | refs.extend(new_refs) 101 | # Calculating loss 102 | l = run_batch(model, args, device, compute_loss_fct) 103 | val_loss += float(l.item()) 104 | try: 105 | print('Hypothesis: {}'.format(hyps[0])) 106 | print("Reference: {}".format(refs[0])) 107 | except: 108 | pass 109 | scores = get_average_scores(hyps, refs, maxlen=gen_len) 110 | # scores = None 111 | return val_loss, scores 112 | 113 | def get_loss_value(num, denom): 114 | """Log a scalar variable.""" 115 | value = num/denom 116 | return value 117 | 118 | 119 | '''Run a single training epoch: 120 | @params- 121 | bestloss: the best loss over any evaluation on the dev set 122 | start_iter: the batch in the epoch to start with 123 | running_loss: the total loss since the last checkpoint update 124 | model: the model being trained 125 | compute_loss_fct: a loss function (from loss.py) 126 | model_opt: the argparse options 127 | train_loader, val_loader: training and validation data loaders 128 | train_log_interval,val_log_interval: how often to log training and validation losses 129 | device: cuda or cpu 130 | beam, gen_len, k, p, decoding_strategy: decoding parameters 131 | accum_iter: how often to run backprop 132 | desc_str: string for showing progress, 133 | save_dir: where to save checkpoints, 134 | logger: class for logging progress (mostly for debugging), 135 | text_encoder: the tokenizer, 136 | show_progress=False: whether to log progress to the command line 137 | summary_loss=None: not in use anymore 138 | my_local_dir='checkpoints_local': a local checkpoint storage if running on servers 139 | ''' 140 | def run_epoch(bestloss, start_iter, running_loss, model, compute_loss_fct, model_opt, train_loader, val_loader, train_log_interval, val_log_interval, device, beam, gen_len, k,p, decoding_strategy, accum_iter, desc_str, save_dir, logger, text_encoder, show_progress=False, summary_loss=None, my_local_dir='checkpoints_local'): 141 | ''' 142 | Run a single epoch, log results, and save best checkpoint 143 | ''' 144 | if show_progress: 145 | train_bar = tqdm(iterable=train_loader, desc=desc_str) 146 | else: 147 | train_bar = train_loader 148 | 149 | for i, batchargs in enumerate(train_bar, start_iter): 150 | num_updates = i // accum_iter 151 | model.train() 152 | loss = run_batch(model, batchargs, device, compute_loss_fct) 153 | loss.backward() 154 | 155 | running_loss += float(loss.item()) 156 | if show_progress: 157 | train_bar.set_postfix(loss=running_loss / ((train_log_interval * accum_iter) if num_updates % train_log_interval == 0 and num_updates != 0 else i % (train_log_interval * accum_iter))) 158 | 159 | if i % accum_iter == 0: 160 | model_opt.step() 161 | model_opt.zero_grad() 162 | torch.cuda.empty_cache() 163 | if num_updates % train_log_interval == 0 and i % accum_iter == 0: 164 | logger.scalar_summary("Training", num=running_loss, denom=(train_log_interval * accum_iter), step=num_updates) 165 | print("training loss %.2f" % (running_loss/float(train_log_interval * accum_iter))) 166 | running_loss = 0 167 | 168 | if num_updates % 1000 == 0 and i % accum_iter == 0: 169 | val_loss, scores = evaluate(val_loader, train_log_interval, model, text_encoder, device, beam, gen_len, k, p, decoding_strategy, compute_loss_fct, min_len=args.min_len) 170 | 171 | logger.scalar_summary("Validation", num=val_loss, denom=len(val_loader), step=num_updates) 172 | # if sum(val_loss) < bestloss or bestloss == -1: 173 | lv = get_loss_value(val_loss, len(val_loader)) 174 | if (not math.isnan(lv)) and (bestloss == -1 or lv < bestloss): 175 | bestloss = lv 176 | save_checkpoint(i + 1, running_loss, model.state_dict(), model_opt.state_dict(), save_dir, my_local_dir) 177 | 178 | 179 | val_loss, scores = evaluate(val_loader, train_log_interval, model, text_encoder, device, beam, gen_len, k, p, decoding_strategy, compute_loss_fct, min_len=args.min_len) 180 | for key, value in scores.items(): 181 | for key2, value2 in value.items(): 182 | logger.rouge_summary("{}/{}".format(key, key2), value2, num_updates) 183 | print("Validation rouge: " + str(scores.items())) 184 | logger.scalar_summary("Validation", num=val_loss, denom=len(val_loader), step=num_updates) 185 | lv = get_loss_value(val_loss, len(val_loader)) 186 | if (not math.isnan(lv)) and (bestloss == -1 or lv < bestloss): 187 | bestloss = lv 188 | save_checkpoint(i + 1, running_loss, model.state_dict(), model_opt.state_dict(), save_dir, my_local_dir) 189 | 190 | 191 | torch.cuda.empty_cache() 192 | return i + 1, running_loss, bestloss, num_updates, lv 193 | 194 | def print_model_params(log_dir, doc_model): 195 | fm = open(log_dir+"/modeldescr.txt","w") 196 | fm.write(str(doc_model)) 197 | fm.close() 198 | print(sum(p.numel() for p in doc_model.parameters() if p.requires_grad)) 199 | 200 | 201 | def init(args): 202 | print("Creating directories") 203 | os.makedirs(args.output_dir, exist_ok=True) 204 | os.makedirs(os.path.join(args.output_dir, args.experiment_name), exist_ok=True) 205 | os.makedirs(os.path.join(args.output_dir, args.experiment_name), exist_ok=True) 206 | 207 | random.seed(args.seed) 208 | np.random.seed(args.seed) 209 | torch.manual_seed(args.seed) 210 | torch.cuda.manual_seed_all(args.seed) 211 | 212 | def main(args): 213 | init(args) 214 | #Args setup: 215 | save_dir = os.path.join(args.output_dir, args.experiment_name, "checkpoints") 216 | save_dir_local = "checkpoints_local" 217 | desc = args.desc 218 | data_dir = args.data_dir 219 | log_dir = os.path.join(args.output_dir, args.experiment_name, "logs") 220 | os.makedirs(log_dir, exist_ok=True) 221 | os.makedirs(save_dir, exist_ok=True) 222 | os.makedirs(save_dir_local, exist_ok=True) 223 | 224 | train_log_interval = args.train_log_interval 225 | val_log_interval = args.val_log_interval 226 | beam = args.beam 227 | p = args.p 228 | n_ctx = args.n_ctx 229 | gen_len = args.gen_len 230 | k = args.k 231 | decoding_strategy = args.decoding_strategy 232 | accum_iter = args.accum_iter 233 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 234 | n_gpu = torch.cuda.device_count() 235 | print("device", device, "n_gpu", n_gpu) 236 | logger = Logger(log_dir) 237 | 238 | #Text Encoder 239 | if args.use_offline_gpt2: 240 | text_encoder = GPT2Tokenizer.from_pretrained('./gpt2model') 241 | elif args.debug_mode: 242 | text_encoder = GPT2Tokenizer.from_pretrained('gpt2') 243 | else: 244 | text_encoder = GPT2Tokenizer.from_pretrained('gpt2-medium') 245 | 246 | text_encoder.add_special_tokens({'bos_token':'_start_', 247 | 'cls_token':'_classify_', 248 | 'eos_token':'_end_', 249 | 'additional_special_tokens': ['_kw_','_endkw_', '_t_', '_i_', '_b_', '_c_'] 250 | }) 251 | 252 | vocab = len(text_encoder) 253 | 254 | print("Loading dataset...") 255 | if args.use_model == "base": 256 | train_loader = get_paragraph_input_loader(os.path.join(data_dir, "train_encoded.csv"), args.n_batch, text_encoder, 257 | num_workers=3, shuffle=True, gen_len=gen_len, n_ctx=n_ctx, include_discourse_type=args.use_discourse, 258 | include_neigh= args.use_neighbor_feat, max_size=args.max_ex, 259 | include_kw = not args.exclude_kw, dim = args.n_embd, debug_mode=args.debug_mode) 260 | 261 | val_loader = get_paragraph_input_loader(os.path.join(data_dir, "val_encoded.csv"), n_gpu, text_encoder, 262 | num_workers=0, shuffle=False, gen_len=gen_len, n_ctx=n_ctx, include_discourse_type=args.use_discourse, 263 | include_neigh= args.use_neighbor_feat, max_size=args.num_val_examples, 264 | include_kw = not args.exclude_kw, dim = args.n_embd, debug_mode=args.debug_mode) 265 | 266 | print("Train length: {}, Validation length: {}".format(len(train_loader), len(val_loader))) 267 | doc_model = GPT2BaseModel(args, vocab=vocab, n_ctx=n_ctx, gen_len=gen_len, lastidx=text_encoder.eos_token_id, includeprev=args.use_neighbor_feat, use_offline_gpt2 = args.use_offline_gpt2) 268 | 269 | elif args.use_model == "plotmachines": 270 | #asli 271 | train_loader = get_paragraph_memory_input_loader(os.path.join(data_dir, "train_encoded.csv"), args.n_batch, text_encoder, 272 | num_workers=3, shuffle=True, gen_len=gen_len, n_ctx=n_ctx, include_discourse_type=args.use_discourse, 273 | include_neigh= args.use_neighbor_feat, max_size = args.max_ex, 274 | include_kw = not args.exclude_kw, memsize=args.memstatesize, dim = args.n_embd, use_kwmem=True, debug_mode=args.debug_mode) 275 | 276 | val_loader = get_paragraph_memory_input_loader(os.path.join(data_dir, "val_encoded.csv"), n_gpu, text_encoder, 277 | num_workers=0, shuffle=False, gen_len=gen_len, n_ctx=n_ctx, include_discourse_type=args.use_discourse, 278 | include_neigh= args.use_neighbor_feat, max_size = args.num_val_examples, 279 | include_kw = not args.exclude_kw, memsize=args.memstatesize, dim = args.n_embd, use_kwmem=True, debug_mode=args.debug_mode) 280 | 281 | print("Train length: {}, Validation length: {}".format(len(train_loader), len(val_loader))) 282 | doc_model = PlotMachinesModel(args, vocab=vocab, n_ctx=n_ctx, gen_len=gen_len, lastidx=text_encoder.eos_token_id, includeprev=args.use_neighbor_feat, use_offline_gpt2 = args.use_offline_gpt2) 283 | 284 | 285 | 286 | n_updates_total = (len(train_loader) // args.accum_iter) * (args.num_epochs) 287 | 288 | if args.debug_mode: 289 | print_model_params(log_dir, doc_model) 290 | 291 | criterion = nn.CrossEntropyLoss(reduction="none") 292 | 293 | model_opt = AdamW(filter(lambda p : p.requires_grad, doc_model.parameters()), 294 | lr=args.lr, 295 | betas=(args.b1,args.b2), 296 | eps=args.e) 297 | 298 | lm_loss = ParagraphLoss(criterion, n_ctx=n_ctx, gen_len=gen_len) 299 | 300 | print("Loading Model") 301 | doc_model.to(device) 302 | if n_gpu > 1: 303 | doc_model = DataParallelModel(doc_model) 304 | lm_loss = DataParallelCriterion(lm_loss) 305 | print("Parallelized") 306 | 307 | bestloss = -1 308 | start_iter, running_loss = 1,0 309 | prevloss = 1000 310 | 311 | start_iter, running_loss = load_checkpoint(args.checkpoint, doc_model, model_opt) 312 | for i in range(args.num_epochs): 313 | start_iter, running_loss, bestloss, updates, val_loss1 = run_epoch(bestloss, start_iter, running_loss, doc_model, lm_loss, model_opt, train_loader, val_loader, train_log_interval, val_log_interval, device, beam, gen_len, k, p, decoding_strategy, accum_iter, "FT Training Epoch [{}/{}]".format(i + 1, args.num_epochs), save_dir, logger, text_encoder, show_progress=args.show_progress, my_local_dir=save_dir_local) 314 | print("VAL LOSS: ", str(val_loss1)) 315 | if val_loss1 > prevloss or math.isnan(val_loss1): 316 | break 317 | prevloss = val_loss1 318 | 319 | 320 | print('Done training...') 321 | print('Evaluating on validation with best checkpoint...') 322 | 323 | bestcheck = os.path.join(save_dir,"checkpoint_best.pt") 324 | checkpoint = torch.load(bestcheck, map_location='cpu') 325 | state_dict = checkpoint["state_dict"] 326 | if state_dict.get('module.pos_emb_mask') is None and doc_model.state_dict().get('module.pos_emb_mask') is not None: 327 | state_dict['module.pos_emb_mask'] = doc_model.state_dict().get('module.pos_emb_mask') 328 | doc_model.load_state_dict(state_dict) 329 | evaluate_doc_model(doc_model, val_loader, text_encoder, device, beam, gen_len, k, p, args.decoding_strategy, os.path.join(save_dir,'valeval.log'), 'gen','tgt', gen_len, [], args) 330 | 331 | 332 | 333 | 334 | if __name__ == "__main__": 335 | parser = argparse.ArgumentParser() 336 | parser.add_argument('--desc', type=str, help="Description") 337 | parser.add_argument('--output_hidden_states', action='store_true') 338 | parser.add_argument('--output_attentions', action='store_true') 339 | parser.add_argument('--output_past', action='store_true') 340 | parser.add_argument('--seed', type=int, default=42) 341 | parser.add_argument('--num_epochs', type=int, default=10) 342 | parser.add_argument('--n_batch', type=int, default=2) 343 | parser.add_argument('--max_grad_norm', type=int, default=1) 344 | parser.add_argument('--lr', type=float, default=6.25e-5) 345 | parser.add_argument('--lr_warmup', type=float, default=0.002) 346 | parser.add_argument('--n_embd', type=int, default=1024) 347 | parser.add_argument('--n_head', type=int, default=12) 348 | parser.add_argument('--n_layer', type=int, default=12) 349 | parser.add_argument('--embd_pdrop', type=float, default=0.1) 350 | parser.add_argument('--attn_pdrop', type=float, default=0.1) 351 | parser.add_argument('--resid_pdrop', type=float, default=0.1) 352 | parser.add_argument('--clf_pdrop', type=float, default=0.1) 353 | parser.add_argument('--l2', type=float, default=0.01) 354 | parser.add_argument('--vector_l2', action='store_true') 355 | parser.add_argument('--opt', type=str, default='adam') 356 | parser.add_argument('--afn', type=str, default='gelu') 357 | parser.add_argument('--lr_schedule', type=str, default='warmup_linear') 358 | parser.add_argument('--n_transfer', type=int, default=12) 359 | parser.add_argument('--lm_coef', type=float, default=0.5) 360 | parser.add_argument('--b1', type=float, default=0.9) 361 | parser.add_argument('--b2', type=float, default=0.999) 362 | parser.add_argument('--e', type=float, default=1e-8) 363 | # Custom 364 | parser.add_argument('--output_dir', type=str, default=os.getenv('PT_OUTPUT_DIR', '/tmp'), help='directory to save logs and checkpoints to') 365 | parser.add_argument('--experiment_name', type=str, required=True, help='name of this experiment will be included in output') 366 | parser.add_argument('--data_dir', type=str, default='data', help='directory with train, dev, test files') 367 | parser.add_argument('--train_log_interval', type=int, default=100, help='number of train steps before logging training progress') 368 | parser.add_argument('--val_log_interval', type=int, default=2000, help='number of train steps before logging validation progress') 369 | parser.add_argument('--num_val_examples', type=int, default=None, help='max number of validation examples, or None:use all data') 370 | parser.add_argument('--beam', type=int, default=0, help='beam size for beam search - not in use') 371 | parser.add_argument('--k', type=int, default=0, help='k for TopK sampling') 372 | parser.add_argument('--p', type=int, default=0, help='p for Nucleus sampling') 373 | parser.add_argument('--decoding_strategy', type=int, default=0, help='not in use') 374 | parser.add_argument('--accum_iter', type=int, default=2, help='number of batches to accumulate gradiencts before doing backprop') 375 | parser.add_argument('--gen_len', type=int, default=922, help='max generation length + 1 for end token') 376 | parser.add_argument('--n_ctx', type=int, default=102, help='max outline length + 2 for delimiters') 377 | parser.add_argument('--show_progress', action='store_true') 378 | parser.add_argument('--exclude_kw', action='store_true', help='unconditional baseline') 379 | parser.add_argument('--max_ex', type=int, default=None, help='max number of train examples, or None:use all training data') 380 | parser.add_argument('--min_len', type=int, default=100, help='minimum generation length') 381 | parser.add_argument('--repeattheta', type=float, default=1.5, help='how much to penalize repitition (1 is not at all, > 1 is more penalty)') 382 | parser.add_argument('--memstatesize', type=int, default=100, help='size of global document state portion of memory (default:100)') 383 | parser.add_argument('--use_model', type=str, choices=['base', 'plotmachines'], help='full plotmachines (w/ memory) vs base gpt (no memory)') 384 | parser.add_argument('--use_neighbor_feat', action='store_true', help='use neighboring (previous) paragraph encoding as extra input') 385 | parser.add_argument('--use_discourse', action='store_true', help='use discouse tokens as extra input') 386 | parser.add_argument('--use_offline_gpt2', action='store_true') 387 | parser.add_argument('--checkpoint', type=str, default=None, help='location of a previous checkpoint') 388 | parser.add_argument('--debug_mode', action='store_true') 389 | 390 | args = parser.parse_args() 391 | print(torch.__version__) 392 | print(args) 393 | main(args) 394 | -------------------------------------------------------------------------------- /src/preprocessing/README.md: -------------------------------------------------------------------------------- 1 | 2 | This repo includes the dataset extraction and preprocessing scripts. 3 | 4 | We construct three datasets for outline-conditioned generation. We focus on fictitious generation, but also include the news domain for generalization. We build on existing publicly available datasets for the target narratives, paired with automatically constructed input outlines as described in detail in our paper. Here we provide the dataset ids to and the preprocessing scripts to construct the train/validation/test splits for experimentation. 5 | 6 | # Prerequisites 7 | 8 | numpy 9 | 10 | Rake 11 | 12 | nltk 13 | 14 | TfidfVectorizer 15 | 16 | # Pre-processing Data 17 | ## WikiPlots 18 | 19 | #### 1. Steps for downloading 20 | The Wikiplots corpus consists of plots of movies, TV shows, and books scraped from Wikipedia. 21 | Please use the scripts provided in the link to extract the dataset (we used Wikipedia from the 10/01/19 timestamp). 22 | 23 | You need to make changes to line 81 of their code to replace '\n' with paragraph markers instead of with spaces: 24 | 25 | `plot = plot.replace('\n\n', ' ##PM## ').replace('\r', '').strip()` 26 | 27 | After processing is complete, you should replace the '##PM##' markers with '<p>' before running the extract outlines script. 28 | 29 | #### 2. Steps for extracting outlines 30 | Run the extract_outlines.py to extract the outline-labeled documents that can be used as input to the train Plotmachines fine-tuning models. 31 | 32 | The output will provide you with a csv of the outlines and stories where each row is a paragraph from a story. The columns are: 33 | - story id: our format is "storyid_{int}" with the {int} after the underscore being this paragraph's index in the story (starting at 0) 34 | - key/abstract: this is a binary signifier for us to know where the data came from, but it's just in "K" for every row, in wikiplots 35 | - outline: the outline with points delimited by [SEP] 36 | - discourse tag: I/B/C for intro, body, conclusion paragraphs respectively 37 | - num_paragraphs: total number of paragraphs in this story 38 | - paragraph: the paragraph text 39 | - previous paragraph: text from the previous paragraph in the story 40 | 41 | #### 3. Steps for splitting into train/dev/test splits 42 | Please use the splits from wikiplots_splits.txt to construct the train, validation and text datasets that were used in the paper. Note that some stories may be need to be removed (marked "flagged") due to potentially offensive and/or harmful content. 43 | 44 | #### 4. Steps for removing offensive content: 45 | ####      4a. Offensive story removal: 46 | 47 | Some plots should be excluded from the data and are marked as 'flagged' instead of train/dev/test in the splits file. These are stories that we have identified as coming from summaries of books/movies that are probably offensive. We mostly try to remove stories that attack someone's identity (e.g. racist propaganda) or are problematically lurid. 48 | 49 | Details about how these stories were identified: we first used automatic toxicity models (using the Perspective API) to identify about stories that had toxicity scores above a manually chosen threshold (about 1500 stories). We skimmed the automatically curated list of toxic stories and manually corrected the labels of about 200 of those stories that we believe were misclassified. The remaining 1300 stories have been flagged in the data splits files. There are limitations to this hybrid automatic/manual approach, and there may be some stories that were misclassified (either incorrectly labelled as inoffensive or incorrectly labelled as offensive). We are continuing to prune the data to remove examples of offensive stories, so please let us know if you find any that we've missed. 50 | 51 | ####      4b. Offensive words: 52 | 53 | Some stories may have instances of offensive words even though the overall story is not offensive. Before training, you should check for swear words, slurs, etc. We recommend adding a pre-processing step to replace these words with some special token. 54 | 55 | ####      4c. Additional precautions: 56 | 57 | Even when taking these steps, there may be a few underlying themes in some stories that don't match modern values (for example, many older stories may express outdated views about gender roles). There also may be stories containing sexual and violent content that - depending on the end use - may be inappropriate for a model to be trained on. We therefore caution anyone using this data to be very careful in how they use models that are trained using these stories. Please moderate output as necessary and appropriate for your end task. 58 | 59 | ## WritingPrompts 60 | 61 | This is a story generation dataset, presented in Hierarchical Neural Story Generation, Fan et.al., 2018 collected from the /r/WritingPrompts subreddit - a forum where Reddit users compose short stories inspired by other users’ prompts. It contains over 300k human-written (prompt, story) pairs. We use the same train/dev/test split from the original dataset paper. 62 | 63 | 64 | ## New York Times 65 | 66 | NYTimes, The New York Times Annotated Corpus LDC2008T19, contains news articles. 67 | We use the scripts to parse NYT corpus and then split into train, validation and test using the list of keys provided in nyt_splits.txt 68 | 69 | Lastly, run the extract_outlines.py to extract the outline-labeled documents that can be used as input to the train Plotmachines fine-tuning models. This script can extract the wikiplots data. 70 | 71 | 72 | # After Pre-Processing 73 | 74 | Once the files are generated, rename the preprocessed files as "train_encoded.csv", "val_encoded.csv", "test_encoded.csv" and place them under "data_dir" folder, which is specificied as input parameter to the fine-truning script "train.py". 75 | 76 | ### Creating Pickle Files for h\_{i-1} 77 | 78 | Our model expects there to be a ∗\_gpt.pkl or ∗\_gpt2.pkl file in the directory (depending on command line settings). 79 | In our paper, we talk about using an encoded representation of the previous paragraph (h\_{i-1}) which we computed using either gpt or gpt2 (depending on the PlotMachines settings). To compute that, we used this function here which computes an average output embedding. For training time, we precomputed h\_{i-1} from gold paragraphs and stored in pickle (pkl) files. 80 | 81 | For each row in the input csv files, there is an entry in the pickle file which should contain tuple of: 82 | - index in the csv data file (header row = 0) 83 | - string version of the previous paragraph (which has to match the last column from the input csv/jsonl files). 84 | - a vector representing the previous paragraph 85 | 86 | Please note: in order to match indices with the input files, there needs to be a "dummy" encoding at the beginning of the pickle file to line up with the header row of the input csv. That row will get ignored in the code. 87 | 88 | -------------------------------------------------------------------------------- /src/preprocessing/extract_outlines.py: -------------------------------------------------------------------------------- 1 | from sklearn.feature_extraction.text import TfidfVectorizer 2 | import numpy as np 3 | from rake_nltk import Rake 4 | from nltk.tokenize import sent_tokenize 5 | import os 6 | 7 | def sorting(lst): 8 | # lst2=sorted(lst, key=len) 9 | lst2 = sorted(lst, key=len) 10 | return lst2 11 | 12 | def trim_body(body): 13 | paragraphs = body.replace('

', '\n').split('\n') 14 | body_new = [] 15 | par_length = 1 16 | 17 | for par in paragraphs: 18 | _par = par 19 | if _par.endswith(' .'): 20 | _par = _par[:-5] 21 | temp_body = _par.replace(' ', ' ').replace(' ', ' ').strip() 22 | sentences = _par.replace(' ', '\n').replace(' ', ' ').strip().split('\n') 23 | 24 | if len(paragraphs) == 1: 25 | s = 0 26 | first = True 27 | 28 | while len(sentences[s].split(' ')) < 4 or '::act ' in sentences[s].lower() or ' act:' in sentences[s].lower(): 29 | s+=1 30 | if s == len(sentences): 31 | return None 32 | body_new.append(' ' + sentences[s].replace(' ', ' ').strip()) 33 | s+=1 34 | 35 | while s < len(sentences) and len(body_new)< 5: 36 | body_new.append('') 37 | curr_len = 0 38 | while s < len(sentences) and curr_len + len(sentences[s].split(' ')) < 400: 39 | if ':act ' in sentences[s].lower() or 'act: ' in sentences[s].lower() : 40 | s+=1 41 | break 42 | 43 | if len(sentences[s]) > 10: 44 | curr_len += len(sentences[s].replace(' ', ' ').strip().split(' ')) 45 | body_new[len(body_new) - 1] += " " + sentences[s].replace(' ', ' ').strip() 46 | body_new[len(body_new) - 1] = body_new[len(body_new) - 1].strip() 47 | s += 1 48 | 49 | else: 50 | if par_length >5: 51 | s = 0 52 | while s < len(sentences) and len(sentences[s]) > 10 and (len(body_new[len(body_new)-1].split(' ')) + len(sentences[s].split(' '))) < 400: 53 | if len(sentences[s]) > 10: 54 | body_new[len(body_new) - 1] += " " + sentences[s].replace(' ', ' ').strip() 55 | body_new[len(body_new) - 1] = body_new[len(body_new) - 1].strip() 56 | s+=1 57 | else: 58 | if len(temp_body) > 10 and len(temp_body.split(' ')) <= 400: 59 | body_new.append(temp_body.replace(' ', ' ').replace(' ', ' ').strip()) 60 | 61 | elif len(temp_body.split(' ')) >400: 62 | curr_len = 0 63 | newstr = '' 64 | for sent in sentences: 65 | if len(newstr.split(' ')) + len(sent.split(' ')) <= 400: 66 | newstr += (' '+ sent).strip() 67 | else: 68 | break 69 | body_new.append(newstr.replace(' ', ' ').replace(' ', ' ').strip()) 70 | 71 | par_length+=1 72 | 73 | return body_new 74 | 75 | def clean_top_features(keywords, top=10): 76 | keywords = sorting(keywords) 77 | newkeys = [] 78 | newkeys.append(keywords[len(keywords)-1]) 79 | for i in range(len(keywords)-2,-1,-1): 80 | if newkeys[len(newkeys)-1].startswith(keywords[i]): 81 | continue 82 | newkeys.append(keywords[i]) 83 | 84 | if len(newkeys) > top: 85 | return newkeys[:10] 86 | return newkeys 87 | 88 | def convert_keys_to_str(key_list): 89 | newstr = key_list[0] 90 | for k in range(1, len(key_list)): 91 | if len(key_list[k].split(' ')) > 2 : 92 | newstr += '[SEP]' + key_list[k] 93 | return newstr.replace("(M)", "").strip() 94 | 95 | r = Rake() 96 | vectorizer = TfidfVectorizer(ngram_range=(1,3)) 97 | topK = 10 98 | 99 | infile = 'plot' 100 | infile_title = 'title' 101 | outfile = 'wikiplot.kwRAKE.csv' 102 | 103 | f = open(infile, 'r', encoding='"ISO-8859-1"') 104 | f_title = open(infile_title, 'r', encoding='"ISO-8859-1"') 105 | fout = open(outfile, 'a', encoding='"ISO-8859-1"') 106 | 107 | lines = f.readlines() 108 | lines_title = f_title.readlines() 109 | 110 | abstract_lens = {} 111 | 112 | sentences_to_write = [] 113 | w = 0 114 | total = 0 115 | sentences_to_write.append("[ID]\t[KEY/ABSTRACT]\t[KEYWORDS]\t[DISCOURSE (T/I/B/C)]\t[NUM_PARAGRAPHS]\t[PARAGRAPH]\t[PREVIOUS_PARAGRAPH]\n") 116 | 117 | title_id = 0 118 | for l in range(len(lines)): 119 | if lines[l].strip().startswith(""): 120 | continue 121 | title = lines_title[title_id].strip() 122 | title_id+=1 123 | document = lines[l].replace('t outline . ', '').replace('

', ' ').replace(' ', ' ').strip().replace(' ', '\n').split('\n') 124 | body = lines[l].replace('t outline . ', '').strip() 125 | 126 | try: 127 | r = Rake() 128 | r.extract_keywords_from_sentences(document) 129 | top_features = r.get_ranked_phrases() 130 | top_features = clean_top_features(top_features, topK) 131 | except Exception: 132 | print(document) 133 | continue 134 | 135 | keywordsSTR = convert_keys_to_str(top_features) 136 | 137 | if len(title) > 2: 138 | title = title.lower().replace("paid notice :", "").replace("paid notice:", "").replace("journal;", "").strip() 139 | keywordsSTR = title + '[SEP]' + keywordsSTR 140 | if len(keywordsSTR.split(' ')) > 100: 141 | keywordsSTR = ' '.join(keywordsSTR.split(' ')[0:100]).strip() 142 | 143 | body_new = trim_body(body) 144 | 145 | if body_new is None or len(body_new) < 1 or len((' '.join(body_new)).split(' '))<15: 146 | continue 147 | 148 | id = 'plot-' + str(title_id) 149 | 150 | total+=1 151 | new_sentence = id + '_0\tK\t' + keywordsSTR + '\tI\t' + str(len(body_new)) + "\t" + body_new[0] + "\tNA" 152 | sentences_to_write.append(new_sentence + '\n') 153 | 154 | for d in range(1, len(body_new)-1): 155 | new_sentence = id + '_' + str(d) + '\tK\t' + keywordsSTR + '\tB\t' + str(len(body_new)) + "\t" + body_new[d] + "\t" + body_new[d-1] 156 | sentences_to_write.append(new_sentence + '\n') 157 | 158 | if len(body_new) > 1: 159 | new_sentence = id + '_' + str(len(body_new)-1) + '\tK\t' + keywordsSTR + '\tC\t' + str(len(body_new)) + "\t" + body_new[len(body_new)-1] + "\t" + body_new[len(body_new)-2] 160 | sentences_to_write.append(new_sentence + '\n') 161 | 162 | fout.writelines(sentences_to_write) 163 | print("Total=" + str(total)) 164 | --------------------------------------------------------------------------------