├── .gitignore ├── README.md ├── create_small_set.py ├── data_loader.py ├── evaluate.py ├── generate.py ├── logger.py ├── loss.py ├── model_pytorch.py ├── opt.py ├── parallel.py ├── parameters_names.json ├── scripts ├── encode_cnndm.py ├── encode_newsroom.py └── encode_xsum.py ├── text_utils.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | model 2 | *__pycache__ 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Code for the paper "Efficient Adaption of Pretrained Transformers for Abstractive Summarization" 2 | 3 | ## Requirements 4 | 5 | To run the training script in [train.py](train.py) you will need in addition: 6 | - PyTorch (version >=0.4) 7 | - tqdm 8 | - pyrouge 9 | - [newsroom](https://github.com/clic-lab/newsroom) 10 | - tensorflow (cpu version is ok) 11 | - nltk 12 | - spacy (and 'en' model) 13 | 14 | You can download the weights of the OpenAI pre-trained version by cloning [Alec Radford's repo](https://github.com/openai/finetune-transformer-lm) and placing the `model` folder containing the pre-trained weights in the present repo. 15 | 16 | 17 | In order to run this code, you will need to pre-process the datasets using bpe through the scripts provided in [scripts](scripts) 18 | ## Dataset Preprocessing 19 | The training and evaluation scripts expect 3 total output files: `train_encoded.jsonl`, `val_encoded.jsonl`, and `test_encoded.jsonl` 20 | 21 | ### CNN/Daily Mail 22 | The data and splits used in the paper can be downloaded from [OpenNMT](http://opennmt.net/OpenNMT-py/Summarization.html). 23 | First, remove the start and end sentence tags using the sed command in the link provided. 24 | To process the data, run the following command: 25 | ``` 26 | python scripts/encode_cnndm.py --src_file {source file} --tgt_file {target file} --out_file {output file} 27 | ``` 28 | 29 | ### XSum 30 | The data and splits used in the paper can be scraped using [XSum](https://github.com/EdinburghNLP/XSum/tree/master/XSum-Dataset). 31 | Run the commands up through `Extract text from HTML Files` section. 32 | To process the data, run the following command: 33 | ``` 34 | python scripts/encode_xsum.py --summary_dir {summary directory} --splits_file {split file} --train_file {train file} --val_file {val file} --test_file {test_file} 35 | ``` 36 | 37 | ### Newsroom 38 | The data and splits used in the paper can be downloaded from [Newsroom](https://summari.es/download/). 39 | To process the data, run the following command: 40 | ``` 41 | python scripts/encode_newsroom.py --in_file {input split file} --out_file {output file} 42 | ``` 43 | 44 | ## Training 45 | To train a model, run the following command: 46 | ``` 47 | python train.py \ 48 | --data_dir {directory containing encoded data} \ 49 | --output_dir {name of folder to save data in} \ 50 | --experiment_name {name of experiment to save data with} \ 51 | --show_progress \ 52 | --doc_model \ 53 | --num_epochs_dat 10 \ 54 | --num_epochs_ft 10 \ 55 | --n_batch 16 \ 56 | --accum_iter 4 \ 57 | --use_pretrain 58 | ``` 59 | to train the pre-trained document embedding model over `dataset` for 10 epochs using domain adaptive training, and 10 epochs using fine tuning. The model will 60 | be trained with a effective batch size of 64, since the actual batch size is 16 and we accumulate gradients over 4 batches. Batch size must be divisible by 61 | the number of gpus available. Training is currently optimized for multi-gpu usage, and may not work for single gpu machines. 62 | 63 | ## Evaluation 64 | To evaluate a model, run the following command: 65 | ``` 66 | python evaluate.py \ 67 | --data_file {path to encoded data file encoded data} \ 68 | --checkpoint {checkpoint to load model weights from} \ 69 | --beam {beam size to do beam search with} \ 70 | --doc_model \ 71 | --save_file {file to output results to} \ 72 | --n_batch {batch size for evaluation, must be divisible by number of gpus} 73 | ``` 74 | to evaluate the document embedding model on the test set. Evaluation is currently optimized for multi-gpu usage, and may not work for single gpu machines. 75 | Since the evaluation script will leave out some examples if the number of data points isn't divisible by the number of gpus, you might need to run the 76 | `create_small_test.py` script to get the last few files that are being left out and aggregate results at the end. 77 | -------------------------------------------------------------------------------- /create_small_set.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from text_utils import TextEncoder 3 | from newsroom import jsonl 4 | 5 | def main(args): 6 | text_encoder = TextEncoder(args.encoder_path, args.bpe_path) 7 | with jsonl.open(args.original_file, gzip = True) as test_file: 8 | data = test_file.read() 9 | 10 | with jsonl.open(args.out_file, gzip=True) as out_file: 11 | out_file.write(data[-args.n:]) 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser("Creates a data file using the last n examples from the original file. Used to cover the fact that the model drops the last few results") 15 | parser.add_argument("--original_file", type=str, required=True) 16 | parser.add_argument("--n", type=int, required=True) 17 | parser.add_argument("--out_file", type=str, required=True) 18 | parser.add_argument('--encoder_path', type=str, default='model/encoder_bpe_40000.json') 19 | parser.add_argument('--bpe_path', type=str, default='model/vocab_40000.bpe') 20 | args = parser.parse_args() 21 | main(args) 22 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader 5 | from newsroom import jsonl 6 | 7 | class CustomDataset(Dataset): 8 | def __init__(self, data_file, encoder, max_size=None, subset=None): 9 | with jsonl.open(data_file, gzip=True) as f: 10 | self.data = f.read() 11 | if subset is not None: 12 | self.data = [x for x in self.data if x["density_bin"] == subset] 13 | random.shuffle(self.data) 14 | if max_size is not None: 15 | self.data = self.data[:max_size] 16 | self.encoder = encoder 17 | 18 | def __getitem__(self, index): 19 | json_data = self.data[index] 20 | src_phrase = json_data["text"][:512] 21 | tgt_phrase = json_data["summary"][:110] 22 | start = torch.LongTensor([self.encoder['_start_']]) 23 | delim = torch.LongTensor([self.encoder['_delimiter_']]) 24 | end = torch.LongTensor([self.encoder['_classify_']]) 25 | 26 | pad_output = torch.zeros(512 + 110 + 3, 2).long() 27 | mask_output = torch.zeros(512 + 110 + 3).long() 28 | 29 | # Tokens 30 | pad_output[0, 0] = start 31 | pad_output[1:len(src_phrase)+1, 0] = torch.LongTensor(src_phrase) 32 | pad_output[1+512, 0] = delim 33 | pad_output[1+512+1:1+512+1+len(tgt_phrase), 0] = torch.LongTensor(tgt_phrase) 34 | pad_output[1+512+1+len(tgt_phrase), 0] = end 35 | 36 | # Positional Embedding 37 | pad_output[1:len(src_phrase)+1, 1] = torch.LongTensor(np.arange(len(self.encoder), len(self.encoder) + len(src_phrase))) 38 | pad_output[1+512:1+512+1+len(tgt_phrase), 1] = torch.LongTensor(np.arange(len(self.encoder), len(self.encoder) + len(tgt_phrase) + 1)) 39 | 40 | # Mask 41 | mask_output[:1+len(src_phrase)] = torch.ones(1 + len(src_phrase)).long() 42 | mask_output[1+512+1:1+512+1+len(tgt_phrase)+1] = torch.ones(len(tgt_phrase) + 1).long() 43 | return pad_output, mask_output 44 | 45 | def __len__(self): 46 | return len(self.data) 47 | 48 | def get_loader(data_file, batch_size, encoder, shuffle=True, num_workers=0, max_size=None, subset=None): 49 | dataset = CustomDataset(data_file, encoder, max_size=max_size, subset=subset) 50 | return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=True) 51 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import glob 4 | import json 5 | import os 6 | import random 7 | import re 8 | 9 | from nltk.tokenize import sent_tokenize 10 | import numpy as np 11 | from pyrouge import Rouge155 12 | import torch 13 | import torch.nn as nn 14 | 15 | from opt import OpenAIAdam 16 | from text_utils import TextEncoder 17 | from data_loader import get_loader 18 | from tqdm import tqdm 19 | from generate import generate_outputs 20 | from model_pytorch import LMModel, load_openai_pretrained_model 21 | from parallel import DataParallelModel 22 | 23 | def clear_dirs(gen_dir, tgt_dir): 24 | for f in glob.glob("{}/*".format(tgt_dir)): 25 | os.remove(f) 26 | for f in glob.glob("{}/*".format(gen_dir)): 27 | os.remove(f) 28 | os.makedirs(tgt_dir, exist_ok=True) 29 | os.makedirs(gen_dir, exist_ok=True) 30 | 31 | def format_text(text, max_len, stop_words=[]): 32 | text = "\n".join(sent_tokenize(text)).replace("<", "<").replace(">", ">") 33 | for stop_word in stop_words: 34 | text = text.replace(" {} ".format(stop_word), " ") 35 | if max_len is not None: 36 | text = " ".join(text.split(" ")[:max_len]) 37 | return text 38 | 39 | def evaluate_model(model, val_loader, text_encoder, device, beam, gen_len, k, decoding_strategy, save_file, gen_dir="gen", tgt_dir="tgt", max_len=110, stop_words=[], args=None): 40 | data = {"src": [], "gen": [], "tgt": []} 41 | srcs, hyps, refs = [], [], [] 42 | 43 | model.eval() 44 | for pad_seq, mask_seq in tqdm(val_loader): 45 | with torch.no_grad(): 46 | # Generating outputs for evaluation 47 | src_strs, tgt_strs, gen_strs = generate_outputs(model, pad_seq, mask_seq, text_encoder, device, beam, gen_len, k, decoding_strategy, min_len=args.min_len) 48 | data["src"].extend(src_strs) 49 | data["gen"].extend(gen_strs) 50 | data["tgt"].extend(tgt_strs) 51 | 52 | for i in range(len(data["src"])): 53 | with open(os.path.join(gen_dir, "gen.{}.txt".format(i)), "w") as gen_file: 54 | gen_file.write(format_text(data["gen"][i], max_len, stop_words)) 55 | with open(os.path.join(tgt_dir, "tgt.{}.txt".format(i)), "w") as tgt_file: 56 | tgt_file.write(format_text(data["tgt"][i], max_len, stop_words)) 57 | 58 | with open(save_file, "w") as f: 59 | json.dump( 60 | get_rouge_scores(gen_dir, tgt_dir), 61 | f, 62 | indent=4, 63 | sort_keys=True 64 | ) 65 | 66 | def get_rouge_scores(gen_dir, tgt_dir, gen_pattern='gen.(\d+).txt', tgt_pattern='tgt.#ID#.txt'): 67 | r = Rouge155() 68 | r.system_dir = gen_dir 69 | r.model_dir = tgt_dir 70 | r.system_filename_pattern = gen_pattern 71 | r.model_filename_pattern = tgt_pattern 72 | output = r.convert_and_evaluate() 73 | return r.output_to_dict(output) 74 | 75 | def main(args): 76 | # Constants 77 | n_ctx = args.n_ctx 78 | desc = args.desc 79 | 80 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 81 | n_gpu = torch.cuda.device_count() 82 | print("device", device, "n_gpu", n_gpu) 83 | 84 | text_encoder = TextEncoder(args.encoder_path, args.bpe_path) 85 | encoder = text_encoder.encoder 86 | n_vocab = len(text_encoder.encoder) 87 | 88 | 89 | encoder['_start_'] = len(encoder) 90 | encoder['_delimiter_'] = len(encoder) 91 | encoder['_classify_'] = len(encoder) 92 | clf_token = encoder['_classify_'] 93 | n_special = 3 94 | 95 | print("Loading dataset...") 96 | test_loader = get_loader(args.data_file, args.n_batch, encoder, num_workers=1, shuffle=False, subset=args.subset) 97 | 98 | vocab = n_vocab + n_special + n_ctx 99 | dh_model = LMModel(args, vocab=vocab, n_ctx=n_ctx, doc_embed=args.doc_model) 100 | 101 | print("Loading model...") 102 | load_openai_pretrained_model(dh_model.transformer, n_ctx=n_ctx, n_special=n_special, path="./model/", path_names="./") 103 | if args.checkpoint != "none": 104 | checkpoint = torch.load(args.checkpoint, map_location='cpu') 105 | state_dict = checkpoint["state_dict"] 106 | for key in list(state_dict.keys()): 107 | state_dict[key[7:]] = state_dict[key] 108 | del state_dict[key] 109 | pos_emb_mask = torch.zeros(1, 1, vocab) 110 | pos_emb_mask[:, :, -n_ctx] = -1e12 111 | state_dict['pos_emb_mask'] = pos_emb_mask 112 | dh_model.load_state_dict(state_dict) 113 | 114 | dh_model.to(device) 115 | dh_model = DataParallelModel(dh_model) 116 | 117 | stop_words = [] 118 | if args.stop_words is not None: 119 | with open(args.stop_words) as f: 120 | for line in f: 121 | stop_words.append(line) 122 | evaluate_model(dh_model, test_loader, text_encoder, device, args.beam, args.gen_len, args.k, args.decoding_strategy, args.save_file, args.gen_dir, args.tgt_dir, args.max_len, stop_words, args) 123 | 124 | 125 | if __name__ == '__main__': 126 | parser = argparse.ArgumentParser() 127 | parser.add_argument('--desc', type=str, help="Description") 128 | parser.add_argument('--seed', type=int, default=42) 129 | parser.add_argument('--n_ctx', type=int, default=512) 130 | parser.add_argument('--n_embd', type=int, default=768) 131 | parser.add_argument('--n_head', type=int, default=12) 132 | parser.add_argument('--n_layer', type=int, default=12) 133 | parser.add_argument('--n_batch', type=int, default=1) 134 | parser.add_argument('--embd_pdrop', type=float, default=0.1) 135 | parser.add_argument('--attn_pdrop', type=float, default=0.1) 136 | parser.add_argument('--resid_pdrop', type=float, default=0.1) 137 | parser.add_argument('--clf_pdrop', type=float, default=0.1) 138 | parser.add_argument('--afn', type=str, default='gelu') 139 | parser.add_argument('--encoder_path', type=str, default='model/encoder_bpe_40000.json') 140 | parser.add_argument('--bpe_path', type=str, default='model/vocab_40000.bpe') 141 | parser.add_argument('--checkpoint', type=str, default="none") 142 | parser.add_argument('--data_file', type=str, required=True) 143 | parser.add_argument('--beam', type=int, default=0) 144 | parser.add_argument('--gen_len', type=int, default=110) 145 | parser.add_argument('--k', type=int, default=10) 146 | parser.add_argument('--min_len', type=int, default=None) 147 | parser.add_argument('--decoding_strategy', type=int, default=0) 148 | parser.add_argument('--save_file', type=str, required=True) 149 | parser.add_argument('--doc_model', action='store_true') 150 | parser.add_argument('--gen_dir', type=str, default="gen") 151 | parser.add_argument('--tgt_dir', type=str, default="tgt") 152 | parser.add_argument('--max_len', type=int, default=110) 153 | parser.add_argument('--stop_words', type=str, default=None) 154 | parser.add_argument('--subset', type=str, default=None) 155 | 156 | args = parser.parse_args() 157 | print(args) 158 | 159 | 160 | random.seed(args.seed) 161 | np.random.seed(args.seed) 162 | torch.manual_seed(args.seed) 163 | torch.cuda.manual_seed_all(args.seed) 164 | clear_dirs(args.gen_dir, args.tgt_dir) 165 | main(args) 166 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | import torch 7 | from tqdm import tqdm 8 | 9 | 10 | from data_loader import get_loader 11 | from model_pytorch import LMModel, load_openai_pretrained_model 12 | from parallel import DataParallelModel 13 | from text_utils import TextEncoder 14 | 15 | def generate_outputs(model, pad_output, mask_output, text_encoder, device, beam, gen_len, k, decoding_strategy, min_len=None): 16 | src_strs, tgt_strs, gen_strs = [], [], [] 17 | mask = mask_output 18 | outputs = model(pad_output, mask_output, text_encoder, device, beam=beam, gen_len=gen_len, k=k, decoding_strategy=decoding_strategy, generate=True, min_len=min_len) 19 | for generated_toks, input_toks, target_toks in outputs: 20 | for idx in range(generated_toks.size(0)): 21 | src_str = toks_to_str(input_toks[idx], text_encoder, is_input=True, mask=mask[idx]) 22 | src_strs.append(src_str) 23 | tgt_str = toks_to_str(target_toks[idx], text_encoder) 24 | tgt_strs.append(tgt_str) 25 | gen_str = toks_to_str(generated_toks[idx], text_encoder) 26 | gen_strs.append(gen_str) 27 | return src_strs, tgt_strs, gen_strs 28 | 29 | def toks_to_str(toks, text_encoder, is_input=False, mask=None): 30 | str_rep = '' 31 | end_tok = text_encoder.encoder['_delimiter_'] if is_input else text_encoder.encoder['_classify_'] 32 | for token in toks: 33 | if token.item() == end_tok:# or x.item() == end_idx: 34 | break 35 | elif token.item() in text_encoder.decoder: 36 | str_rep += text_encoder.decoder[token.item()].replace('', ' ').replace('\n', '') 37 | else: 38 | str_rep += 'unk ' 39 | # This makes sure rouge scorers doesn't complain about no sentences 40 | if not str_rep: 41 | str_rep = "unk." 42 | elif "." not in str_rep: 43 | str_rep += "." 44 | return str_rep 45 | 46 | def init(args): 47 | random.seed(args.seed) 48 | np.random.seed(args.seed) 49 | torch.manual_seed(args.seed) 50 | torch.cuda.manual_seed_all(args.seed) 51 | 52 | def main(args): 53 | init(args) 54 | 55 | # Constants 56 | n_ctx = args.n_ctx 57 | data_dir = args.data_dir 58 | 59 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 60 | n_gpu = torch.cuda.device_count() 61 | print("device", device, "n_gpu", n_gpu) 62 | 63 | text_encoder = TextEncoder(args.encoder_path, args.bpe_path) 64 | encoder = text_encoder.encoder 65 | n_vocab = len(text_encoder.encoder) 66 | text_encoder.decoder[len(encoder)] = '_start_' 67 | encoder['_start_'] = len(encoder) 68 | text_encoder.decoder[len(encoder)] = '_delimiter_' 69 | encoder['_delimiter_'] = len(encoder) 70 | text_encoder.decoder[len(encoder)] = '_classify_' 71 | encoder['_classify_'] = len(encoder) 72 | 73 | n_special = 3 # XD: useless for language modeling task 74 | vocab = n_vocab + n_special + n_ctx 75 | 76 | lm_model = LMModel(args, vocab, n_ctx, return_probs=True, doc_embed=args.doc_model) 77 | load_openai_pretrained_model(lm_model.transformer, n_ctx=n_ctx, n_special=n_special) 78 | if args.checkpoint != "none": 79 | checkpoint = torch.load(args.checkpoint, map_location='cpu') 80 | state_dict = checkpoint["state_dict"] 81 | for key in list(state_dict.keys()): 82 | state_dict[key[7:]] = state_dict[key] 83 | del state_dict[key] 84 | pos_emb_mask = torch.zeros(1, 1, vocab) 85 | pos_emb_mask[:, :, -n_ctx] = -1e12 86 | state_dict['pos_emb_mask'] = pos_emb_mask 87 | lm_model.load_state_dict(state_dict) 88 | lm_model.to(device) 89 | lm_model = DataParallelModel(lm_model) 90 | 91 | train_bar = get_loader(os.path.join(data_dir, "val_encoded.jsonl"), n_gpu, encoder, num_workers=1, shuffle=True, max_size=args.n_iter) 92 | srcs, hyps, refs = [], [], [] 93 | with torch.no_grad(): 94 | lm_model.eval() 95 | for i, (pad_output, mask_output) in enumerate(tqdm(train_bar), 1): 96 | src_strs, tgt_strs, gen_strs = generate_outputs(lm_model, pad_output, mask_output, text_encoder, device, args.beam, args.gen_len, args.k, args.decoding_strategy) 97 | srcs.extend(src_strs) 98 | hyps.extend(gen_strs) 99 | refs.extend(tgt_strs) 100 | 101 | for i in range(len(hyps)): 102 | print("*" * 50) 103 | print("Source: {}".format(srcs[i])) 104 | print('Hypothesis: {}'.format(hyps[i])) 105 | print("Reference: {}".format(refs[i])) 106 | 107 | if __name__ == '__main__': 108 | parser = argparse.ArgumentParser() 109 | 110 | # Standard 111 | parser.add_argument('--data_dir', type=str, default='data/') 112 | parser.add_argument('--seed', type=int, default=42) 113 | parser.add_argument('--n_iter', type=int, default=1) 114 | parser.add_argument('--n_ctx', type=int, default=512) 115 | parser.add_argument('--n_embd', type=int, default=768) 116 | parser.add_argument('--n_head', type=int, default=12) 117 | parser.add_argument('--n_layer', type=int, default=12) 118 | parser.add_argument('--embd_pdrop', type=float, default=0.1) 119 | parser.add_argument('--attn_pdrop', type=float, default=0.1) 120 | parser.add_argument('--resid_pdrop', type=float, default=0.1) 121 | parser.add_argument('--clf_pdrop', type=float, default=0.1) 122 | parser.add_argument('--afn', type=str, default='gelu') 123 | parser.add_argument('--encoder_path', type=str, default='src/model/encoder_bpe_40000.json') 124 | parser.add_argument('--bpe_path', type=str, default='src/model/vocab_40000.bpe') 125 | parser.add_argument('--checkpoint', type=str, default="none") 126 | # Custom 127 | parser.add_argument('--gen_len', type=int, default=110, 128 | help='Length of the generation') 129 | parser.add_argument('--k', type=int, default=10, 130 | help='How many tokens to sample for various decoding strategies') 131 | parser.add_argument('--inits', type=str, default='init.txt', 132 | help='Text file containing prefixes to continue') 133 | parser.add_argument('--decoding_strategy', type=int, default=0, 134 | help='Which decoding strategy to use, described in the comments') 135 | parser.add_argument('--beam', type=int, default=0, 136 | help='If this is 0, decoding_strategy will be used, if this is greater than 0 beam search will be used with the specified beam size') 137 | parser.add_argument('--doc_model', action='store_true', 138 | help='Set to use the document embedding model') 139 | parser.add_argument('--min_len', type=int, default=None, 140 | help='Set to use the document embedding model') 141 | 142 | args = parser.parse_args() 143 | print(args) 144 | main(args) 145 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | # Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 2 | import tensorflow as tf 3 | 4 | class Logger(): 5 | 6 | def __init__(self, log_dir): 7 | """Create a summary writer logging to log_dir.""" 8 | self.writer = tf.summary.FileWriter(log_dir) 9 | 10 | def scalar_summary(self, tag, value, step): 11 | """Log a scalar variable.""" 12 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 13 | self.writer.add_summary(summary, step) 14 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class LMLoss(nn.Module): 5 | 6 | def __init__(self, lm_criterion, opt=None): 7 | super(LMLoss, self).__init__() 8 | self.lm_criterion = lm_criterion 9 | self.opt = opt 10 | 11 | def forward(self, lm_logits, X, mask): 12 | x_shifted = X[:, 1:, 0].contiguous().view(-1) 13 | mask = mask[:, 1:].view(-1, mask.size(-1) - 1).float() 14 | lm_logits = lm_logits[:, :-1, :].contiguous().view(-1, lm_logits.size(-1)) 15 | lm_losses = self.lm_criterion(lm_logits, x_shifted) 16 | lm_losses = lm_losses.view(X.size(0), X.size(1) - 1) 17 | lm_losses = lm_losses * mask 18 | lm_losses = lm_losses.sum(1) / torch.sum(mask, 1) 19 | return lm_losses 20 | 21 | class SummaryLoss(nn.Module): 22 | 23 | def __init__(self, lm_criterion, opt=None): 24 | super(SummaryLoss, self).__init__() 25 | self.lm_criterion = lm_criterion 26 | self.opt = opt 27 | 28 | def forward(self, lm_logits, X, mask): 29 | x_shifted = X[:, 1+512+1:, 0].contiguous().view(-1) 30 | mask = mask[:, 1+512+1:].view(-1, mask.size(-1) - 514).float() 31 | lm_logits = lm_logits[:, 1+512:-1, :].contiguous().view(-1, lm_logits.size(-1)) 32 | lm_losses = self.lm_criterion(lm_logits, x_shifted) 33 | lm_losses = lm_losses.view(X.size(0), -1) 34 | lm_losses = lm_losses * mask 35 | lm_losses = lm_losses.sum(1) / torch.sum(mask, 1) 36 | return lm_losses 37 | -------------------------------------------------------------------------------- /model_pytorch.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import math 4 | import re 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.nn.parameter import Parameter 11 | 12 | 13 | def gelu(x): 14 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 15 | 16 | 17 | def swish(x): 18 | return x * torch.sigmoid(x) 19 | 20 | 21 | ACT_FNS = { 22 | 'relu': nn.ReLU, 23 | 'swish': swish, 24 | 'gelu': gelu 25 | } 26 | 27 | 28 | class LayerNorm(nn.Module): 29 | "Construct a layernorm module in the OpenAI style (epsilon inside the square root)." 30 | 31 | def __init__(self, n_state, e=1e-5): 32 | super(LayerNorm, self).__init__() 33 | self.g = nn.Parameter(torch.ones(n_state)) 34 | self.b = nn.Parameter(torch.zeros(n_state)) 35 | self.e = e 36 | 37 | def forward(self, x): 38 | u = x.mean(-1, keepdim=True) 39 | s = (x - u).pow(2).mean(-1, keepdim=True) 40 | x = (x - u) / torch.sqrt(s + self.e) 41 | return self.g * x + self.b 42 | 43 | 44 | class Conv1D(nn.Module): 45 | def __init__(self, nf, rf, nx): 46 | super(Conv1D, self).__init__() 47 | self.rf = rf 48 | self.nf = nf 49 | if rf == 1: # faster 1x1 conv 50 | w = torch.empty(nx, nf) 51 | nn.init.normal_(w, std=0.02) 52 | self.w = Parameter(w) 53 | self.b = Parameter(torch.zeros(nf)) 54 | else: # was used to train LM 55 | raise NotImplementedError 56 | 57 | def forward(self, x): 58 | if self.rf == 1: 59 | size_out = x.size()[:-1] + (self.nf,) 60 | x = torch.addmm(self.b, x.view(-1, x.size(-1)), self.w) 61 | x = x.view(*size_out) 62 | else: 63 | raise NotImplementedError 64 | return x 65 | 66 | 67 | class Attention(nn.Module): 68 | def __init__(self, nx, n_ctx, cfg, scale=False): 69 | super(Attention, self).__init__() 70 | n_state = nx # in Attention: n_state=768 (nx=n_embd) 71 | # [switch nx => n_state from Block to Attention to keep identical to TF implem] 72 | assert n_state % cfg.n_head == 0 73 | self.register_buffer('b', torch.tril(torch.ones(625, 625)).view(1, 1, 625, 625)) 74 | self.n_head = cfg.n_head 75 | self.split_size = n_state 76 | self.scale = scale 77 | self.c_attn = Conv1D(n_state * 3, 1, nx) 78 | self.c_proj = Conv1D(n_state, 1, nx) 79 | self.attn_dropout = nn.Dropout(cfg.attn_pdrop) 80 | self.resid_dropout = nn.Dropout(cfg.resid_pdrop) 81 | 82 | def _attn(self, q, k, v, mask): 83 | w = torch.matmul(q, k) 84 | if self.scale: 85 | w = w / math.sqrt(v.size(-1)) 86 | # w = w * self.b + -1e9 * (1 - self.b) # TF implem method: mask_attn_weights 87 | # XD: self.b may be larger than w, so we need to crop it 88 | b = self.b * mask.unsqueeze(1).unsqueeze(1).type_as(self.b) 89 | b = b[:, :, :w.size(-2), :w.size(-1)] 90 | w = w * b + -1e9 * (1 - b) 91 | 92 | w = nn.Softmax(dim=-1)(w) 93 | w = self.attn_dropout(w) 94 | return torch.matmul(w, v) 95 | 96 | def merge_heads(self, x): 97 | x = x.permute(0, 2, 1, 3).contiguous() 98 | new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) 99 | return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states 100 | 101 | def split_heads(self, x, k=False): 102 | new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) 103 | x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states 104 | if k: 105 | return x.permute(0, 2, 3, 1) 106 | else: 107 | return x.permute(0, 2, 1, 3) 108 | 109 | def forward(self, x, mask): 110 | x = self.c_attn(x) 111 | query, key, value = x.split(self.split_size, dim=2) 112 | query = self.split_heads(query) 113 | key = self.split_heads(key, k=True) 114 | value = self.split_heads(value) 115 | a = self._attn(query, key, value, mask) 116 | a = self.merge_heads(a) 117 | a = self.c_proj(a) 118 | a = self.resid_dropout(a) 119 | return a 120 | 121 | 122 | class MLP(nn.Module): 123 | def __init__(self, n_state, cfg): # in MLP: n_state=3072 (4 * n_embd) 124 | super(MLP, self).__init__() 125 | nx = cfg.n_embd 126 | self.c_fc = Conv1D(n_state, 1, nx) 127 | self.c_proj = Conv1D(nx, 1, n_state) 128 | self.act = ACT_FNS[cfg.afn] 129 | self.dropout = nn.Dropout(cfg.resid_pdrop) 130 | 131 | def forward(self, x): 132 | h = self.act(self.c_fc(x)) 133 | h2 = self.c_proj(h) 134 | return self.dropout(h2) 135 | 136 | 137 | class Block(nn.Module): 138 | def __init__(self, n_ctx, cfg, scale=False): 139 | super(Block, self).__init__() 140 | nx = cfg.n_embd 141 | self.attn = Attention(nx, n_ctx, cfg, scale) 142 | self.ln_1 = LayerNorm(nx) 143 | self.mlp = MLP(4 * nx, cfg) 144 | self.ln_2 = LayerNorm(nx) 145 | 146 | def forward(self, x, mask): 147 | a = self.attn(x, mask) 148 | n = self.ln_1(x + a) 149 | m = self.mlp(n) 150 | h = self.ln_2(n + m) 151 | return h 152 | 153 | 154 | class TransformerModel(nn.Module): 155 | """ Transformer model """ 156 | 157 | def __init__(self, cfg, vocab=40990, n_ctx=512, doc_embed=True): 158 | super(TransformerModel, self).__init__() 159 | self.vocab = vocab 160 | self.embed = nn.Embedding(vocab, cfg.n_embd) 161 | self.is_doc_embed = doc_embed 162 | if doc_embed: 163 | self.article_embed = nn.Embedding(2, cfg.n_embd) 164 | self.summary_embed = nn.Embedding(2, cfg.n_embd) 165 | nn.init.normal_(self.article_embed.weight, 0, 0.01) 166 | nn.init.normal_(self.summary_embed.weight, 0, 0.01) 167 | self.drop = nn.Dropout(cfg.embd_pdrop) 168 | block = Block(n_ctx, cfg, scale=True) 169 | self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(cfg.n_layer)]) 170 | 171 | nn.init.normal_(self.embed.weight, std=0.02) 172 | 173 | def forward(self, x, mask): 174 | x = x.view(-1, x.size(-2), x.size(-1)) 175 | e = self.embed(x) 176 | # Add the position information to the input embeddings 177 | h = e.sum(dim=2) 178 | if self.is_doc_embed: 179 | doc_embed = torch.cat((self.article_embed(mask[:, :514]), self.summary_embed(mask[:, 514:])), dim=1) 180 | h += doc_embed[:, :h.size(1), :] 181 | for block in self.h: 182 | h = block(h, mask) 183 | return h 184 | 185 | 186 | class LMHead(nn.Module): 187 | """ Language Model Head for the transformer """ 188 | 189 | def __init__(self, model, cfg, trunc_and_reshape=True): 190 | super(LMHead, self).__init__() 191 | self.n_embd = cfg.n_embd 192 | embed_shape = model.embed.weight.shape 193 | self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False) 194 | self.decoder.weight = model.embed.weight # Tied weights 195 | self.trunc_and_reshape = trunc_and_reshape # XD 196 | 197 | def forward(self, h): 198 | # Truncated Language modeling logits (we remove the last token) 199 | h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd) \ 200 | if self.trunc_and_reshape else h # XD 201 | lm_logits = self.decoder(h_trunc) 202 | return lm_logits 203 | 204 | # XD 205 | class LMModel(nn.Module): 206 | """ Transformer with language model head only """ 207 | def __init__(self, cfg, vocab=40990, n_ctx=512, return_probs=False, doc_embed=True): 208 | super(LMModel, self).__init__() 209 | self.transformer = TransformerModel(cfg, vocab=vocab, n_ctx=n_ctx, doc_embed=doc_embed) 210 | self.lm_head = LMHead(self.transformer, cfg, trunc_and_reshape=False) 211 | self.return_probs = return_probs 212 | 213 | #if self.return_probs: 214 | pos_emb_mask = torch.zeros(1, 1, vocab) 215 | pos_emb_mask[:, :, -n_ctx:] = -1e12 216 | self.register_buffer('pos_emb_mask', pos_emb_mask) 217 | 218 | def forward(self, pad_output, mask_output=None, text_encoder=None, device=None, beam=0, gen_len=110, k=0, decoding_strategy=0, log=True, generate=False, min_len=None): 219 | if generate: 220 | return self.generate(pad_output, mask_output, text_encoder, device, beam, gen_len, k, decoding_strategy, min_len=min_len) 221 | return self._forward(pad_output, mask_output, log) 222 | 223 | 224 | def _forward(self, x, mask_output, log=True, return_probs=False): 225 | h = self.transformer(x, mask_output) 226 | lm_logits = self.lm_head(h) 227 | if self.return_probs or return_probs: 228 | if log: 229 | lm_logits = F.log_softmax((lm_logits + self.pos_emb_mask), dim=-1) 230 | else: 231 | lm_logits = F.softmax((lm_logits + self.pos_emb_mask), dim=-1) 232 | return lm_logits 233 | 234 | def append_batch(self, X, next_idx): 235 | next_pos = X[:, -1:, 1] + 1 236 | next_x = torch.cat((next_idx, next_pos), -1).unsqueeze(1) 237 | return torch.cat((X, next_x), 1) 238 | 239 | 240 | def sample(self, pad_output, mask, classify_idx, text_encoder, gen_len=110, k=0, decoding_strategy=0, min_len=None): 241 | XMB = pad_output 242 | seen_trigrams = [{} for _ in range(XMB.size(0))] 243 | for _ in range(gen_len): 244 | lm_probs = self._forward(XMB, mask, return_probs=True, log=False) 245 | dist = lm_probs[:, -1, :].squeeze() 246 | if k == 0: 247 | next_idx = torch.multinomial(lm_probs[:, -1, :], 1) 248 | else: 249 | if decoding_strategy == 0: 250 | # Sample from top k 251 | values, indices = dist.topk(k) 252 | next_idx = indices.gather(-1, torch.multinomial(values, 1)) 253 | if _ == 2: 254 | for i in range(XMB.size(0)): 255 | bigram = (XMB[i, -2, 0].item(), XMB[i, -1, 0].item()) 256 | seen_trigrams[i][bigram] = [next_idx[i].item()] 257 | elif _ > 2: 258 | for i in range(XMB.size(0)): 259 | bigram = (XMB[i, -2, 0].item(), XMB[i, -1, 0].item()) 260 | if bigram in seen_trigrams[i]: 261 | for value in seen_trigrams[i][bigram]: 262 | dist[i, value] = 0 263 | values, indices = dist.topk(k) 264 | next_idx = indices.gather(-1, torch.multinomial(values, 1)) 265 | for i in range(XMB.size(0)): 266 | bigram = (XMB[i, -2, 0].item(), XMB[i, -1, 0].item()) 267 | if bigram not in seen_trigrams[i]: 268 | seen_trigrams[i][bigram] = [] 269 | seen_trigrams[i][bigram].append(next_idx[i].item()) 270 | else: 271 | raise NotImplementedError 272 | XMB = self.append_batch(XMB, next_idx) 273 | return XMB[:, -gen_len:, 0] 274 | 275 | 276 | def beam_search(self, pad_output, mask, classify_idx, text_encoder, beam, gen_len=110, min_len=None): 277 | batch_size = pad_output.size(0) 278 | XMB = pad_output 279 | finished_beams = [[] for _ in range(batch_size)] 280 | 281 | """Initial run""" 282 | dist = self._forward(XMB, mask, log=True, return_probs=True)[:, -1, :] 283 | beam_lls, beam_toks = dist.topk(beam) 284 | beam_probs = beam_lls.view(-1, 1) 285 | beam_toks = beam_toks.view(-1, 1) 286 | XMB = XMB.repeat(1, beam, 1).view(-1, XMB.size(1), XMB.size(2)) 287 | next_x = torch.cat((beam_toks, XMB[:, -1:, 1] + 1), -1).unsqueeze(1) 288 | XMB = torch.cat((XMB, next_x), 1) 289 | mask = mask.repeat(1, beam).view(-1, mask.size(1)) 290 | 291 | finished_mask = beam_toks.eq(classify_idx) 292 | for i in range(finished_mask.size(0)): 293 | if finished_mask[i].item() == 1: 294 | finished_beams[i // beam].append((XMB[i, 1+512+1:, 0].cpu(), beam_probs[i].item() / XMB.size(1) - 513)) # 513 to include classify tok in error and avoid division by 0 295 | beam_probs[i] = -1e8 296 | 297 | for _ in range(gen_len - 1): 298 | top_values, top_beams = beam_probs.view(batch_size, -1).topk(beam) 299 | top_beams = (top_beams + torch.tensor([[i * beam_probs.size(0) / batch_size for j in range(beam)] for i in range(batch_size)]).type_as(top_beams)).view(-1) 300 | index = torch.cat([torch.ones(1, mask.size(1), XMB.size(2)).long() * torch.tensor([top_beams[i]]).long() for i in range(top_beams.size(0))], dim=0).type_as(XMB) 301 | XMB = torch.gather(XMB, 0, index[:, :XMB.size(1), :]) 302 | mask = torch.gather(mask, 0, index[:, :, 0]) 303 | beam_probs = torch.gather(beam_probs, 0, index[:, 0, 0].unsqueeze(1)) 304 | if _ > 2: 305 | seen_hashes = torch.gather(seen_hashes, 0, index[:, :seen_hashes.size(1), 0]) 306 | 307 | lm_probs = self._forward(XMB, mask, log=True, return_probs=True)[:, -1, :] 308 | if _ == 2: 309 | trigram = XMB[:, -3:, 0] 310 | trigram_hash = (2.0 * trigram[:, 0] + 3.3 * trigram[:, 1] + 7.8 * trigram[:, 2]).unsqueeze(1) 311 | seen_hashes = trigram_hash 312 | elif _ > 2: 313 | trigram = XMB[:, -3:, 0] 314 | trigram_hash = (2.0 * trigram[:, 0] + 3.3 * trigram[:, 1] + 7.8 * trigram[:, 2]).unsqueeze(1) 315 | trigram_mask = trigram_hash.eq(seen_hashes).sum(dim=1, keepdim=True).byte() 316 | lm_probs.masked_fill_(trigram_mask, -1e8) 317 | seen_hashes = torch.cat((seen_hashes, trigram_hash), dim=1) 318 | 319 | beam_lls, beam_toks = lm_probs.topk(beam) 320 | beam_lls = beam_lls.view(-1, 1) 321 | beam_toks = beam_toks.view(-1, 1) 322 | beam_probs = beam_probs.repeat(1, beam).view(-1, 1) 323 | XMB = XMB.repeat(1, beam, 1).view(-1, XMB.size(1), XMB.size(2)) 324 | mask = mask.repeat(1, beam).view(-1, mask.size(1)) 325 | if _ >= 2: 326 | seen_hashes = seen_hashes.repeat(1, beam).view(-1, seen_hashes.size(1)) 327 | next_x = torch.cat((beam_toks, XMB[:, -1:, 1] + 1), -1).unsqueeze(1) 328 | XMB = torch.cat((XMB, next_x), 1) 329 | beam_probs += beam_lls 330 | finished_mask = beam_toks.eq(classify_idx) 331 | #TODO Might be able to batch this 332 | for i in range(finished_mask.size(0)): 333 | if finished_mask[i].item() == 1: 334 | tokens = [] 335 | for tok in XMB[i, 1+512+1:-1, 0]: 336 | if tok.item() in text_encoder.decoder: 337 | tokens.append(text_encoder.decoder[tok.item()].replace('', ' ').replace('\n', '')) 338 | else: 339 | tokens.append(" ") 340 | phrase = ' '.join(''.join(tokens).split()) 341 | if min_len is None or len(phrase.split(" ")) >= min_len: 342 | finished_beams[i // (beam * beam)].append((XMB[i, 1+512+1:, 0], beam_probs[i].item() / XMB.size(1) - 513)) # 513 to include classify tok in error and avoid division by 0 343 | beam_probs[i] = -1e8 344 | finished_mask = beam_toks.eq(classify_idx) 345 | beam_seqs = [sorted(finished_beam, key=lambda x: x[1], reverse=True) for finished_beam in finished_beams] 346 | tokens = torch.zeros(len(beam_seqs), gen_len) 347 | for i in range(len(beam_seqs)): 348 | beam_seq = beam_seqs[i][0][0] if len(beam_seqs[i]) != 0 else torch.tensor([classify_idx]).unsqueeze(1).type_as(XMB) 349 | tokens[i, :beam_seq.size(0)] = beam_seq 350 | return tokens 351 | 352 | def generate(self, pad_output, mask, text_encoder, device, beam=0, gen_len=110, k=0, decoding_strategy=0, min_len=None): 353 | classify_idx = text_encoder.encoder['_classify_'] 354 | input_toks = pad_output[:, :1+512+1, 0] # includes delimiter 355 | target_toks = pad_output[:, -(gen_len+1):, 0] 356 | mask_pad = torch.ones(mask.size()).type_as(mask) 357 | mask_pad[:, :1 + 512 + 1] = mask[:, :1 + 512 + 1] 358 | mask = mask_pad 359 | 360 | pad_output = pad_output.to(device) 361 | XMB = pad_output[:, :1+512+1] 362 | if beam == 0: 363 | generated_toks = self.sample(XMB, mask, classify_idx, text_encoder, gen_len, k, decoding_strategy, min_len=min_len) 364 | else: 365 | generated_toks = self.beam_search(XMB, mask, classify_idx, text_encoder, beam=beam, gen_len=gen_len, min_len=min_len) 366 | return generated_toks.type_as(XMB), input_toks.type_as(XMB), target_toks.type_as(XMB) 367 | 368 | 369 | def load_openai_pretrained_model(model, n_ctx=-1, n_special=-1, n_transfer=12, n_embd=768, path='./model/', 370 | path_names='./'): 371 | # Load weights from TF model 372 | print("Loading weights...") 373 | names = json.load(open(path_names + 'parameters_names.json')) 374 | shapes = json.load(open(path + 'params_shapes.json')) 375 | offsets = np.cumsum([np.prod(shape) for shape in shapes]) 376 | init_params = [np.load(path + 'params_{}.npy'.format(n)) for n in range(10)] 377 | init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1] 378 | init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)] 379 | if n_ctx > 0: 380 | init_params[0] = init_params[0][:n_ctx] 381 | if n_special > 0: 382 | init_params[0] = np.concatenate( 383 | [init_params[1], 384 | (np.random.randn(n_special, n_embd) * 0.02).astype(np.float32), 385 | init_params[0] 386 | ], 0) 387 | else: 388 | init_params[0] = np.concatenate( 389 | [init_params[1], 390 | init_params[0] 391 | ], 0) 392 | del init_params[1] 393 | if n_transfer == -1: 394 | n_transfer = 0 395 | else: 396 | n_transfer = 1 + n_transfer * 12 397 | init_params = [arr.squeeze() for arr in init_params] 398 | 399 | try: 400 | assert model.embed.weight.shape == init_params[0].shape 401 | except AssertionError as e: 402 | e.args += (model.embed.weight.shape, init_params[0].shape) 403 | raise 404 | 405 | model.embed.weight.data = torch.from_numpy(init_params[0]) 406 | 407 | for name, ip in zip(names[1:n_transfer], init_params[1:n_transfer]): 408 | name = name[6:] # skip "model/" 409 | assert name[-2:] == ":0" 410 | name = name[:-2] 411 | name = name.split('/') 412 | pointer = model 413 | for m_name in name: 414 | if re.fullmatch(r'[A-Za-z]+\d+', m_name): 415 | l = re.split(r'(\d+)', m_name) 416 | else: 417 | l = [m_name] 418 | pointer = getattr(pointer, l[0]) 419 | if len(l) >= 2: 420 | num = int(l[1]) 421 | pointer = pointer[num] 422 | try: 423 | assert pointer.shape == ip.shape 424 | except AssertionError as e: 425 | e.args += (pointer.shape, ip.shape) 426 | raise 427 | pointer.data = torch.from_numpy(ip) 428 | 429 | 430 | class dotdict(dict): 431 | """dot.notation access to dictionary attributes""" 432 | __getattr__ = dict.get 433 | __setattr__ = dict.__setitem__ 434 | __delattr__ = dict.__delitem__ 435 | 436 | 437 | DEFAULT_CONFIG = dotdict({ 438 | 'n_embd': 768, 439 | 'n_head': 12, 440 | 'n_layer': 12, 441 | 'embd_pdrop': 0.1, 442 | 'attn_pdrop': 0.1, 443 | 'resid_pdrop': 0.1, 444 | 'afn': 'gelu', 445 | 'clf_pdrop': 0.1}) 446 | -------------------------------------------------------------------------------- /opt.py: -------------------------------------------------------------------------------- 1 | # Code referenced from https://github.com/huggingface/pytorch-openai-transformer-lm/blob/master/opt.py 2 | 3 | import math 4 | import torch 5 | from torch.optim import Optimizer 6 | from torch.nn.utils import clip_grad_norm_ 7 | 8 | def warmup_cosine(x, warmup=0.002): 9 | s = 1 if x <= warmup else 0 10 | return s*(x/warmup) + (1-s)*(0.5 * (1 + torch.cos(math.pi * x))) 11 | 12 | def warmup_constant(x, warmup=0.002): 13 | s = 1 if x <= warmup else 0 14 | return s*(x/warmup) + (1-s)*1 15 | 16 | def warmup_linear(x, warmup=0.002): 17 | s = 1 if x <= warmup else 0 18 | return (s*(x/warmup) + (1-s))*(1-x) 19 | 20 | SCHEDULES = { 21 | 'warmup_cosine':warmup_cosine, 22 | 'warmup_constant':warmup_constant, 23 | 'warmup_linear':warmup_linear, 24 | } 25 | 26 | 27 | class OpenAIAdam(Optimizer): 28 | """Implements Open AI version of Adam algorithm with weight decay fix. 29 | """ 30 | def __init__(self, params, lr, schedule, warmup, t_total, 31 | b1=0.9, b2=0.999, e=1e-8, l2=0, 32 | vector_l2=False, max_grad_norm=-1, **kwargs): 33 | if not 0.0 <= lr: 34 | raise ValueError("Invalid learning rate: {}".format(lr)) 35 | if schedule not in SCHEDULES: 36 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 37 | if not 0 <= warmup: 38 | raise ValueError("Invalid warmup: {}".format(warmup)) 39 | if not 0.0 <= b1 < 1.0: 40 | raise ValueError("Invalid b1 parameter: {}".format(b1)) 41 | if not 0.0 <= b2 < 1.0: 42 | raise ValueError("Invalid b2 parameter: {}".format(b2)) 43 | if not 0.0 <= e: 44 | raise ValueError("Invalid epsilon value: {}".format(e)) 45 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 46 | b1=b1, b2=b2, e=e, l2=l2, vector_l2=vector_l2, 47 | max_grad_norm=max_grad_norm) 48 | super(OpenAIAdam, self).__init__(params, defaults) 49 | 50 | def step(self, closure=None): 51 | """Performs a single optimization step. 52 | 53 | Arguments: 54 | closure (callable, optional): A closure that reevaluates the model 55 | and returns the loss. 56 | """ 57 | loss = None 58 | if closure is not None: 59 | loss = closure() 60 | 61 | for group in self.param_groups: 62 | for p in group['params']: 63 | if p.grad is None: 64 | continue 65 | grad = p.grad.data 66 | if grad.is_sparse: 67 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 68 | 69 | state = self.state[p] 70 | 71 | # State initialization 72 | if len(state) == 0: 73 | state['step'] = 0 74 | # Exponential moving average of gradient values 75 | state['exp_avg'] = torch.zeros_like(p.data) 76 | # Exponential moving average of squared gradient values 77 | state['exp_avg_sq'] = torch.zeros_like(p.data) 78 | 79 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 80 | beta1, beta2 = group['b1'], group['b2'] 81 | 82 | state['step'] += 1 83 | 84 | # Add grad clipping 85 | if group['max_grad_norm'] > 0: 86 | clip_grad_norm_(p, group['max_grad_norm']) 87 | 88 | # Decay the first and second moment running average coefficient 89 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 90 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 91 | denom = exp_avg_sq.sqrt().add_(group['e']) 92 | 93 | bias_correction1 = 1 - beta1 ** state['step'] 94 | bias_correction2 = 1 - beta2 ** state['step'] 95 | 96 | schedule_fct = SCHEDULES[group['schedule']] 97 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 98 | step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 99 | 100 | p.data.addcdiv_(-step_size, exp_avg, denom) 101 | 102 | # Add weight decay at the end (fixed version) 103 | if (len(p.size()) > 1 or group['vector_l2']) and group['l2'] > 0: 104 | p.data.add_(-lr_scheduled * group['l2'], p.data) 105 | 106 | return loss 107 | -------------------------------------------------------------------------------- /parallel.py: -------------------------------------------------------------------------------- 1 | # Code referenced from https://gist.github.com/thomwolf/7e2407fbd5945f07821adae3d9fd1312 2 | 3 | """Encoding Data Parallel""" 4 | import threading 5 | import functools 6 | import torch 7 | from torch.autograd import Variable, Function 8 | import torch.cuda.comm as comm 9 | from torch.nn.parallel.data_parallel import DataParallel 10 | from torch.nn.parallel.parallel_apply import get_a_var 11 | from torch.nn.parallel.scatter_gather import gather 12 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 13 | 14 | torch_ver = torch.__version__[:3] 15 | 16 | __all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion', 17 | 'patch_replication_callback'] 18 | 19 | def allreduce(*inputs): 20 | """Cross GPU all reduce autograd operation for calculate mean and 21 | variance in SyncBN. 22 | """ 23 | return AllReduce.apply(*inputs) 24 | 25 | class AllReduce(Function): 26 | @staticmethod 27 | def forward(ctx, num_inputs, *inputs): 28 | ctx.num_inputs = num_inputs 29 | ctx.target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)] 30 | inputs = [inputs[i:i + num_inputs] 31 | for i in range(0, len(inputs), num_inputs)] 32 | # sort before reduce sum 33 | inputs = sorted(inputs, key=lambda i: i[0].get_device()) 34 | results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0]) 35 | outputs = comm.broadcast_coalesced(results, ctx.target_gpus) 36 | return tuple([t for tensors in outputs for t in tensors]) 37 | 38 | @staticmethod 39 | def backward(ctx, *inputs): 40 | inputs = [i.data for i in inputs] 41 | inputs = [inputs[i:i + ctx.num_inputs] 42 | for i in range(0, len(inputs), ctx.num_inputs)] 43 | results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0]) 44 | outputs = comm.broadcast_coalesced(results, ctx.target_gpus) 45 | return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors]) 46 | 47 | 48 | class Reduce(Function): 49 | @staticmethod 50 | def forward(ctx, *inputs): 51 | ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))] 52 | inputs = sorted(inputs, key=lambda i: i.get_device()) 53 | return comm.reduce_add(inputs) 54 | 55 | @staticmethod 56 | def backward(ctx, gradOutput): 57 | return Broadcast.apply(ctx.target_gpus, gradOutput) 58 | 59 | class DistributedDataParallelModel(torch.nn.parallel.distributed.DistributedDataParallel): 60 | """Implements data parallelism at the module level for the DistributedDataParallel module. 61 | This container parallelizes the application of the given module by 62 | splitting the input across the specified devices by chunking in the 63 | batch dimension. 64 | In the forward pass, the module is replicated on each device, 65 | and each replica handles a portion of the input. During the backwards pass, 66 | gradients from each replica are summed into the original module. 67 | Note that the outputs are not gathered, please use compatible 68 | :class:`encoding.parallel.DataParallelCriterion`. 69 | The batch size should be larger than the number of GPUs used. It should 70 | also be an integer multiple of the number of GPUs so that each chunk is 71 | the same size (so that each GPU processes the same number of samples). 72 | Args: 73 | module: module to be parallelized 74 | device_ids: CUDA devices (default: all devices) 75 | Reference: 76 | Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, 77 | Amit Agrawal. “Context Encoding for Semantic Segmentation. 78 | *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* 79 | Example:: 80 | >>> net = encoding.nn.DistributedDataParallelModel(model, device_ids=[0, 1, 2]) 81 | >>> y = net(x) 82 | """ 83 | def gather(self, outputs, output_device): 84 | return outputs 85 | 86 | class DataParallelModel(DataParallel): 87 | """Implements data parallelism at the module level. 88 | 89 | This container parallelizes the application of the given module by 90 | splitting the input across the specified devices by chunking in the 91 | batch dimension. 92 | In the forward pass, the module is replicated on each device, 93 | and each replica handles a portion of the input. During the backwards pass, 94 | gradients from each replica are summed into the original module. 95 | Note that the outputs are not gathered, please use compatible 96 | :class:`encoding.parallel.DataParallelCriterion`. 97 | 98 | The batch size should be larger than the number of GPUs used. It should 99 | also be an integer multiple of the number of GPUs so that each chunk is 100 | the same size (so that each GPU processes the same number of samples). 101 | 102 | Args: 103 | module: module to be parallelized 104 | device_ids: CUDA devices (default: all devices) 105 | 106 | Reference: 107 | Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, 108 | Amit Agrawal. “Context Encoding for Semantic Segmentation. 109 | *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* 110 | 111 | Example:: 112 | 113 | >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2]) 114 | >>> y = net(x) 115 | """ 116 | def gather(self, outputs, output_device): 117 | return outputs 118 | 119 | def replicate(self, module, device_ids): 120 | modules = super(DataParallelModel, self).replicate(module, device_ids) 121 | execute_replication_callbacks(modules) 122 | return modules 123 | 124 | 125 | class DataParallelCriterion(DataParallel): 126 | """ 127 | Calculate loss in multiple-GPUs, which balance the memory usage. 128 | The targets are splitted across the specified devices by chunking in 129 | the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`. 130 | 131 | Reference: 132 | Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, 133 | Amit Agrawal. “Context Encoding for Semantic Segmentation. 134 | *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* 135 | 136 | Example:: 137 | 138 | >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2]) 139 | >>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2]) 140 | >>> y = net(x) 141 | >>> loss = criterion(y, target) 142 | """ 143 | #def forward(self, inputs, *targets, **kwargs): 144 | def forward(self, inputs, *targets, only_return_losses=False, **kwargs): 145 | # input should be already scatterd 146 | # scattering the targets instead 147 | if not self.device_ids: 148 | return self.module(inputs, *targets, only_return_losses=only_return_losses, **kwargs) 149 | targets, kwargs = self.scatter(targets, kwargs, self.device_ids) 150 | if len(self.device_ids) == 1: 151 | return self.module(inputs, *targets[0], **kwargs[0]) 152 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) 153 | outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs) 154 | return self.gather(outputs, self.output_device) 155 | 156 | 157 | def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None): 158 | assert len(modules) == len(inputs) 159 | assert len(targets) == len(inputs) 160 | if kwargs_tup: 161 | assert len(modules) == len(kwargs_tup) 162 | else: 163 | kwargs_tup = ({},) * len(modules) 164 | if devices is not None: 165 | assert len(modules) == len(devices) 166 | else: 167 | devices = [None] * len(modules) 168 | 169 | lock = threading.Lock() 170 | results = {} 171 | if torch_ver != "0.3": 172 | grad_enabled = torch.is_grad_enabled() 173 | 174 | def _worker(i, module, input, target, kwargs, device=None): 175 | if torch_ver != "0.3": 176 | torch.set_grad_enabled(grad_enabled) 177 | if device is None: 178 | device = get_a_var(input).get_device() 179 | try: 180 | with torch.cuda.device(device): 181 | # this also avoids accidental slicing of `input` if it is a Tensor 182 | if not isinstance(input, (list, tuple)): 183 | input = (input,) 184 | if not isinstance(target, (list, tuple)): 185 | target = (target,) 186 | output = module(*(input + target), **kwargs) 187 | with lock: 188 | results[i] = output 189 | except Exception as e: 190 | with lock: 191 | results[i] = e 192 | 193 | if len(modules) > 1: 194 | threads = [threading.Thread(target=_worker, 195 | args=(i, module, input, target, 196 | kwargs, device),) 197 | for i, (module, input, target, kwargs, device) in 198 | enumerate(zip(modules, inputs, targets, kwargs_tup, devices))] 199 | 200 | for thread in threads: 201 | thread.start() 202 | for thread in threads: 203 | thread.join() 204 | else: 205 | _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) 206 | 207 | outputs = [] 208 | for i in range(len(inputs)): 209 | output = results[i] 210 | if isinstance(output, Exception): 211 | raise output 212 | outputs.append(output) 213 | return outputs 214 | 215 | 216 | ########################################################################### 217 | # Adapted from Synchronized-BatchNorm-PyTorch. 218 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 219 | # 220 | class CallbackContext(object): 221 | pass 222 | 223 | 224 | def execute_replication_callbacks(modules): 225 | """ 226 | Execute an replication callback `__data_parallel_replicate__` on each module created 227 | by original replication. 228 | 229 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 230 | 231 | Note that, as all modules are isomorphism, we assign each sub-module with a context 232 | (shared among multiple copies of this module on different devices). 233 | Through this context, different copies can share some information. 234 | 235 | We guarantee that the callback on the master copy (the first copy) will be called ahead 236 | of calling the callback of any slave copies. 237 | """ 238 | master_copy = modules[0] 239 | nr_modules = len(list(master_copy.modules())) 240 | ctxs = [CallbackContext() for _ in range(nr_modules)] 241 | 242 | for i, module in enumerate(modules): 243 | for j, m in enumerate(module.modules()): 244 | if hasattr(m, '__data_parallel_replicate__'): 245 | m.__data_parallel_replicate__(ctxs[j], i) 246 | 247 | 248 | def patch_replication_callback(data_parallel): 249 | """ 250 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 251 | Useful when you have customized `DataParallel` implementation. 252 | 253 | Examples: 254 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 255 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 256 | > patch_replication_callback(sync_bn) 257 | # this is equivalent to 258 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 259 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 260 | """ 261 | 262 | assert isinstance(data_parallel, DataParallel) 263 | 264 | old_replicate = data_parallel.replicate 265 | 266 | @functools.wraps(old_replicate) 267 | def new_replicate(module, device_ids): 268 | modules = old_replicate(module, device_ids) 269 | execute_replication_callbacks(modules) 270 | return modules 271 | 272 | data_parallel.replicate = new_replicate 273 | -------------------------------------------------------------------------------- /parameters_names.json: -------------------------------------------------------------------------------- 1 | ["model/we:0", "model/h0/attn/c_attn/w:0", "model/h0/attn/c_attn/b:0", "model/h0/attn/c_proj/w:0", "model/h0/attn/c_proj/b:0", "model/h0/ln_1/g:0", "model/h0/ln_1/b:0", "model/h0/mlp/c_fc/w:0", "model/h0/mlp/c_fc/b:0", "model/h0/mlp/c_proj/w:0", "model/h0/mlp/c_proj/b:0", "model/h0/ln_2/g:0", "model/h0/ln_2/b:0", "model/h1/attn/c_attn/w:0", "model/h1/attn/c_attn/b:0", "model/h1/attn/c_proj/w:0", "model/h1/attn/c_proj/b:0", "model/h1/ln_1/g:0", "model/h1/ln_1/b:0", "model/h1/mlp/c_fc/w:0", "model/h1/mlp/c_fc/b:0", "model/h1/mlp/c_proj/w:0", "model/h1/mlp/c_proj/b:0", "model/h1/ln_2/g:0", "model/h1/ln_2/b:0", "model/h2/attn/c_attn/w:0", "model/h2/attn/c_attn/b:0", "model/h2/attn/c_proj/w:0", "model/h2/attn/c_proj/b:0", "model/h2/ln_1/g:0", "model/h2/ln_1/b:0", "model/h2/mlp/c_fc/w:0", "model/h2/mlp/c_fc/b:0", "model/h2/mlp/c_proj/w:0", "model/h2/mlp/c_proj/b:0", "model/h2/ln_2/g:0", "model/h2/ln_2/b:0", "model/h3/attn/c_attn/w:0", "model/h3/attn/c_attn/b:0", "model/h3/attn/c_proj/w:0", "model/h3/attn/c_proj/b:0", "model/h3/ln_1/g:0", "model/h3/ln_1/b:0", "model/h3/mlp/c_fc/w:0", "model/h3/mlp/c_fc/b:0", "model/h3/mlp/c_proj/w:0", "model/h3/mlp/c_proj/b:0", "model/h3/ln_2/g:0", "model/h3/ln_2/b:0", "model/h4/attn/c_attn/w:0", "model/h4/attn/c_attn/b:0", "model/h4/attn/c_proj/w:0", "model/h4/attn/c_proj/b:0", "model/h4/ln_1/g:0", "model/h4/ln_1/b:0", "model/h4/mlp/c_fc/w:0", "model/h4/mlp/c_fc/b:0", "model/h4/mlp/c_proj/w:0", "model/h4/mlp/c_proj/b:0", "model/h4/ln_2/g:0", "model/h4/ln_2/b:0", "model/h5/attn/c_attn/w:0", "model/h5/attn/c_attn/b:0", "model/h5/attn/c_proj/w:0", "model/h5/attn/c_proj/b:0", "model/h5/ln_1/g:0", "model/h5/ln_1/b:0", "model/h5/mlp/c_fc/w:0", "model/h5/mlp/c_fc/b:0", "model/h5/mlp/c_proj/w:0", "model/h5/mlp/c_proj/b:0", "model/h5/ln_2/g:0", "model/h5/ln_2/b:0", "model/h6/attn/c_attn/w:0", "model/h6/attn/c_attn/b:0", "model/h6/attn/c_proj/w:0", "model/h6/attn/c_proj/b:0", "model/h6/ln_1/g:0", "model/h6/ln_1/b:0", "model/h6/mlp/c_fc/w:0", "model/h6/mlp/c_fc/b:0", "model/h6/mlp/c_proj/w:0", "model/h6/mlp/c_proj/b:0", "model/h6/ln_2/g:0", "model/h6/ln_2/b:0", "model/h7/attn/c_attn/w:0", "model/h7/attn/c_attn/b:0", "model/h7/attn/c_proj/w:0", "model/h7/attn/c_proj/b:0", "model/h7/ln_1/g:0", "model/h7/ln_1/b:0", "model/h7/mlp/c_fc/w:0", "model/h7/mlp/c_fc/b:0", "model/h7/mlp/c_proj/w:0", "model/h7/mlp/c_proj/b:0", "model/h7/ln_2/g:0", "model/h7/ln_2/b:0", "model/h8/attn/c_attn/w:0", "model/h8/attn/c_attn/b:0", "model/h8/attn/c_proj/w:0", "model/h8/attn/c_proj/b:0", "model/h8/ln_1/g:0", "model/h8/ln_1/b:0", "model/h8/mlp/c_fc/w:0", "model/h8/mlp/c_fc/b:0", "model/h8/mlp/c_proj/w:0", "model/h8/mlp/c_proj/b:0", "model/h8/ln_2/g:0", "model/h8/ln_2/b:0", "model/h9/attn/c_attn/w:0", "model/h9/attn/c_attn/b:0", "model/h9/attn/c_proj/w:0", "model/h9/attn/c_proj/b:0", "model/h9/ln_1/g:0", "model/h9/ln_1/b:0", "model/h9/mlp/c_fc/w:0", "model/h9/mlp/c_fc/b:0", "model/h9/mlp/c_proj/w:0", "model/h9/mlp/c_proj/b:0", "model/h9/ln_2/g:0", "model/h9/ln_2/b:0", "model/h10/attn/c_attn/w:0", "model/h10/attn/c_attn/b:0", "model/h10/attn/c_proj/w:0", "model/h10/attn/c_proj/b:0", "model/h10/ln_1/g:0", "model/h10/ln_1/b:0", "model/h10/mlp/c_fc/w:0", "model/h10/mlp/c_fc/b:0", "model/h10/mlp/c_proj/w:0", "model/h10/mlp/c_proj/b:0", "model/h10/ln_2/g:0", "model/h10/ln_2/b:0", "model/h11/attn/c_attn/w:0", "model/h11/attn/c_attn/b:0", "model/h11/attn/c_proj/w:0", "model/h11/attn/c_proj/b:0", "model/h11/ln_1/g:0", "model/h11/ln_1/b:0", "model/h11/mlp/c_fc/w:0", "model/h11/mlp/c_fc/b:0", "model/h11/mlp/c_proj/w:0", "model/h11/mlp/c_proj/b:0", "model/h11/ln_2/g:0", "model/h11/ln_2/b:0", "model/clf/w:0", "model/clf/b:0"] -------------------------------------------------------------------------------- /scripts/encode_cnndm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from newsroom import jsonl 3 | from tqdm import tqdm 4 | from text_utils import TextEncoder 5 | 6 | def encode_line(line, encoder): 7 | encoding = encoder.encode([line]) 8 | return encoding[0] 9 | 10 | def main(args): 11 | text_encoder = TextEncoder(args.encoder_path, args.bpe_path) 12 | num_summaries = 0 13 | out_data = [] 14 | with open(args.src_file) as src_file, open(args.tgt_file) as tgt_file: 15 | src_lines = src_file.readlines() 16 | tgt_lines = tgt_file.readlines() 17 | for i in tqdm(range(len(src_lines))): 18 | num_summaries += 1 19 | out_data.append({ 20 | "summary": encode_line(tgt_lines[i].strip(), text_encoder), 21 | "text": encode_line(src_lines[i].strip(), text_encoder) 22 | }) 23 | with jsonl.open(args.out_file, gzip=True) as out_file: 24 | out_file.write(out_data) 25 | print("Number of successful conversions: {}".format(num_summaries)) 26 | 27 | 28 | if __name__ == '__main__': 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--src_file', type=str, required=True) 31 | parser.add_argument('--tgt_file', type=str, required=True) 32 | parser.add_argument('--out_file', type=str, required=True) 33 | parser.add_argument('--encoder_path', type=str, default='model/encoder_bpe_40000.json') 34 | parser.add_argument('--bpe_path', type=str, default='model/vocab_40000.bpe') 35 | args = parser.parse_args() 36 | 37 | main(args) 38 | -------------------------------------------------------------------------------- /scripts/encode_newsroom.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from newsroom import jsonl 3 | from tqdm import tqdm 4 | from text_utils import TextEncoder 5 | 6 | def encode_line(line, encoder): 7 | encoding = encoder.encode([line]) 8 | return encoding[0] 9 | 10 | def main(args): 11 | text_encoder = TextEncoder(args.encoder_path, args.bpe_path) 12 | num_summaries = 0 13 | out_data = [] 14 | with jsonl.open(args.in_file, gzip=True) as in_file: 15 | data = in_file.read() 16 | for entry in tqdm(data): 17 | if entry["summary"] is None or entry["text"] is None: 18 | continue 19 | entry["summary"] = encode_line(entry["summary"], text_encoder) 20 | entry["text"] = encode_line(entry["text"], text_encoder) 21 | num_summaries += 1 22 | out_data.append(entry) 23 | with jsonl.open(args.out_file, gzip=True) as out_file: 24 | out_file.write(out_data) 25 | print("Number of successful conversions: {}".format(num_summaries)) 26 | 27 | 28 | if __name__ == '__main__': 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--in_file', type=str, required=True) 31 | parser.add_argument('--out_file', type=str, required=True) 32 | parser.add_argument('--encoder_path', type=str, default='model/encoder_bpe_40000.json') 33 | parser.add_argument('--bpe_path', type=str, default='model/vocab_40000.bpe') 34 | args = parser.parse_args() 35 | 36 | main(args) 37 | -------------------------------------------------------------------------------- /scripts/encode_xsum.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from newsroom import jsonl 5 | from tqdm import tqdm 6 | from text_utils import TextEncoder 7 | 8 | def load_summary(file_name): 9 | summary = [] 10 | text = [] 11 | is_summary = False 12 | is_text = False 13 | with open(file_name, encoding="latin-1") as f: 14 | for line in f: 15 | line = line.strip() 16 | if line[:4] == "[SN]" and line[-4:] == "[SN]": 17 | if line[4:-4] == "FIRST-SENTENCE": 18 | is_summary = True 19 | elif line[4:-4] == "RESTBODY": 20 | is_text = True 21 | is_summary = False 22 | elif is_summary: 23 | if line: 24 | summary.append(line) 25 | elif is_text: 26 | if line: 27 | text.append(line) 28 | if len(summary) != 1 or not text: 29 | return {"summary": "", "text": ""} 30 | return {"summary": " ".join(summary), "text": " ".join(text)} 31 | 32 | def load_splits(splits_file): 33 | with open(args.splits_file) as f: 34 | splits = json.load(f) 35 | train_split = set(splits["train"]) 36 | val_split = set(splits["validation"]) 37 | test_split = set(splits["test"]) 38 | return train_split, val_split, test_split 39 | 40 | def encode_line(line, encoder): 41 | encoding = encoder.encode([line]) 42 | return encoding[0] 43 | 44 | def main(args): 45 | text_encoder = TextEncoder(args.encoder_path, args.bpe_path) 46 | train_split, val_split, test_split = load_splits(args.splits_file) 47 | summaries = os.listdir(args.summary_dir) 48 | 49 | num_summaries = 0 50 | train_data, val_data, test_data = [], [], [] 51 | for file_name in tqdm(summaries): 52 | summary_data = load_summary(os.path.join(args.summary_dir, file_name)) 53 | if len(summary_data["summary"]) == 0 or len(summary_data["text"]) == 0: 54 | continue 55 | summary_data["summary"] = encode_line(summary_data["summary"], text_encoder) 56 | summary_data["text"] = encode_line(summary_data["text"], text_encoder) 57 | file_id = file_name.split(".")[0] 58 | if file_id in train_split: 59 | train_data.append(summary_data) 60 | num_summaries += 1 61 | elif file_id in val_split: 62 | val_data.append(summary_data) 63 | num_summaries += 1 64 | elif file_id in test_split: 65 | test_data.append(summary_data) 66 | num_summaries += 1 67 | 68 | with jsonl.open(args.train_file, gzip=True) as train_file: 69 | train_file.write(train_data) 70 | with jsonl.open(args.val_file, gzip=True) as val_file: 71 | val_file.write(val_data) 72 | with jsonl.open(args.test_file, gzip=True) as test_file: 73 | test_file.write(test_data) 74 | print("Number of successful conversions: {}".format(num_summaries)) 75 | 76 | if __name__ == '__main__': 77 | parser = argparse.ArgumentParser() 78 | parser.add_argument('--summary_dir', type=str, required=True) 79 | parser.add_argument('--splits_file', type=str, required=True) 80 | parser.add_argument('--train_file', type=str, required=True) 81 | parser.add_argument('--val_file', type=str, required=True) 82 | parser.add_argument('--test_file', type=str, required=True) 83 | parser.add_argument('--encoder_path', type=str, default='model/encoder_bpe_40000.json') 84 | parser.add_argument('--bpe_path', type=str, default='model/vocab_40000.bpe') 85 | args = parser.parse_args() 86 | 87 | main(args) 88 | -------------------------------------------------------------------------------- /text_utils.py: -------------------------------------------------------------------------------- 1 | # Code referenced from https://github.com/huggingface/pytorch-openai-transformer-lm/blob/master/text_utils.py 2 | 3 | import re 4 | import json 5 | import ftfy 6 | import spacy 7 | from tqdm import tqdm 8 | 9 | def get_pairs(word): 10 | """ 11 | Return set of symbol pairs in a word. 12 | word is represented as tuple of symbols (symbols being variable-length strings) 13 | """ 14 | pairs = set() 15 | prev_char = word[0] 16 | for char in word[1:]: 17 | pairs.add((prev_char, char)) 18 | prev_char = char 19 | return pairs 20 | 21 | def text_standardize(text): 22 | """ 23 | fixes some issues the spacy tokenizer had on books corpus 24 | also does some whitespace standardization 25 | """ 26 | text = text.replace('—', '-') 27 | text = text.replace('–', '-') 28 | text = text.replace('―', '-') 29 | text = text.replace('…', '...') 30 | text = text.replace('´', "'") 31 | text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text) 32 | text = re.sub(r'\s*\n\s*', ' \n ', text) 33 | text = re.sub(r'[^\S\n]+', ' ', text) 34 | return text.strip() 35 | 36 | class TextEncoder(): 37 | """ 38 | mostly a wrapper for a public python bpe tokenizer 39 | """ 40 | 41 | def __init__(self, encoder_path, bpe_path): 42 | self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat']) 43 | self.nlp.max_length = 1500000 44 | self.encoder = json.load(open(encoder_path)) 45 | self.decoder = {v:k for k, v in self.encoder.items()} 46 | merges = open(bpe_path, encoding='utf-8').read().split('\n')[1:-1] 47 | merges = [tuple(merge.split()) for merge in merges] 48 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 49 | self.cache = {} 50 | 51 | def bpe(self, token): 52 | word = tuple(token[:-1]) + (token[-1] + '',) 53 | if token in self.cache: 54 | return self.cache[token] 55 | pairs = get_pairs(word) 56 | 57 | if not pairs: 58 | return token+'' 59 | 60 | while True: 61 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) 62 | if bigram not in self.bpe_ranks: 63 | break 64 | first, second = bigram 65 | new_word = [] 66 | i = 0 67 | while i < len(word): 68 | try: 69 | j = word.index(first, i) 70 | new_word.extend(word[i:j]) 71 | i = j 72 | except: 73 | new_word.extend(word[i:]) 74 | break 75 | 76 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 77 | new_word.append(first+second) 78 | i += 2 79 | else: 80 | new_word.append(word[i]) 81 | i += 1 82 | new_word = tuple(new_word) 83 | word = new_word 84 | if len(word) == 1: 85 | break 86 | else: 87 | pairs = get_pairs(word) 88 | word = ' '.join(word) 89 | if word == '\n ': 90 | word = '\n' 91 | self.cache[token] = word 92 | return word 93 | 94 | def encode(self, texts, verbose=True): 95 | texts_tokens = [] 96 | if verbose: 97 | for text in tqdm(texts, ncols=80, leave=False): 98 | text = self.nlp(text_standardize(ftfy.fix_text(text))) 99 | text_tokens = [] 100 | for token in text: 101 | text_tokens.extend([self.encoder.get(t, 0) for t in self.bpe(token.text.lower()).split(' ')]) 102 | texts_tokens.append(text_tokens) 103 | else: 104 | for text in texts: 105 | text = self.nlp(text_standardize(ftfy.fix_text(text))) 106 | text_tokens = [] 107 | for token in text: 108 | text_tokens.extend([self.encoder.get(t, 0) for t in self.bpe(token.text.lower()).split(' ')]) 109 | texts_tokens.append(text_tokens) 110 | return texts_tokens 111 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | import rouge 7 | import torch 8 | from torch import nn 9 | from tqdm import tqdm 10 | 11 | from data_loader import get_loader 12 | from generate import generate_outputs 13 | from logger import Logger 14 | from loss import LMLoss, SummaryLoss 15 | from model_pytorch import LMModel, load_openai_pretrained_model 16 | from opt import OpenAIAdam 17 | from parallel import DataParallelModel, DataParallelCriterion 18 | from text_utils import TextEncoder 19 | 20 | def load_checkpoint(checkpoint_file, model, model_opt, vocab, n_ctx): 21 | """ 22 | Loads a checkpoint including model state and running loss for continued training 23 | """ 24 | if checkpoint_file is not None: 25 | checkpoint = torch.load(checkpoint_file) 26 | state_dict = checkpoint["state_dict"] 27 | start_iter = checkpoint['iter'] 28 | running_loss = checkpoint['running_loss'] 29 | opt_state_dict = checkpoint['optimizer'] 30 | model_opt.load_state_dict(opt_state_dict) 31 | for state in model_opt.state.values(): 32 | for key, value in state.items(): 33 | if isinstance(value, torch.Tensor): 34 | state[key] = value.cuda() 35 | for key in list(state_dict.keys()): 36 | state_dict[key[7:]] = state_dict[key] 37 | del state_dict[key] 38 | pos_emb_mask = torch.zeros(1, 1, vocab) 39 | pos_emb_mask[:, :, -n_ctx] = -1e12 40 | model.load_state_dict(state_dict) 41 | else: 42 | start_iter = 1 43 | running_loss = 0 44 | return start_iter, running_loss 45 | 46 | def get_average_scores(hyps, refs): 47 | rouge_scorer = rouge.Rouge(metrics=['rouge-n', 'rouge-l'], 48 | max_n=4, 49 | limit_length=True, 50 | length_limit=110, 51 | length_limit_type='words', 52 | apply_avg=False, 53 | apply_best=False, 54 | alpha=0.5, # Default F1_score 55 | weight_factor=1.2, 56 | stemming=True) 57 | 58 | averaged_scores = {'rouge-1': {'f': 0, 'p': 0, 'r': 0}, 59 | 'rouge-2': {'f': 0, 'p': 0, 'r': 0}, 60 | 'rouge-l': {'f': 0, 'p': 0, 'r': 0}} 61 | scores = rouge_scorer.get_scores(hyps, refs) 62 | for metric in averaged_scores.keys(): 63 | for values in scores[metric]: 64 | for sub_metric in averaged_scores[metric]: 65 | averaged_scores[metric][sub_metric] += values[sub_metric][0] 66 | for key in averaged_scores.keys(): 67 | for sub_key in averaged_scores[key].keys(): 68 | averaged_scores[key][sub_key] /= len(hyps) 69 | return averaged_scores 70 | 71 | def run_batch(model, pad_seq, mask_seq, device, compute_loss_fct): 72 | pad_seq = pad_seq.to(device) 73 | mask_seq = mask_seq.to(device) 74 | lm_logits = model(pad_seq, mask_seq) 75 | loss = compute_loss_fct(lm_logits, pad_seq, mask_seq).mean() 76 | return loss 77 | 78 | def save_checkpoint(num_updates, iter_num, running_loss, model_state_dict, optimizer_state_dict, save_dir): 79 | torch.save({ 80 | "iter": iter_num, 81 | "running_loss": running_loss, 82 | "state_dict": model_state_dict, 83 | "optimizer": optimizer_state_dict 84 | }, os.path.join(save_dir, "checkpoint_{0:05d}.pt".format(num_updates))) 85 | 86 | def evaluate(val_loader, train_log_interval, model, text_encoder, device, beam, gen_len, k, decoding_strategy, compute_loss_fct): 87 | hyps, refs = [], [] 88 | val_loss = 0 89 | for j, (pad_seq, mask_seq) in enumerate(val_loader): 90 | with torch.no_grad(): 91 | if j == train_log_interval: 92 | break 93 | if j <= 20: 94 | model.eval() 95 | # Generating outputs for evaluation 96 | src_strs, new_refs, new_hyps = generate_outputs(model, pad_seq, mask_seq, text_encoder, device, beam, gen_len, k, decoding_strategy) 97 | hyps.extend(new_hyps) 98 | refs.extend(new_refs) 99 | # Calculating loss 100 | val_loss += run_batch(model, pad_seq, mask_seq, device, compute_loss_fct).item() 101 | scores = get_average_scores(hyps, refs) 102 | return val_loss, scores 103 | 104 | def run_epoch(start_iter, running_loss, model, compute_loss_fct, model_opt, train_loader, val_loader, train_log_interval, val_log_interval, device, beam, gen_len, k, decoding_strategy, accum_iter, desc_str, save_dir, logger, text_encoder, show_progress=False, summary_loss=None): 105 | if show_progress: 106 | train_bar = tqdm(iterable=train_loader, desc=desc_str) 107 | else: 108 | train_bar = train_loader 109 | 110 | for i, (pad_seq, mask_seq) in enumerate(train_bar, start_iter): 111 | num_updates = i // accum_iter 112 | model.train() 113 | loss = run_batch(model, pad_seq, mask_seq, device, compute_loss_fct) 114 | torch.cuda.empty_cache() 115 | loss.backward() 116 | running_loss += loss.item() 117 | 118 | if show_progress: 119 | train_bar.set_postfix(loss=running_loss / ((train_log_interval * accum_iter) if num_updates % train_log_interval == 0 and num_updates != 0 else i % (train_log_interval * accum_iter))) 120 | 121 | if i % accum_iter == 0: 122 | model_opt.step() 123 | model_opt.zero_grad() 124 | 125 | if num_updates % train_log_interval == 0 and i % accum_iter == 0: 126 | logger.scalar_summary("Training/Loss", running_loss / (train_log_interval * accum_iter), num_updates) 127 | running_loss = 0 128 | 129 | if num_updates % val_log_interval == 0 and i % accum_iter == 0: 130 | val_loss, scores = evaluate(val_loader, train_log_interval, model, text_encoder, device, beam, gen_len, k, decoding_strategy, summary_loss if summary_loss else compute_loss_fct) 131 | for key, value in scores.items(): 132 | for key2, value2 in value.items(): 133 | logger.scalar_summary("{}/{}".format(key, key2), value2, num_updates) 134 | logger.scalar_summary("Validation/Loss", val_loss / train_log_interval, num_updates) 135 | torch.cuda.empty_cache() 136 | 137 | # Saving the model 138 | if num_updates % val_log_interval == 0 and i % accum_iter == 0: 139 | save_checkpoint(num_updates, i + 1, running_loss, model.state_dict(), model_opt.state_dict(), save_dir) 140 | save_checkpoint(num_updates, i + 1, running_loss, model.state_dict(), model_opt.state_dict(), save_dir) 141 | return i + 1, running_loss 142 | 143 | def init(args): 144 | print("Creating directories") 145 | os.makedirs(args.output_dir, exist_ok=True) 146 | os.makedirs(os.path.join(args.output_dir, args.experiment_name), exist_ok=True) 147 | os.makedirs(os.path.join(args.output_dir, args.experiment_name), exist_ok=True) 148 | 149 | random.seed(args.seed) 150 | np.random.seed(args.seed) 151 | torch.manual_seed(args.seed) 152 | torch.cuda.manual_seed_all(args.seed) 153 | 154 | def main(args): 155 | init(args) 156 | 157 | # Constants 158 | n_ctx = args.n_ctx 159 | save_dir = os.path.join(args.output_dir, args.experiment_name, "checkpoints") 160 | desc = args.desc 161 | data_dir = args.data_dir 162 | log_dir = os.path.join(args.output_dir, args.experiment_name, "logs") 163 | train_log_interval = args.train_log_interval 164 | val_log_interval = args.val_log_interval 165 | beam = args.beam 166 | gen_len = args.gen_len 167 | k = args.k 168 | decoding_strategy = args.decoding_strategy 169 | accum_iter = args.accum_iter 170 | 171 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 172 | n_gpu = torch.cuda.device_count() 173 | print("device", device, "n_gpu", n_gpu) 174 | logger = Logger(log_dir) 175 | 176 | text_encoder = TextEncoder(args.encoder_path, args.vocab_path) 177 | encoder = text_encoder.encoder 178 | n_vocab = len(text_encoder.encoder) 179 | encoder['_start_'] = len(encoder) 180 | encoder['_delimiter_'] = len(encoder) 181 | encoder['_classify_'] = len(encoder) 182 | clf_token = encoder['_classify_'] 183 | n_special = 3 184 | 185 | print("Loading dataset...") 186 | train_loader = get_loader(os.path.join(data_dir, "train_encoded.jsonl"), args.n_batch, encoder, num_workers=3, shuffle=True) 187 | val_loader = get_loader(os.path.join(data_dir, "val_encoded.jsonl"), n_gpu, encoder, num_workers=0, shuffle=False, max_size=args.num_val_examples) 188 | print("Train length: {}, Validation length: {}".format(len(train_loader), len(val_loader))) 189 | 190 | vocab = n_vocab + n_special + n_ctx 191 | n_updates_total = (len(train_loader) // args.accum_iter) * (args.num_epochs_dat + args.num_epochs_ft) 192 | 193 | dh_model = LMModel(args, vocab=vocab, n_ctx=n_ctx, doc_embed=args.doc_model) 194 | 195 | criterion = nn.CrossEntropyLoss(reduction="none") 196 | model_opt = OpenAIAdam(dh_model.parameters(), 197 | lr=args.lr, 198 | schedule=args.lr_schedule, 199 | warmup=args.lr_warmup, 200 | t_total=n_updates_total, 201 | b1=args.b1, 202 | b2=args.b2, 203 | e=args.e, 204 | l2=args.l2, 205 | vector_l2=args.vector_l2, 206 | max_grad_norm=args.max_grad_norm) 207 | 208 | lm_loss = LMLoss(criterion) 209 | summary_loss = SummaryLoss(criterion) 210 | 211 | print("Loading Model") 212 | if args.use_pretrain: 213 | load_openai_pretrained_model(dh_model.transformer, n_ctx=n_ctx, n_special=n_special, path="./model/", path_names="./") 214 | start_iter, running_loss = load_checkpoint(args.checkpoint, dh_model, model_opt, vocab, n_ctx) 215 | 216 | dh_model.to(device) 217 | dh_model = DataParallelModel(dh_model) 218 | lm_loss = DataParallelCriterion(lm_loss) 219 | summary_loss = DataParallelCriterion(summary_loss) 220 | 221 | for i in range(args.num_epochs_dat): 222 | start_iter, running_loss = run_epoch(start_iter, running_loss, dh_model, lm_loss, model_opt, train_loader, val_loader, train_log_interval, val_log_interval, device, beam, gen_len, k, decoding_strategy, accum_iter, "DAT Training Epoch [{}/{}]".format(i + 1, args.num_epochs_dat), save_dir, logger, text_encoder, show_progress=args.show_progress, summary_loss=summary_loss) 223 | for i in range(args.num_epochs_ft): 224 | start_iter, running_loss = run_epoch(start_iter, running_loss, dh_model, summary_loss, model_opt, train_loader, val_loader, train_log_interval, val_log_interval, device, beam, gen_len, k, decoding_strategy, accum_iter, "FT Training Epoch [{}/{}]".format(i + 1, args.num_epochs_ft), save_dir, logger, text_encoder, show_progress=args.show_progress) 225 | 226 | if __name__ == "__main__": 227 | parser = argparse.ArgumentParser() 228 | parser.add_argument('--desc', type=str, help="Description") 229 | parser.add_argument('--seed', type=int, default=42) 230 | parser.add_argument('--num_epochs_dat', type=int, default=0) 231 | parser.add_argument('--num_epochs_ft', type=int, default=20) 232 | parser.add_argument('--n_batch', type=int, default=32) 233 | parser.add_argument('--max_grad_norm', type=int, default=1) 234 | parser.add_argument('--lr', type=float, default=6.25e-5) 235 | parser.add_argument('--lr_warmup', type=float, default=0.002) 236 | parser.add_argument('--n_ctx', type=int, default=512) 237 | parser.add_argument('--n_embd', type=int, default=768) 238 | parser.add_argument('--n_head', type=int, default=12) 239 | parser.add_argument('--n_layer', type=int, default=12) 240 | parser.add_argument('--embd_pdrop', type=float, default=0.1) 241 | parser.add_argument('--attn_pdrop', type=float, default=0.1) 242 | parser.add_argument('--resid_pdrop', type=float, default=0.1) 243 | parser.add_argument('--clf_pdrop', type=float, default=0.1) 244 | parser.add_argument('--l2', type=float, default=0.01) 245 | parser.add_argument('--vector_l2', action='store_true') 246 | parser.add_argument('--opt', type=str, default='adam') 247 | parser.add_argument('--afn', type=str, default='gelu') 248 | parser.add_argument('--lr_schedule', type=str, default='warmup_linear') 249 | parser.add_argument('--encoder_path', type=str, default='model/encoder_bpe_40000.json') 250 | parser.add_argument('--vocab_path', type=str, default='model/vocab_40000.bpe') 251 | parser.add_argument('--n_transfer', type=int, default=12) 252 | parser.add_argument('--lm_coef', type=float, default=0.5) 253 | parser.add_argument('--b1', type=float, default=0.9) 254 | parser.add_argument('--b2', type=float, default=0.999) 255 | parser.add_argument('--e', type=float, default=1e-8) 256 | # Custom 257 | parser.add_argument('--output_dir', type=str, default="output") 258 | parser.add_argument('--checkpoint', type=str, default=None) 259 | parser.add_argument('--experiment_name', type=str, required=True) 260 | parser.add_argument('--data_dir', type=str, default='data') 261 | parser.add_argument('--train_log_interval', type=int, default=100) 262 | parser.add_argument('--val_log_interval', type=int, default=2000) 263 | parser.add_argument('--num_val_examples', type=int, default=500) 264 | parser.add_argument('--beam', type=int, default=3) 265 | parser.add_argument('--gen_len', type=int, default=110) 266 | parser.add_argument('--k', type=int, default=10) 267 | parser.add_argument('--decoding_strategy', type=int, default=0) 268 | parser.add_argument('--accum_iter', type=int, default=2) 269 | parser.add_argument('--show_progress', action='store_true') 270 | parser.add_argument('--doc_model', action='store_true') 271 | parser.add_argument('--use_pretrain', action='store_true') 272 | args = parser.parse_args() 273 | print(args) 274 | main(args) 275 | --------------------------------------------------------------------------------