├── LICENSE ├── README.md ├── comet └── README.md ├── comet_utils.py ├── csk_models.py ├── data ├── aan │ ├── .gitkeep │ └── csk │ │ └── .gitkeep ├── nips │ ├── .gitkeep │ └── csk │ │ └── .gitkeep ├── nsf │ ├── .gitkeep │ └── csk │ │ └── .gitkeep ├── roc │ ├── .gitkeep │ ├── csk │ │ └── .gitkeep │ ├── test.tsv │ ├── train.tsv │ └── valid.tsv └── sind │ ├── .gitkeep │ └── csk │ └── .gitkeep ├── dataloader.py ├── modified_transformers └── modeling_roberta.py ├── prepare_csk.py ├── prepare_data.py ├── results ├── aan │ └── .gitkeep ├── nips │ └── .gitkeep ├── nsf │ └── .gitkeep ├── roc │ └── .gitkeep └── sind │ └── .gitkeep ├── saved ├── aan │ └── .gitkeep ├── nips │ └── .gitkeep ├── nsf │ └── .gitkeep ├── roc │ └── .gitkeep └── sind │ └── .gitkeep ├── stack.png ├── topological_sort.py └── train_csk.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Deep Cognition and Language Research (DeCLaRe) Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # STaCK: Sentence Ordering with Temporal Commonsense Knowledge 2 | 3 | This repository contains the pytorch implementation of the paper [STaCK: Sentence Ordering with Temporal Commonsense Knowledge](https://arxiv.org/abs/2109.02247) appearing at EMNLP 2021. 4 | 5 | 6 | ![Alt text](stack.png?raw=true "Illustration of STaCK.") 7 | 8 | Sentence ordering is the task of finding the correct order of sentences in a randomly ordered document. Correctly ordering the sentences requires an understanding of coherence with respect to the chronological sequence of events described in the text. Document-level contextual understanding and commonsense knowledge centered around these events is often essential in uncovering this coherence and predicting the exact chronological order. In this paper, we introduce STaCK --- a framework based on graph neural networks and temporal commonsense knowledge to model global information and predict the relative order of sentences. Our graph network accumulates temporal evidence using knowledge of past and future and formulates sentence ordering as a constrained edge classification problem. We report results on five different datasets, and empirically show that the proposed method is naturally suitable for order prediction. 9 | 10 | ## Data 11 | 12 | Contact the authors of the paper [Sentence Ordering and Coherence Modeling using Recurrent Neural Networks](https://arxiv.org/pdf/1611.02654.pdf) to obtain the AAN, NIPS and NSF datasets. 13 | 14 | Download the stories of images in sequence SIND dataset (SIS) from the [Visual Storytelling](http://visionandlanguage.net/VIST/dataset.html) website. 15 | 16 | Keep the files in appropriate directories in `data/` 17 | 18 | The ROC dataset with train, validation, and test splits are provided in this repository. 19 | 20 | ## Prepare Datasets 21 | 22 | Download the COMET model by following instaructions specified in `comet/` directory. Then, run the following: 23 | 24 | ``` 25 | python prepare_data.py 26 | CUDA_VISIBLE_DEVICES=0 python prepare_csk.py 27 | ``` 28 | 29 | ## Experiments: 30 | 31 | Train and evaluate using: 32 | 33 | ``` 34 | CUDA_VISIBLE_DEVICES=0 python train_csk.py --lr 1e-6 --dataset nips --epochs 10 --hdim 200 --batch-size 8 --pfd 35 | ``` 36 | 37 | For other datasets, you can use the argument `--dataset [aan|nsf|roc|sind]`. The `--pfd` argument ensures that the past and future commonsense knowledge nodes have different relations. Remove this argument to use the same relation. 38 | 39 | We recommend using a learning rate of 1e-6 for all the datasets. Run the experiments multiple times and average the scores to reproduce the results reported in the paper. 40 | 41 | ## Citation 42 | 43 | Please cite the following paper if the use this code in your work: 44 | 45 | Deepanway Ghosal, Navonil Majumder, Rada Mihalcea, Soujanya Poria. "STaCK: Sentence Ordering with Temporal Commonsense Knowledge." In Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing (EMNLP). 46 | 47 | ## Credits 48 | Some of the code in this repository is borrowed from https://github.com/shrimai/Topological-Sort-for-Sentence-Ordering and https://github.com/allenai/comet-atomic-2020 49 | 50 | -------------------------------------------------------------------------------- /comet/README.md: -------------------------------------------------------------------------------- 1 | Download and unzip the BART COMET model in this directory using: 2 | 3 | ``` 4 | wget https://storage.googleapis.com/ai2-mosaic-public/projects/mosaic-kgs/comet-atomic_2020_BART.zip 5 | unzip comet-atomic_2020_BART.zip 6 | ``` -------------------------------------------------------------------------------- /comet_utils.py: -------------------------------------------------------------------------------- 1 | # Credits: This code is used unmodified from https://github.com/allenai/comet-atomic-2020/tree/master/models/comet_atomic2020_bart/utils.py 2 | 3 | import itertools 4 | import json 5 | import linecache 6 | import os 7 | import pickle 8 | import warnings 9 | from logging import getLogger 10 | from pathlib import Path 11 | from typing import Callable, Dict, Iterable, List 12 | 13 | import git 14 | import numpy as np 15 | import torch 16 | from rouge_score import rouge_scorer, scoring 17 | from sacrebleu import corpus_bleu 18 | from torch import nn 19 | from torch.utils.data import Dataset, Sampler 20 | 21 | from transformers import BartTokenizer 22 | 23 | 24 | def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"): 25 | extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {} 26 | return tokenizer( 27 | [line], 28 | max_length=max_length, 29 | padding="max_length" if pad_to_max_length else None, 30 | truncation=True, 31 | return_tensors=return_tensors, 32 | **extra_kw, 33 | ) 34 | 35 | 36 | def lmap(f: Callable, x: Iterable) -> List: 37 | """list(map(f, x))""" 38 | return list(map(f, x)) 39 | 40 | 41 | def calculate_bleu_score(output_lns, refs_lns, **kwargs) -> dict: 42 | """Uses sacrebleu's corpus_bleu implementation.""" 43 | return {"bleu": corpus_bleu(output_lns, [refs_lns], **kwargs).score} 44 | 45 | 46 | def trim_batch( 47 | input_ids, pad_token_id, attention_mask=None, 48 | ): 49 | """Remove columns that are populated exclusively by pad_token_id""" 50 | keep_column_mask = input_ids.ne(pad_token_id).any(dim=0) 51 | if attention_mask is None: 52 | return input_ids[:, keep_column_mask] 53 | else: 54 | return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask]) 55 | 56 | 57 | class Seq2SeqDataset(Dataset): 58 | def __init__( 59 | self, 60 | tokenizer, 61 | data_dir, 62 | max_source_length, 63 | max_target_length, 64 | type_path="train", 65 | n_obs=None, 66 | src_lang=None, 67 | tgt_lang=None, 68 | prefix="", 69 | ): 70 | super().__init__() 71 | self.src_file = Path(data_dir).joinpath(type_path + ".source") 72 | self.tgt_file = Path(data_dir).joinpath(type_path + ".target") 73 | self.src_lens = self.get_char_lens(self.src_file) 74 | self.max_source_length = max_source_length 75 | self.max_target_length = max_target_length 76 | assert min(self.src_lens) > 0, f"found empty line in {self.src_file}" 77 | self.tokenizer = tokenizer 78 | self.prefix = prefix 79 | if n_obs is not None: 80 | self.src_lens = self.src_lens[:n_obs] 81 | self.pad_token_id = self.tokenizer.pad_token_id 82 | self.src_lang = src_lang 83 | self.tgt_lang = tgt_lang 84 | 85 | def __len__(self): 86 | return len(self.src_lens) 87 | 88 | def __getitem__(self, index) -> Dict[str, torch.Tensor]: 89 | index = index + 1 # linecache starts at 1 90 | source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") 91 | tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") 92 | assert source_line, f"empty source line for index {index}" 93 | assert tgt_line, f"empty tgt line for index {index}" 94 | source_inputs = encode_line(self.tokenizer, source_line, self.max_source_length) 95 | target_inputs = encode_line(self.tokenizer, tgt_line, self.max_target_length) 96 | 97 | source_ids = source_inputs["input_ids"].squeeze() 98 | target_ids = target_inputs["input_ids"].squeeze() 99 | src_mask = source_inputs["attention_mask"].squeeze() 100 | return { 101 | "input_ids": source_ids, 102 | "attention_mask": src_mask, 103 | "decoder_input_ids": target_ids, 104 | } 105 | 106 | @staticmethod 107 | def get_char_lens(data_file): 108 | return [len(x) for x in Path(data_file).open().readlines()] 109 | 110 | @staticmethod 111 | def trim_seq2seq_batch(batch, pad_token_id) -> tuple: 112 | y = trim_batch(batch["decoder_input_ids"], pad_token_id) 113 | source_ids, source_mask = trim_batch(batch["input_ids"], pad_token_id, attention_mask=batch["attention_mask"]) 114 | return source_ids, source_mask, y 115 | 116 | def collate_fn(self, batch) -> Dict[str, torch.Tensor]: 117 | input_ids = torch.stack([x["input_ids"] for x in batch]) 118 | masks = torch.stack([x["attention_mask"] for x in batch]) 119 | target_ids = torch.stack([x["decoder_input_ids"] for x in batch]) 120 | pad_token_id = self.pad_token_id 121 | y = trim_batch(target_ids, pad_token_id) 122 | source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks) 123 | batch = { 124 | "input_ids": source_ids, 125 | "attention_mask": source_mask, 126 | "decoder_input_ids": y, 127 | } 128 | return batch 129 | 130 | def make_sortish_sampler(self, batch_size): 131 | return SortishSampler(self.src_lens, batch_size) 132 | 133 | 134 | class MBartDataset(Seq2SeqDataset): 135 | def __init__(self, *args, **kwargs): 136 | super().__init__(*args, **kwargs) 137 | if self.max_source_length != self.max_target_length: 138 | warnings.warn( 139 | f"Mbart will ignore max_target_length = {self.max_target_length} and use {self.max_source_length} for both sides." 140 | ) 141 | 142 | def __getitem__(self, index) -> Dict[str, str]: 143 | index = index + 1 # linecache starts at 1 144 | source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") 145 | tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") 146 | assert source_line, f"empty source line for index {index}" 147 | assert tgt_line, f"empty tgt line for index {index}" 148 | return { 149 | "tgt_texts": source_line, 150 | "src_texts": tgt_line, 151 | } 152 | 153 | def collate_fn(self, batch) -> Dict[str, torch.Tensor]: 154 | batch_encoding = self.tokenizer.prepare_translation_batch( 155 | [x["src_texts"] for x in batch], 156 | src_lang=self.src_lang, 157 | tgt_texts=[x["tgt_texts"] for x in batch], 158 | tgt_lang=self.tgt_lang, 159 | max_length=self.max_source_length, 160 | ) 161 | return batch_encoding.data 162 | 163 | 164 | class SortishSampler(Sampler): 165 | "Go through the text data by order of src length with a bit of randomness. From fastai repo." 166 | 167 | def __init__(self, data, batch_size): 168 | self.data, self.bs = data, batch_size 169 | 170 | def key(self, i): 171 | return self.data[i] 172 | 173 | def __len__(self) -> int: 174 | return len(self.data) 175 | 176 | def __iter__(self): 177 | idxs = np.random.permutation(len(self.data)) 178 | sz = self.bs * 50 179 | ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)] 180 | sort_idx = np.concatenate([sorted(s, key=self.key, reverse=True) for s in ck_idx]) 181 | sz = self.bs 182 | ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)] 183 | max_ck = np.argmax([self.key(ck[0]) for ck in ck_idx]) # find the chunk with the largest key, 184 | ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first. 185 | sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int) 186 | sort_idx = np.concatenate((ck_idx[0], sort_idx)) 187 | return iter(sort_idx) 188 | 189 | 190 | logger = getLogger(__name__) 191 | 192 | 193 | def use_task_specific_params(model, task): 194 | """Update config with summarization specific params.""" 195 | task_specific_params = model.config.task_specific_params 196 | 197 | if task_specific_params is not None: 198 | pars = task_specific_params.get(task, {}) 199 | logger.info(f"using task specific params for {task}: {pars}") 200 | model.config.update(pars) 201 | 202 | 203 | def pickle_load(path): 204 | """pickle.load(path)""" 205 | with open(path, "rb") as f: 206 | return pickle.load(f) 207 | 208 | 209 | def pickle_save(obj, path): 210 | """pickle.dump(obj, path)""" 211 | with open(path, "wb") as f: 212 | return pickle.dump(obj, f) 213 | 214 | 215 | def flatten_list(summary_ids: List[List]): 216 | return [x for x in itertools.chain.from_iterable(summary_ids)] 217 | 218 | 219 | def save_git_info(folder_path: str) -> None: 220 | """Save git information to output_dir/git_log.json""" 221 | repo_infos = get_git_info() 222 | save_json(repo_infos, os.path.join(folder_path, "git_log.json")) 223 | 224 | 225 | def save_json(content, path): 226 | with open(path, "w") as f: 227 | json.dump(content, f, indent=4) 228 | 229 | 230 | def load_json(path): 231 | with open(path) as f: 232 | return json.load(f) 233 | 234 | 235 | def get_git_info(): 236 | repo = git.Repo(search_parent_directories=True) 237 | repo_infos = { 238 | "repo_id": str(repo), 239 | "repo_sha": str(repo.head.object.hexsha), 240 | "repo_branch": str(repo.active_branch), 241 | } 242 | return repo_infos 243 | 244 | 245 | ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"] 246 | 247 | 248 | def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer=True) -> Dict: 249 | scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer) 250 | aggregator = scoring.BootstrapAggregator() 251 | 252 | for reference_ln, output_ln in zip(reference_lns, output_lns): 253 | scores = scorer.score(reference_ln, output_ln) 254 | aggregator.add_scores(scores) 255 | 256 | result = aggregator.aggregate() 257 | return {k: v.mid.fmeasure for k, v in result.items()} 258 | 259 | 260 | def freeze_params(model: nn.Module): 261 | for par in model.parameters(): 262 | par.requires_grad = False 263 | 264 | 265 | def grad_status(model: nn.Module) -> Iterable: 266 | return (par.requires_grad for par in model.parameters()) 267 | 268 | 269 | def any_requires_grad(model: nn.Module) -> bool: 270 | return any(grad_status(model)) 271 | 272 | 273 | def assert_all_frozen(model): 274 | model_grads: List[bool] = list(grad_status(model)) 275 | n_require_grad = sum(lmap(int, model_grads)) 276 | npars = len(model_grads) 277 | assert not any(model_grads), f"{n_require_grad/npars:.1%} of {npars} weights require grad" 278 | 279 | 280 | def assert_not_all_frozen(model): 281 | model_grads: List[bool] = list(grad_status(model)) 282 | npars = len(model_grads) 283 | assert any(model_grads), f"none of {npars} weights require grad" -------------------------------------------------------------------------------- /csk_models.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn.utils.rnn import pad_sequence 6 | import dgl 7 | import dgl.nn as dglnn 8 | import dgl.function as fn 9 | from sentence_transformers import SentenceTransformer 10 | from transformers import AutoModelForSequenceClassification, AutoTokenizer 11 | from modified_transformers.modeling_roberta import RobertaForSequenceClassificationWoPositional 12 | 13 | 14 | class TransformerModel(nn.Module): 15 | def __init__(self, model_name): 16 | super().__init__() 17 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 18 | self.model = AutoModelForSequenceClassification.from_pretrained(model_name) 19 | 20 | def forward(self, sentences): 21 | max_len = 512 22 | 23 | if len(sentences) <= 40: 24 | batch = self.tokenizer(sentences, padding=True, return_tensors="pt") 25 | input_ids = batch['input_ids'][:, :max_len].cuda() 26 | attention_mask = batch['attention_mask'][:, :max_len].cuda() 27 | output = self.model(input_ids, attention_mask, output_hidden_states=True) 28 | embeddings = output['hidden_states'][-1][:, 0, :] 29 | else: 30 | embeddings = [] 31 | batch_size = 40 32 | for k in range(0, len(sentences), batch_size): 33 | batch = self.tokenizer(sentences[k:k+batch_size], padding=True, return_tensors="pt") 34 | input_ids = batch['input_ids'][:, :max_len].cuda() 35 | attention_mask = batch['attention_mask'][:, :max_len].cuda() 36 | output = self.model(input_ids, attention_mask, output_hidden_states=True) 37 | embeddings.append(output['hidden_states'][-1][:, 0, :]) 38 | embeddings = torch.cat(embeddings) 39 | return embeddings 40 | 41 | 42 | class NonPositionalTransformerModel(nn.Module): 43 | def __init__(self, model_name): 44 | super().__init__() 45 | 46 | if 'base' in model_name: 47 | name = 'roberta-base' 48 | elif 'large' in model_name: 49 | name = 'roberta-large' 50 | self.model = RobertaForSequenceClassificationWoPositional.from_pretrained(name) 51 | self.tokenizer = AutoTokenizer.from_pretrained(name) 52 | 53 | def forward(self, sentences): 54 | max_len = 512 55 | original, mask = [], [] 56 | 57 | for item in sentences: 58 | t1 = [item for sublist in [self.tokenizer(sent)["input_ids"] for sent in item] for item in sublist] 59 | original.append(torch.tensor(t1)); mask.append(torch.tensor([1]*len(t1))) 60 | 61 | original = pad_sequence(original, batch_first=True, padding_value=self.tokenizer.pad_token_id)[:, :max_len].cuda() 62 | mask = pad_sequence(mask, batch_first=True, padding_value=0)[:, :max_len].cuda() 63 | output = self.model(original, mask, output_hidden_states=True) 64 | embeddings = output['hidden_states'][-1][:, 0, :] 65 | return embeddings 66 | 67 | 68 | class SinePredictor(nn.Module): 69 | def __init__(self, in_features): 70 | super().__init__() 71 | self.W = nn.Linear(in_features, 1) 72 | 73 | def forward(self, graph, h): 74 | s = h[graph.edges()[0]] 75 | o = h[graph.edges()[1]] 76 | score = self.W(torch.sin(s-o)) 77 | return score 78 | 79 | 80 | class GraphNetwork(nn.Module): 81 | def __init__(self, encoder_name, hidden_features, out_features, rel_types=5): 82 | super().__init__() 83 | if 'base' in encoder_name: 84 | in_features = 768 85 | elif 'large' in encoder_name: 86 | in_features = 1024 87 | 88 | self.in_features = in_features 89 | self.transformer = TransformerModel(encoder_name) 90 | self.document_transformer = NonPositionalTransformerModel(encoder_name) 91 | 92 | self.gcn1 = dglnn.RelGraphConv(in_features, hidden_features, rel_types, regularizer='basis', num_bases=2) 93 | self.gcn2 = dglnn.RelGraphConv(hidden_features, out_features, rel_types, regularizer='basis', num_bases=2) 94 | self.scorer = SinePredictor(in_features+out_features) 95 | 96 | def forward(self, x, sentence_nodes, document_nodes, csk_nodes, sentences, csk): 97 | 98 | all_sentences = [sent for instance in sentences for sent in instance] 99 | sentence_embed = self.transformer(all_sentences) 100 | 101 | embeddings = torch.zeros(len(sentence_nodes)+len(document_nodes)+len(csk_nodes), self.in_features).cuda() 102 | embeddings[sentence_nodes] = sentence_embed 103 | # csk features from BART are 1024 dimensionsl, so if the sentence encoder model is a base model then we take only 768 csk dimensions from the end 104 | embeddings[csk_nodes] = torch.tensor(csk[:, -self.in_features:]).float().cuda() 105 | 106 | document_embed = self.document_transformer(sentences) 107 | embeddings[document_nodes] = document_embed 108 | 109 | g = dgl.graph((x[0], x[1])).to('cuda') 110 | etype = torch.tensor(x[2]).to('cuda') 111 | 112 | hidden = F.relu(self.gcn1(g, embeddings, etype)) 113 | hidden = F.relu(self.gcn2(g, hidden, etype)) 114 | 115 | out = torch.cat([embeddings, hidden], -1) 116 | y = self.scorer(g, out) 117 | return y, embeddings, hidden, out 118 | 119 | -------------------------------------------------------------------------------- /data/aan/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/declare-lab/sentence-ordering/c6524778f60b7093166b3cf5f003b655becef498/data/aan/.gitkeep -------------------------------------------------------------------------------- /data/aan/csk/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/declare-lab/sentence-ordering/c6524778f60b7093166b3cf5f003b655becef498/data/aan/csk/.gitkeep -------------------------------------------------------------------------------- /data/nips/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/declare-lab/sentence-ordering/c6524778f60b7093166b3cf5f003b655becef498/data/nips/.gitkeep -------------------------------------------------------------------------------- /data/nips/csk/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/declare-lab/sentence-ordering/c6524778f60b7093166b3cf5f003b655becef498/data/nips/csk/.gitkeep -------------------------------------------------------------------------------- /data/nsf/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/declare-lab/sentence-ordering/c6524778f60b7093166b3cf5f003b655becef498/data/nsf/.gitkeep -------------------------------------------------------------------------------- /data/nsf/csk/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/declare-lab/sentence-ordering/c6524778f60b7093166b3cf5f003b655becef498/data/nsf/csk/.gitkeep -------------------------------------------------------------------------------- /data/roc/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/declare-lab/sentence-ordering/c6524778f60b7093166b3cf5f003b655becef498/data/roc/.gitkeep -------------------------------------------------------------------------------- /data/roc/csk/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/declare-lab/sentence-ordering/c6524778f60b7093166b3cf5f003b655becef498/data/roc/csk/.gitkeep -------------------------------------------------------------------------------- /data/sind/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/declare-lab/sentence-ordering/c6524778f60b7093166b3cf5f003b655becef498/data/sind/.gitkeep -------------------------------------------------------------------------------- /data/sind/csk/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/declare-lab/sentence-ordering/c6524778f60b7093166b3cf5f003b655becef498/data/sind/csk/.gitkeep -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from torch.utils.data import Dataset, DataLoader 3 | 4 | class SentenceOrderingDataset(Dataset): 5 | 6 | def __init__(self, filename): 7 | 8 | id_, context1, context2 = [], [], [] 9 | with open(filename, 'r') as f: 10 | for i, line in enumerate(f): 11 | sents = line.strip().split('\t') 12 | if len(sents) > 1: 13 | id_.append(i) 14 | context1.append(sents) 15 | context2.append(' '.join(sents)) 16 | 17 | self.id = id_ 18 | self.context1 = context1 19 | self.context2 = context2 20 | 21 | def __len__(self): 22 | return len(self.context1) 23 | 24 | def __getitem__(self, index): 25 | i = self.id[index] 26 | c1 = self.context1[index] 27 | c2 = self.context2[index] 28 | return i, c1, c2 29 | 30 | def collate_fn(self, data): 31 | dat = pd.DataFrame(data) 32 | return [dat[i].tolist() for i in dat] 33 | 34 | 35 | def SentenceOrderingLoader(filename, batch_size, shuffle): 36 | dataset = SentenceOrderingDataset(filename) 37 | loader = DataLoader(dataset, shuffle=shuffle, batch_size=batch_size, collate_fn=dataset.collate_fn) 38 | return loader -------------------------------------------------------------------------------- /modified_transformers/modeling_roberta.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch RoBERTa model. """ 17 | 18 | import math 19 | 20 | import torch 21 | import torch.nn as nn 22 | import torch.utils.checkpoint 23 | from torch.nn import CrossEntropyLoss, MSELoss 24 | 25 | from transformers.activations import ACT2FN, gelu 26 | from transformers.file_utils import ( 27 | add_code_sample_docstrings, 28 | add_start_docstrings, 29 | add_start_docstrings_to_model_forward, 30 | replace_return_docstrings, 31 | ) 32 | from transformers.modeling_outputs import ( 33 | BaseModelOutputWithPastAndCrossAttentions, 34 | BaseModelOutputWithPoolingAndCrossAttentions, 35 | CausalLMOutputWithCrossAttentions, 36 | MaskedLMOutput, 37 | MultipleChoiceModelOutput, 38 | QuestionAnsweringModelOutput, 39 | SequenceClassifierOutput, 40 | TokenClassifierOutput, 41 | ) 42 | from transformers.modeling_utils import ( 43 | PreTrainedModel, 44 | apply_chunking_to_forward, 45 | find_pruneable_heads_and_indices, 46 | prune_linear_layer, 47 | ) 48 | from transformers.utils import logging 49 | from transformers.models.roberta.configuration_roberta import RobertaConfig 50 | 51 | 52 | logger = logging.get_logger(__name__) 53 | 54 | _CHECKPOINT_FOR_DOC = "roberta-base" 55 | _CONFIG_FOR_DOC = "RobertaConfig" 56 | _TOKENIZER_FOR_DOC = "RobertaTokenizer" 57 | 58 | ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [ 59 | "roberta-base", 60 | "roberta-large", 61 | "roberta-large-mnli", 62 | "distilroberta-base", 63 | "roberta-base-openai-detector", 64 | "roberta-large-openai-detector", 65 | # See all RoBERTa models at https://huggingface.co/models?filter=roberta 66 | ] 67 | 68 | 69 | class RobertaEmbeddings(nn.Module): 70 | """ 71 | Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. 72 | """ 73 | 74 | # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__ 75 | def __init__(self, config): 76 | super().__init__() 77 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) 78 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 79 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 80 | 81 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 82 | # any TensorFlow checkpoint file 83 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 84 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 85 | 86 | # position_ids (1, len position emb) is contiguous in memory and exported when serialized 87 | self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) 88 | self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") 89 | 90 | # End copy 91 | self.padding_idx = config.pad_token_id 92 | self.position_embeddings = nn.Embedding( 93 | config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx 94 | ) 95 | 96 | def forward( 97 | self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 98 | ): 99 | if position_ids is None: 100 | if input_ids is not None: 101 | # Create the position ids from the input token ids. Any padded tokens remain padded. 102 | position_ids = create_position_ids_from_input_ids( 103 | input_ids, self.padding_idx, past_key_values_length 104 | ).to(input_ids.device) 105 | else: 106 | position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) 107 | 108 | if input_ids is not None: 109 | input_shape = input_ids.size() 110 | else: 111 | input_shape = inputs_embeds.size()[:-1] 112 | 113 | if token_type_ids is None: 114 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) 115 | 116 | if inputs_embeds is None: 117 | inputs_embeds = self.word_embeddings(input_ids) 118 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 119 | 120 | embeddings = inputs_embeds + token_type_embeddings 121 | # if self.position_embedding_type == "absolute": 122 | # position_embeddings = self.position_embeddings(position_ids) 123 | # embeddings += position_embeddings 124 | embeddings = self.LayerNorm(embeddings) 125 | embeddings = self.dropout(embeddings) 126 | return embeddings 127 | 128 | def create_position_ids_from_inputs_embeds(self, inputs_embeds): 129 | """ 130 | We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. 131 | 132 | Args: 133 | inputs_embeds: torch.Tensor 134 | 135 | Returns: torch.Tensor 136 | """ 137 | input_shape = inputs_embeds.size()[:-1] 138 | sequence_length = input_shape[1] 139 | 140 | position_ids = torch.arange( 141 | self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device 142 | ) 143 | return position_ids.unsqueeze(0).expand(input_shape) 144 | 145 | 146 | # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Roberta 147 | class RobertaSelfAttention(nn.Module): 148 | def __init__(self, config): 149 | super().__init__() 150 | if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): 151 | raise ValueError( 152 | f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " 153 | f"heads ({config.num_attention_heads})" 154 | ) 155 | 156 | self.num_attention_heads = config.num_attention_heads 157 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 158 | self.all_head_size = self.num_attention_heads * self.attention_head_size 159 | 160 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 161 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 162 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 163 | 164 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 165 | self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") 166 | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": 167 | self.max_position_embeddings = config.max_position_embeddings 168 | self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) 169 | 170 | self.is_decoder = config.is_decoder 171 | 172 | def transpose_for_scores(self, x): 173 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 174 | x = x.view(*new_x_shape) 175 | return x.permute(0, 2, 1, 3) 176 | 177 | def forward( 178 | self, 179 | hidden_states, 180 | attention_mask=None, 181 | head_mask=None, 182 | encoder_hidden_states=None, 183 | encoder_attention_mask=None, 184 | past_key_value=None, 185 | output_attentions=False, 186 | ): 187 | mixed_query_layer = self.query(hidden_states) 188 | 189 | # If this is instantiated as a cross-attention module, the keys 190 | # and values come from an encoder; the attention mask needs to be 191 | # such that the encoder's padding tokens are not attended to. 192 | is_cross_attention = encoder_hidden_states is not None 193 | 194 | if is_cross_attention and past_key_value is not None: 195 | # reuse k,v, cross_attentions 196 | key_layer = past_key_value[0] 197 | value_layer = past_key_value[1] 198 | attention_mask = encoder_attention_mask 199 | elif is_cross_attention: 200 | key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) 201 | value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) 202 | attention_mask = encoder_attention_mask 203 | elif past_key_value is not None: 204 | key_layer = self.transpose_for_scores(self.key(hidden_states)) 205 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 206 | key_layer = torch.cat([past_key_value[0], key_layer], dim=2) 207 | value_layer = torch.cat([past_key_value[1], value_layer], dim=2) 208 | else: 209 | key_layer = self.transpose_for_scores(self.key(hidden_states)) 210 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 211 | 212 | query_layer = self.transpose_for_scores(mixed_query_layer) 213 | 214 | if self.is_decoder: 215 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 216 | # Further calls to cross_attention layer can then reuse all cross-attention 217 | # key/value_states (first "if" case) 218 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 219 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 220 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 221 | # if encoder bi-directional self-attention `past_key_value` is always `None` 222 | past_key_value = (key_layer, value_layer) 223 | 224 | # Take the dot product between "query" and "key" to get the raw attention scores. 225 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 226 | 227 | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": 228 | seq_length = hidden_states.size()[1] 229 | position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) 230 | position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) 231 | distance = position_ids_l - position_ids_r 232 | positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) 233 | positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility 234 | 235 | if self.position_embedding_type == "relative_key": 236 | relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) 237 | attention_scores = attention_scores + relative_position_scores 238 | elif self.position_embedding_type == "relative_key_query": 239 | relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) 240 | relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) 241 | attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key 242 | 243 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 244 | if attention_mask is not None: 245 | # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function) 246 | attention_scores = attention_scores + attention_mask 247 | 248 | # Normalize the attention scores to probabilities. 249 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 250 | 251 | # This is actually dropping out entire tokens to attend to, which might 252 | # seem a bit unusual, but is taken from the original Transformer paper. 253 | attention_probs = self.dropout(attention_probs) 254 | 255 | # Mask heads if we want to 256 | if head_mask is not None: 257 | attention_probs = attention_probs * head_mask 258 | 259 | context_layer = torch.matmul(attention_probs, value_layer) 260 | 261 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 262 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 263 | context_layer = context_layer.view(*new_context_layer_shape) 264 | 265 | outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) 266 | 267 | if self.is_decoder: 268 | outputs = outputs + (past_key_value,) 269 | return outputs 270 | 271 | 272 | # Copied from transformers.models.bert.modeling_bert.BertSelfOutput 273 | class RobertaSelfOutput(nn.Module): 274 | def __init__(self, config): 275 | super().__init__() 276 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 277 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 278 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 279 | 280 | def forward(self, hidden_states, input_tensor): 281 | hidden_states = self.dense(hidden_states) 282 | hidden_states = self.dropout(hidden_states) 283 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 284 | return hidden_states 285 | 286 | 287 | # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Roberta 288 | class RobertaAttention(nn.Module): 289 | def __init__(self, config): 290 | super().__init__() 291 | self.self = RobertaSelfAttention(config) 292 | self.output = RobertaSelfOutput(config) 293 | self.pruned_heads = set() 294 | 295 | def prune_heads(self, heads): 296 | if len(heads) == 0: 297 | return 298 | heads, index = find_pruneable_heads_and_indices( 299 | heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads 300 | ) 301 | 302 | # Prune linear layers 303 | self.self.query = prune_linear_layer(self.self.query, index) 304 | self.self.key = prune_linear_layer(self.self.key, index) 305 | self.self.value = prune_linear_layer(self.self.value, index) 306 | self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) 307 | 308 | # Update hyper params and store pruned heads 309 | self.self.num_attention_heads = self.self.num_attention_heads - len(heads) 310 | self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads 311 | self.pruned_heads = self.pruned_heads.union(heads) 312 | 313 | def forward( 314 | self, 315 | hidden_states, 316 | attention_mask=None, 317 | head_mask=None, 318 | encoder_hidden_states=None, 319 | encoder_attention_mask=None, 320 | past_key_value=None, 321 | output_attentions=False, 322 | ): 323 | self_outputs = self.self( 324 | hidden_states, 325 | attention_mask, 326 | head_mask, 327 | encoder_hidden_states, 328 | encoder_attention_mask, 329 | past_key_value, 330 | output_attentions, 331 | ) 332 | attention_output = self.output(self_outputs[0], hidden_states) 333 | outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them 334 | return outputs 335 | 336 | 337 | # Copied from transformers.models.bert.modeling_bert.BertIntermediate 338 | class RobertaIntermediate(nn.Module): 339 | def __init__(self, config): 340 | super().__init__() 341 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 342 | if isinstance(config.hidden_act, str): 343 | self.intermediate_act_fn = ACT2FN[config.hidden_act] 344 | else: 345 | self.intermediate_act_fn = config.hidden_act 346 | 347 | def forward(self, hidden_states): 348 | hidden_states = self.dense(hidden_states) 349 | hidden_states = self.intermediate_act_fn(hidden_states) 350 | return hidden_states 351 | 352 | 353 | # Copied from transformers.models.bert.modeling_bert.BertOutput 354 | class RobertaOutput(nn.Module): 355 | def __init__(self, config): 356 | super().__init__() 357 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 358 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 359 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 360 | 361 | def forward(self, hidden_states, input_tensor): 362 | hidden_states = self.dense(hidden_states) 363 | hidden_states = self.dropout(hidden_states) 364 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 365 | return hidden_states 366 | 367 | 368 | # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Roberta 369 | class RobertaLayer(nn.Module): 370 | def __init__(self, config): 371 | super().__init__() 372 | self.chunk_size_feed_forward = config.chunk_size_feed_forward 373 | self.seq_len_dim = 1 374 | self.attention = RobertaAttention(config) 375 | self.is_decoder = config.is_decoder 376 | self.add_cross_attention = config.add_cross_attention 377 | if self.add_cross_attention: 378 | assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added" 379 | self.crossattention = RobertaAttention(config) 380 | self.intermediate = RobertaIntermediate(config) 381 | self.output = RobertaOutput(config) 382 | 383 | def forward( 384 | self, 385 | hidden_states, 386 | attention_mask=None, 387 | head_mask=None, 388 | encoder_hidden_states=None, 389 | encoder_attention_mask=None, 390 | past_key_value=None, 391 | output_attentions=False, 392 | ): 393 | # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 394 | self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None 395 | self_attention_outputs = self.attention( 396 | hidden_states, 397 | attention_mask, 398 | head_mask, 399 | output_attentions=output_attentions, 400 | past_key_value=self_attn_past_key_value, 401 | ) 402 | attention_output = self_attention_outputs[0] 403 | 404 | # if decoder, the last output is tuple of self-attn cache 405 | if self.is_decoder: 406 | outputs = self_attention_outputs[1:-1] 407 | present_key_value = self_attention_outputs[-1] 408 | else: 409 | outputs = self_attention_outputs[1:] # add self attentions if we output attention weights 410 | 411 | cross_attn_present_key_value = None 412 | if self.is_decoder and encoder_hidden_states is not None: 413 | assert hasattr( 414 | self, "crossattention" 415 | ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" 416 | 417 | # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple 418 | cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None 419 | cross_attention_outputs = self.crossattention( 420 | attention_output, 421 | attention_mask, 422 | head_mask, 423 | encoder_hidden_states, 424 | encoder_attention_mask, 425 | cross_attn_past_key_value, 426 | output_attentions, 427 | ) 428 | attention_output = cross_attention_outputs[0] 429 | outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights 430 | 431 | # add cross-attn cache to positions 3,4 of present_key_value tuple 432 | cross_attn_present_key_value = cross_attention_outputs[-1] 433 | present_key_value = present_key_value + cross_attn_present_key_value 434 | 435 | layer_output = apply_chunking_to_forward( 436 | self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output 437 | ) 438 | outputs = (layer_output,) + outputs 439 | 440 | # if decoder, return the attn key/values as the last output 441 | if self.is_decoder: 442 | outputs = outputs + (present_key_value,) 443 | 444 | return outputs 445 | 446 | def feed_forward_chunk(self, attention_output): 447 | intermediate_output = self.intermediate(attention_output) 448 | layer_output = self.output(intermediate_output, attention_output) 449 | return layer_output 450 | 451 | 452 | # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Roberta 453 | class RobertaEncoder(nn.Module): 454 | def __init__(self, config): 455 | super().__init__() 456 | self.config = config 457 | self.layer = nn.ModuleList([RobertaLayer(config) for _ in range(config.num_hidden_layers)]) 458 | 459 | def forward( 460 | self, 461 | hidden_states, 462 | attention_mask=None, 463 | head_mask=None, 464 | encoder_hidden_states=None, 465 | encoder_attention_mask=None, 466 | past_key_values=None, 467 | use_cache=None, 468 | output_attentions=False, 469 | output_hidden_states=False, 470 | return_dict=True, 471 | ): 472 | all_hidden_states = () if output_hidden_states else None 473 | all_self_attentions = () if output_attentions else None 474 | all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None 475 | 476 | next_decoder_cache = () if use_cache else None 477 | for i, layer_module in enumerate(self.layer): 478 | if output_hidden_states: 479 | all_hidden_states = all_hidden_states + (hidden_states,) 480 | 481 | layer_head_mask = head_mask[i] if head_mask is not None else None 482 | past_key_value = past_key_values[i] if past_key_values is not None else None 483 | 484 | if getattr(self.config, "gradient_checkpointing", False) and self.training: 485 | 486 | if use_cache: 487 | logger.warn( 488 | "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " 489 | "`use_cache=False`..." 490 | ) 491 | use_cache = False 492 | 493 | def create_custom_forward(module): 494 | def custom_forward(*inputs): 495 | return module(*inputs, past_key_value, output_attentions) 496 | 497 | return custom_forward 498 | 499 | layer_outputs = torch.utils.checkpoint.checkpoint( 500 | create_custom_forward(layer_module), 501 | hidden_states, 502 | attention_mask, 503 | layer_head_mask, 504 | encoder_hidden_states, 505 | encoder_attention_mask, 506 | ) 507 | else: 508 | layer_outputs = layer_module( 509 | hidden_states, 510 | attention_mask, 511 | layer_head_mask, 512 | encoder_hidden_states, 513 | encoder_attention_mask, 514 | past_key_value, 515 | output_attentions, 516 | ) 517 | 518 | hidden_states = layer_outputs[0] 519 | if use_cache: 520 | next_decoder_cache += (layer_outputs[-1],) 521 | if output_attentions: 522 | all_self_attentions = all_self_attentions + (layer_outputs[1],) 523 | if self.config.add_cross_attention: 524 | all_cross_attentions = all_cross_attentions + (layer_outputs[2],) 525 | 526 | if output_hidden_states: 527 | all_hidden_states = all_hidden_states + (hidden_states,) 528 | 529 | if not return_dict: 530 | return tuple( 531 | v 532 | for v in [ 533 | hidden_states, 534 | next_decoder_cache, 535 | all_hidden_states, 536 | all_self_attentions, 537 | all_cross_attentions, 538 | ] 539 | if v is not None 540 | ) 541 | return BaseModelOutputWithPastAndCrossAttentions( 542 | last_hidden_state=hidden_states, 543 | past_key_values=next_decoder_cache, 544 | hidden_states=all_hidden_states, 545 | attentions=all_self_attentions, 546 | cross_attentions=all_cross_attentions, 547 | ) 548 | 549 | 550 | # Copied from transformers.models.bert.modeling_bert.BertPooler 551 | class RobertaPooler(nn.Module): 552 | def __init__(self, config): 553 | super().__init__() 554 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 555 | self.activation = nn.Tanh() 556 | 557 | def forward(self, hidden_states): 558 | # We "pool" the model by simply taking the hidden state corresponding 559 | # to the first token. 560 | first_token_tensor = hidden_states[:, 0] 561 | pooled_output = self.dense(first_token_tensor) 562 | pooled_output = self.activation(pooled_output) 563 | return pooled_output 564 | 565 | 566 | class RobertaPreTrainedModel(PreTrainedModel): 567 | """ 568 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 569 | models. 570 | """ 571 | 572 | config_class = RobertaConfig 573 | base_model_prefix = "roberta" 574 | 575 | # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights 576 | def _init_weights(self, module): 577 | """ Initialize the weights """ 578 | if isinstance(module, nn.Linear): 579 | # Slightly different from the TF version which uses truncated_normal for initialization 580 | # cf https://github.com/pytorch/pytorch/pull/5617 581 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 582 | if module.bias is not None: 583 | module.bias.data.zero_() 584 | elif isinstance(module, nn.Embedding): 585 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 586 | if module.padding_idx is not None: 587 | module.weight.data[module.padding_idx].zero_() 588 | elif isinstance(module, nn.LayerNorm): 589 | module.bias.data.zero_() 590 | module.weight.data.fill_(1.0) 591 | 592 | 593 | ROBERTA_START_DOCSTRING = r""" 594 | 595 | This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic 596 | methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, 597 | pruning heads etc.) 598 | 599 | This model is also a PyTorch `torch.nn.Module `__ 600 | subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to 601 | general usage and behavior. 602 | 603 | Parameters: 604 | config (:class:`~transformers.RobertaConfig`): Model configuration class with all the parameters of the 605 | model. Initializing with a config file does not load the weights associated with the model, only the 606 | configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model 607 | weights. 608 | """ 609 | 610 | ROBERTA_INPUTS_DOCSTRING = r""" 611 | Args: 612 | input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`): 613 | Indices of input sequence tokens in the vocabulary. 614 | 615 | Indices can be obtained using :class:`~transformers.RobertaTokenizer`. See 616 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for 617 | details. 618 | 619 | `What are input IDs? <../glossary.html#input-ids>`__ 620 | attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`): 621 | Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: 622 | 623 | - 1 for tokens that are **not masked**, 624 | - 0 for tokens that are **masked**. 625 | 626 | `What are attention masks? <../glossary.html#attention-mask>`__ 627 | token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): 628 | Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, 629 | 1]``: 630 | 631 | - 0 corresponds to a `sentence A` token, 632 | - 1 corresponds to a `sentence B` token. 633 | 634 | `What are token type IDs? <../glossary.html#token-type-ids>`_ 635 | position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): 636 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, 637 | config.max_position_embeddings - 1]``. 638 | 639 | `What are position IDs? <../glossary.html#position-ids>`_ 640 | head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): 641 | Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: 642 | 643 | - 1 indicates the head is **not masked**, 644 | - 0 indicates the head is **masked**. 645 | 646 | inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`): 647 | Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. 648 | This is useful if you want more control over how to convert :obj:`input_ids` indices into associated 649 | vectors than the model's internal embedding lookup matrix. 650 | output_attentions (:obj:`bool`, `optional`): 651 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned 652 | tensors for more detail. 653 | output_hidden_states (:obj:`bool`, `optional`): 654 | Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for 655 | more detail. 656 | return_dict (:obj:`bool`, `optional`): 657 | Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. 658 | """ 659 | 660 | 661 | @add_start_docstrings( 662 | "The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.", 663 | ROBERTA_START_DOCSTRING, 664 | ) 665 | class RobertaModel(RobertaPreTrainedModel): 666 | """ 667 | 668 | The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of 669 | cross-attention is added between the self-attention layers, following the architecture described in `Attention is 670 | all you need`_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz 671 | Kaiser and Illia Polosukhin. 672 | 673 | To behave as an decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration 674 | set to :obj:`True`. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder` 675 | argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an 676 | input to the forward pass. 677 | 678 | .. _`Attention is all you need`: https://arxiv.org/abs/1706.03762 679 | 680 | """ 681 | 682 | _keys_to_ignore_on_load_missing = [r"position_ids"] 683 | 684 | # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta 685 | def __init__(self, config, add_pooling_layer=True): 686 | super().__init__(config) 687 | self.config = config 688 | 689 | self.embeddings = RobertaEmbeddings(config) 690 | self.encoder = RobertaEncoder(config) 691 | 692 | self.pooler = RobertaPooler(config) if add_pooling_layer else None 693 | 694 | self.init_weights() 695 | 696 | def get_input_embeddings(self): 697 | return self.embeddings.word_embeddings 698 | 699 | def set_input_embeddings(self, value): 700 | self.embeddings.word_embeddings = value 701 | 702 | def _prune_heads(self, heads_to_prune): 703 | """ 704 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base 705 | class PreTrainedModel 706 | """ 707 | for layer, heads in heads_to_prune.items(): 708 | self.encoder.layer[layer].attention.prune_heads(heads) 709 | 710 | @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) 711 | @add_code_sample_docstrings( 712 | tokenizer_class=_TOKENIZER_FOR_DOC, 713 | checkpoint=_CHECKPOINT_FOR_DOC, 714 | output_type=BaseModelOutputWithPoolingAndCrossAttentions, 715 | config_class=_CONFIG_FOR_DOC, 716 | ) 717 | # Copied from transformers.models.bert.modeling_bert.BertModel.forward 718 | def forward( 719 | self, 720 | input_ids=None, 721 | attention_mask=None, 722 | token_type_ids=None, 723 | position_ids=None, 724 | head_mask=None, 725 | inputs_embeds=None, 726 | encoder_hidden_states=None, 727 | encoder_attention_mask=None, 728 | past_key_values=None, 729 | use_cache=None, 730 | output_attentions=None, 731 | output_hidden_states=None, 732 | return_dict=None, 733 | ): 734 | r""" 735 | encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): 736 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if 737 | the model is configured as a decoder. 738 | encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 739 | Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in 740 | the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: 741 | 742 | - 1 for tokens that are **not masked**, 743 | - 0 for tokens that are **masked**. 744 | past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): 745 | Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. 746 | 747 | If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` 748 | (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` 749 | instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. 750 | use_cache (:obj:`bool`, `optional`): 751 | If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up 752 | decoding (see :obj:`past_key_values`). 753 | """ 754 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 755 | output_hidden_states = ( 756 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 757 | ) 758 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 759 | 760 | if self.config.is_decoder: 761 | use_cache = use_cache if use_cache is not None else self.config.use_cache 762 | else: 763 | use_cache = False 764 | 765 | if input_ids is not None and inputs_embeds is not None: 766 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 767 | elif input_ids is not None: 768 | input_shape = input_ids.size() 769 | batch_size, seq_length = input_shape 770 | elif inputs_embeds is not None: 771 | input_shape = inputs_embeds.size()[:-1] 772 | batch_size, seq_length = input_shape 773 | else: 774 | raise ValueError("You have to specify either input_ids or inputs_embeds") 775 | 776 | device = input_ids.device if input_ids is not None else inputs_embeds.device 777 | 778 | # past_key_values_length 779 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 780 | 781 | if attention_mask is None: 782 | attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) 783 | if token_type_ids is None: 784 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) 785 | 786 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] 787 | # ourselves in which case we just need to make it broadcastable to all heads. 788 | extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) 789 | 790 | # If a 2D or 3D attention mask is provided for the cross-attention 791 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] 792 | if self.config.is_decoder and encoder_hidden_states is not None: 793 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 794 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 795 | if encoder_attention_mask is None: 796 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) 797 | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) 798 | else: 799 | encoder_extended_attention_mask = None 800 | 801 | # Prepare head mask if needed 802 | # 1.0 in head_mask indicate we keep the head 803 | # attention_probs has shape bsz x n_heads x N x N 804 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] 805 | # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] 806 | head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) 807 | 808 | embedding_output = self.embeddings( 809 | input_ids=input_ids, 810 | position_ids=position_ids, 811 | token_type_ids=token_type_ids, 812 | inputs_embeds=inputs_embeds, 813 | past_key_values_length=past_key_values_length, 814 | ) 815 | encoder_outputs = self.encoder( 816 | embedding_output, 817 | attention_mask=extended_attention_mask, 818 | head_mask=head_mask, 819 | encoder_hidden_states=encoder_hidden_states, 820 | encoder_attention_mask=encoder_extended_attention_mask, 821 | past_key_values=past_key_values, 822 | use_cache=use_cache, 823 | output_attentions=output_attentions, 824 | output_hidden_states=output_hidden_states, 825 | return_dict=return_dict, 826 | ) 827 | sequence_output = encoder_outputs[0] 828 | pooled_output = self.pooler(sequence_output) if self.pooler is not None else None 829 | 830 | if not return_dict: 831 | return (sequence_output, pooled_output) + encoder_outputs[1:] 832 | 833 | return BaseModelOutputWithPoolingAndCrossAttentions( 834 | last_hidden_state=sequence_output, 835 | pooler_output=pooled_output, 836 | past_key_values=encoder_outputs.past_key_values, 837 | hidden_states=encoder_outputs.hidden_states, 838 | attentions=encoder_outputs.attentions, 839 | cross_attentions=encoder_outputs.cross_attentions, 840 | ) 841 | 842 | 843 | @add_start_docstrings( 844 | """RoBERTa Model with a `language modeling` head on top for CLM fine-tuning. """, ROBERTA_START_DOCSTRING 845 | ) 846 | class RobertaForCausalLM(RobertaPreTrainedModel): 847 | _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.bias"] 848 | _keys_to_ignore_on_load_unexpected = [r"pooler"] 849 | 850 | def __init__(self, config): 851 | super().__init__(config) 852 | 853 | if not config.is_decoder: 854 | logger.warning("If you want to use `RobertaLMHeadModel` as a standalone, add `is_decoder=True.`") 855 | 856 | self.roberta = RobertaModel(config, add_pooling_layer=False) 857 | self.lm_head = RobertaLMHead(config) 858 | 859 | self.init_weights() 860 | 861 | def get_output_embeddings(self): 862 | return self.lm_head.decoder 863 | 864 | def set_output_embeddings(self, new_embeddings): 865 | self.lm_head.decoder = new_embeddings 866 | 867 | @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 868 | @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) 869 | def forward( 870 | self, 871 | input_ids=None, 872 | attention_mask=None, 873 | token_type_ids=None, 874 | position_ids=None, 875 | head_mask=None, 876 | inputs_embeds=None, 877 | encoder_hidden_states=None, 878 | encoder_attention_mask=None, 879 | labels=None, 880 | past_key_values=None, 881 | use_cache=None, 882 | output_attentions=None, 883 | output_hidden_states=None, 884 | return_dict=None, 885 | ): 886 | r""" 887 | encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): 888 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if 889 | the model is configured as a decoder. 890 | encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 891 | Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in 892 | the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: 893 | 894 | - 1 for tokens that are **not masked**, 895 | - 0 for tokens that are **masked**. 896 | 897 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 898 | Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in 899 | ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are 900 | ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` 901 | past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): 902 | Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. 903 | 904 | If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` 905 | (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` 906 | instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. 907 | use_cache (:obj:`bool`, `optional`): 908 | If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up 909 | decoding (see :obj:`past_key_values`). 910 | 911 | Returns: 912 | 913 | Example:: 914 | 915 | >>> from transformers import RobertaTokenizer, RobertaForCausalLM, RobertaConfig 916 | >>> import torch 917 | 918 | >>> tokenizer = RobertaTokenizer.from_pretrained('roberta-base') 919 | >>> config = RobertaConfig.from_pretrained("roberta-base") 920 | >>> config.is_decoder = True 921 | >>> model = RobertaForCausalLM.from_pretrained('roberta-base', config=config) 922 | 923 | >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") 924 | >>> outputs = model(**inputs) 925 | 926 | >>> prediction_logits = outputs.logits 927 | """ 928 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 929 | if labels is not None: 930 | use_cache = False 931 | 932 | outputs = self.roberta( 933 | input_ids, 934 | attention_mask=attention_mask, 935 | token_type_ids=token_type_ids, 936 | position_ids=position_ids, 937 | head_mask=head_mask, 938 | inputs_embeds=inputs_embeds, 939 | encoder_hidden_states=encoder_hidden_states, 940 | encoder_attention_mask=encoder_attention_mask, 941 | past_key_values=past_key_values, 942 | use_cache=use_cache, 943 | output_attentions=output_attentions, 944 | output_hidden_states=output_hidden_states, 945 | return_dict=return_dict, 946 | ) 947 | 948 | sequence_output = outputs[0] 949 | prediction_scores = self.lm_head(sequence_output) 950 | 951 | lm_loss = None 952 | if labels is not None: 953 | # we are doing next-token prediction; shift prediction scores and input ids by one 954 | shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() 955 | labels = labels[:, 1:].contiguous() 956 | loss_fct = CrossEntropyLoss() 957 | lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) 958 | 959 | if not return_dict: 960 | output = (prediction_scores,) + outputs[2:] 961 | return ((lm_loss,) + output) if lm_loss is not None else output 962 | 963 | return CausalLMOutputWithCrossAttentions( 964 | loss=lm_loss, 965 | logits=prediction_scores, 966 | past_key_values=outputs.past_key_values, 967 | hidden_states=outputs.hidden_states, 968 | attentions=outputs.attentions, 969 | cross_attentions=outputs.cross_attentions, 970 | ) 971 | 972 | def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): 973 | input_shape = input_ids.shape 974 | # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly 975 | if attention_mask is None: 976 | attention_mask = input_ids.new_ones(input_shape) 977 | 978 | # cut decoder_input_ids if past is used 979 | if past is not None: 980 | input_ids = input_ids[:, -1:] 981 | 982 | return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past} 983 | 984 | def _reorder_cache(self, past, beam_idx): 985 | reordered_past = () 986 | for layer_past in past: 987 | reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) 988 | return reordered_past 989 | 990 | 991 | @add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING) 992 | class RobertaForMaskedLM(RobertaPreTrainedModel): 993 | _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.bias"] 994 | _keys_to_ignore_on_load_unexpected = [r"pooler"] 995 | 996 | def __init__(self, config): 997 | super().__init__(config) 998 | 999 | if config.is_decoder: 1000 | logger.warning( 1001 | "If you want to use `RobertaForMaskedLM` make sure `config.is_decoder=False` for " 1002 | "bi-directional self-attention." 1003 | ) 1004 | 1005 | self.roberta = RobertaModel(config, add_pooling_layer=False) 1006 | self.lm_head = RobertaLMHead(config) 1007 | 1008 | self.init_weights() 1009 | 1010 | def get_output_embeddings(self): 1011 | return self.lm_head.decoder 1012 | 1013 | def set_output_embeddings(self, new_embeddings): 1014 | self.lm_head.decoder = new_embeddings 1015 | 1016 | @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 1017 | @add_code_sample_docstrings( 1018 | tokenizer_class=_TOKENIZER_FOR_DOC, 1019 | checkpoint=_CHECKPOINT_FOR_DOC, 1020 | output_type=MaskedLMOutput, 1021 | config_class=_CONFIG_FOR_DOC, 1022 | mask="", 1023 | ) 1024 | def forward( 1025 | self, 1026 | input_ids=None, 1027 | attention_mask=None, 1028 | token_type_ids=None, 1029 | position_ids=None, 1030 | head_mask=None, 1031 | inputs_embeds=None, 1032 | encoder_hidden_states=None, 1033 | encoder_attention_mask=None, 1034 | labels=None, 1035 | output_attentions=None, 1036 | output_hidden_states=None, 1037 | return_dict=None, 1038 | ): 1039 | r""" 1040 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 1041 | Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., 1042 | config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored 1043 | (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` 1044 | kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): 1045 | Used to hide legacy arguments that have been deprecated. 1046 | """ 1047 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1048 | 1049 | outputs = self.roberta( 1050 | input_ids, 1051 | attention_mask=attention_mask, 1052 | token_type_ids=token_type_ids, 1053 | position_ids=position_ids, 1054 | head_mask=head_mask, 1055 | inputs_embeds=inputs_embeds, 1056 | encoder_hidden_states=encoder_hidden_states, 1057 | encoder_attention_mask=encoder_attention_mask, 1058 | output_attentions=output_attentions, 1059 | output_hidden_states=output_hidden_states, 1060 | return_dict=return_dict, 1061 | ) 1062 | sequence_output = outputs[0] 1063 | prediction_scores = self.lm_head(sequence_output) 1064 | 1065 | masked_lm_loss = None 1066 | if labels is not None: 1067 | loss_fct = CrossEntropyLoss() 1068 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) 1069 | 1070 | if not return_dict: 1071 | output = (prediction_scores,) + outputs[2:] 1072 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 1073 | 1074 | return MaskedLMOutput( 1075 | loss=masked_lm_loss, 1076 | logits=prediction_scores, 1077 | hidden_states=outputs.hidden_states, 1078 | attentions=outputs.attentions, 1079 | ) 1080 | 1081 | 1082 | class RobertaLMHead(nn.Module): 1083 | """Roberta Head for masked language modeling.""" 1084 | 1085 | def __init__(self, config): 1086 | super().__init__() 1087 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 1088 | self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 1089 | 1090 | self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 1091 | self.bias = nn.Parameter(torch.zeros(config.vocab_size)) 1092 | 1093 | # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` 1094 | self.decoder.bias = self.bias 1095 | 1096 | def forward(self, features, **kwargs): 1097 | x = self.dense(features) 1098 | x = gelu(x) 1099 | x = self.layer_norm(x) 1100 | 1101 | # project back to size of vocabulary with bias 1102 | x = self.decoder(x) 1103 | 1104 | return x 1105 | 1106 | 1107 | @add_start_docstrings( 1108 | """ 1109 | RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the 1110 | pooled output) e.g. for GLUE tasks. 1111 | """, 1112 | ROBERTA_START_DOCSTRING, 1113 | ) 1114 | class RobertaForSequenceClassificationWoPositional(RobertaPreTrainedModel): 1115 | _keys_to_ignore_on_load_missing = [r"position_ids"] 1116 | 1117 | def __init__(self, config): 1118 | super().__init__(config) 1119 | self.num_labels = config.num_labels 1120 | 1121 | self.roberta = RobertaModel(config, add_pooling_layer=False) 1122 | self.classifier = RobertaClassificationHead(config) 1123 | 1124 | self.init_weights() 1125 | 1126 | @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 1127 | @add_code_sample_docstrings( 1128 | tokenizer_class=_TOKENIZER_FOR_DOC, 1129 | checkpoint=_CHECKPOINT_FOR_DOC, 1130 | output_type=SequenceClassifierOutput, 1131 | config_class=_CONFIG_FOR_DOC, 1132 | ) 1133 | def forward( 1134 | self, 1135 | input_ids=None, 1136 | attention_mask=None, 1137 | token_type_ids=None, 1138 | position_ids=None, 1139 | head_mask=None, 1140 | inputs_embeds=None, 1141 | labels=None, 1142 | output_attentions=None, 1143 | output_hidden_states=None, 1144 | return_dict=None, 1145 | ): 1146 | r""" 1147 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 1148 | Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., 1149 | config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), 1150 | If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1151 | """ 1152 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1153 | 1154 | outputs = self.roberta( 1155 | input_ids, 1156 | attention_mask=attention_mask, 1157 | token_type_ids=token_type_ids, 1158 | position_ids=position_ids, 1159 | head_mask=head_mask, 1160 | inputs_embeds=inputs_embeds, 1161 | output_attentions=output_attentions, 1162 | output_hidden_states=output_hidden_states, 1163 | return_dict=return_dict, 1164 | ) 1165 | sequence_output = outputs[0] 1166 | logits = self.classifier(sequence_output) 1167 | 1168 | loss = None 1169 | if labels is not None: 1170 | if self.num_labels == 1: 1171 | # We are doing regression 1172 | loss_fct = MSELoss() 1173 | loss = loss_fct(logits.view(-1), labels.view(-1)) 1174 | else: 1175 | loss_fct = CrossEntropyLoss() 1176 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 1177 | 1178 | if not return_dict: 1179 | output = (logits,) + outputs[2:] 1180 | return ((loss,) + output) if loss is not None else output 1181 | 1182 | return SequenceClassifierOutput( 1183 | loss=loss, 1184 | logits=logits, 1185 | hidden_states=outputs.hidden_states, 1186 | attentions=outputs.attentions, 1187 | ) 1188 | 1189 | 1190 | @add_start_docstrings( 1191 | """ 1192 | Roberta Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a 1193 | softmax) e.g. for RocStories/SWAG tasks. 1194 | """, 1195 | ROBERTA_START_DOCSTRING, 1196 | ) 1197 | class RobertaForMultipleChoice(RobertaPreTrainedModel): 1198 | _keys_to_ignore_on_load_missing = [r"position_ids"] 1199 | 1200 | def __init__(self, config): 1201 | super().__init__(config) 1202 | 1203 | self.roberta = RobertaModel(config) 1204 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 1205 | self.classifier = nn.Linear(config.hidden_size, 1) 1206 | 1207 | self.init_weights() 1208 | 1209 | @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) 1210 | @add_code_sample_docstrings( 1211 | tokenizer_class=_TOKENIZER_FOR_DOC, 1212 | checkpoint=_CHECKPOINT_FOR_DOC, 1213 | output_type=MultipleChoiceModelOutput, 1214 | config_class=_CONFIG_FOR_DOC, 1215 | ) 1216 | def forward( 1217 | self, 1218 | input_ids=None, 1219 | token_type_ids=None, 1220 | attention_mask=None, 1221 | labels=None, 1222 | position_ids=None, 1223 | head_mask=None, 1224 | inputs_embeds=None, 1225 | output_attentions=None, 1226 | output_hidden_states=None, 1227 | return_dict=None, 1228 | ): 1229 | r""" 1230 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 1231 | Labels for computing the multiple choice classification loss. Indices should be in ``[0, ..., 1232 | num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See 1233 | :obj:`input_ids` above) 1234 | """ 1235 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1236 | num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] 1237 | 1238 | flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None 1239 | flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None 1240 | flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None 1241 | flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None 1242 | flat_inputs_embeds = ( 1243 | inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) 1244 | if inputs_embeds is not None 1245 | else None 1246 | ) 1247 | 1248 | outputs = self.roberta( 1249 | flat_input_ids, 1250 | position_ids=flat_position_ids, 1251 | token_type_ids=flat_token_type_ids, 1252 | attention_mask=flat_attention_mask, 1253 | head_mask=head_mask, 1254 | inputs_embeds=flat_inputs_embeds, 1255 | output_attentions=output_attentions, 1256 | output_hidden_states=output_hidden_states, 1257 | return_dict=return_dict, 1258 | ) 1259 | pooled_output = outputs[1] 1260 | 1261 | pooled_output = self.dropout(pooled_output) 1262 | logits = self.classifier(pooled_output) 1263 | reshaped_logits = logits.view(-1, num_choices) 1264 | 1265 | loss = None 1266 | if labels is not None: 1267 | loss_fct = CrossEntropyLoss() 1268 | loss = loss_fct(reshaped_logits, labels) 1269 | 1270 | if not return_dict: 1271 | output = (reshaped_logits,) + outputs[2:] 1272 | return ((loss,) + output) if loss is not None else output 1273 | 1274 | return MultipleChoiceModelOutput( 1275 | loss=loss, 1276 | logits=reshaped_logits, 1277 | hidden_states=outputs.hidden_states, 1278 | attentions=outputs.attentions, 1279 | ) 1280 | 1281 | 1282 | @add_start_docstrings( 1283 | """ 1284 | Roberta Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for 1285 | Named-Entity-Recognition (NER) tasks. 1286 | """, 1287 | ROBERTA_START_DOCSTRING, 1288 | ) 1289 | class RobertaForTokenClassification(RobertaPreTrainedModel): 1290 | _keys_to_ignore_on_load_unexpected = [r"pooler"] 1291 | _keys_to_ignore_on_load_missing = [r"position_ids"] 1292 | 1293 | def __init__(self, config): 1294 | super().__init__(config) 1295 | self.num_labels = config.num_labels 1296 | 1297 | self.roberta = RobertaModel(config, add_pooling_layer=False) 1298 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 1299 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 1300 | 1301 | self.init_weights() 1302 | 1303 | @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 1304 | @add_code_sample_docstrings( 1305 | tokenizer_class=_TOKENIZER_FOR_DOC, 1306 | checkpoint=_CHECKPOINT_FOR_DOC, 1307 | output_type=TokenClassifierOutput, 1308 | config_class=_CONFIG_FOR_DOC, 1309 | ) 1310 | def forward( 1311 | self, 1312 | input_ids=None, 1313 | attention_mask=None, 1314 | token_type_ids=None, 1315 | position_ids=None, 1316 | head_mask=None, 1317 | inputs_embeds=None, 1318 | labels=None, 1319 | output_attentions=None, 1320 | output_hidden_states=None, 1321 | return_dict=None, 1322 | ): 1323 | r""" 1324 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 1325 | Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - 1326 | 1]``. 1327 | """ 1328 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1329 | 1330 | outputs = self.roberta( 1331 | input_ids, 1332 | attention_mask=attention_mask, 1333 | token_type_ids=token_type_ids, 1334 | position_ids=position_ids, 1335 | head_mask=head_mask, 1336 | inputs_embeds=inputs_embeds, 1337 | output_attentions=output_attentions, 1338 | output_hidden_states=output_hidden_states, 1339 | return_dict=return_dict, 1340 | ) 1341 | 1342 | sequence_output = outputs[0] 1343 | 1344 | sequence_output = self.dropout(sequence_output) 1345 | logits = self.classifier(sequence_output) 1346 | 1347 | loss = None 1348 | if labels is not None: 1349 | loss_fct = CrossEntropyLoss() 1350 | # Only keep active parts of the loss 1351 | if attention_mask is not None: 1352 | active_loss = attention_mask.view(-1) == 1 1353 | active_logits = logits.view(-1, self.num_labels) 1354 | active_labels = torch.where( 1355 | active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) 1356 | ) 1357 | loss = loss_fct(active_logits, active_labels) 1358 | else: 1359 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 1360 | 1361 | if not return_dict: 1362 | output = (logits,) + outputs[2:] 1363 | return ((loss,) + output) if loss is not None else output 1364 | 1365 | return TokenClassifierOutput( 1366 | loss=loss, 1367 | logits=logits, 1368 | hidden_states=outputs.hidden_states, 1369 | attentions=outputs.attentions, 1370 | ) 1371 | 1372 | 1373 | class RobertaClassificationHead(nn.Module): 1374 | """Head for sentence-level classification tasks.""" 1375 | 1376 | def __init__(self, config): 1377 | super().__init__() 1378 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 1379 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 1380 | self.out_proj = nn.Linear(config.hidden_size, config.num_labels) 1381 | 1382 | def forward(self, features, **kwargs): 1383 | x = features[:, 0, :] # take token (equiv. to [CLS]) 1384 | x = self.dropout(x) 1385 | x = self.dense(x) 1386 | x = torch.tanh(x) 1387 | x = self.dropout(x) 1388 | x = self.out_proj(x) 1389 | return x 1390 | 1391 | 1392 | @add_start_docstrings( 1393 | """ 1394 | Roberta Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear 1395 | layers on top of the hidden-states output to compute `span start logits` and `span end logits`). 1396 | """, 1397 | ROBERTA_START_DOCSTRING, 1398 | ) 1399 | class RobertaForQuestionAnswering(RobertaPreTrainedModel): 1400 | _keys_to_ignore_on_load_unexpected = [r"pooler"] 1401 | _keys_to_ignore_on_load_missing = [r"position_ids"] 1402 | 1403 | def __init__(self, config): 1404 | super().__init__(config) 1405 | self.num_labels = config.num_labels 1406 | 1407 | self.roberta = RobertaModel(config, add_pooling_layer=False) 1408 | self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) 1409 | 1410 | self.init_weights() 1411 | 1412 | @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 1413 | @add_code_sample_docstrings( 1414 | tokenizer_class=_TOKENIZER_FOR_DOC, 1415 | checkpoint=_CHECKPOINT_FOR_DOC, 1416 | output_type=QuestionAnsweringModelOutput, 1417 | config_class=_CONFIG_FOR_DOC, 1418 | ) 1419 | def forward( 1420 | self, 1421 | input_ids=None, 1422 | attention_mask=None, 1423 | token_type_ids=None, 1424 | position_ids=None, 1425 | head_mask=None, 1426 | inputs_embeds=None, 1427 | start_positions=None, 1428 | end_positions=None, 1429 | output_attentions=None, 1430 | output_hidden_states=None, 1431 | return_dict=None, 1432 | ): 1433 | r""" 1434 | start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 1435 | Labels for position (index) of the start of the labelled span for computing the token classification loss. 1436 | Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the 1437 | sequence are not taken into account for computing the loss. 1438 | end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 1439 | Labels for position (index) of the end of the labelled span for computing the token classification loss. 1440 | Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the 1441 | sequence are not taken into account for computing the loss. 1442 | """ 1443 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1444 | 1445 | outputs = self.roberta( 1446 | input_ids, 1447 | attention_mask=attention_mask, 1448 | token_type_ids=token_type_ids, 1449 | position_ids=position_ids, 1450 | head_mask=head_mask, 1451 | inputs_embeds=inputs_embeds, 1452 | output_attentions=output_attentions, 1453 | output_hidden_states=output_hidden_states, 1454 | return_dict=return_dict, 1455 | ) 1456 | 1457 | sequence_output = outputs[0] 1458 | 1459 | logits = self.qa_outputs(sequence_output) 1460 | start_logits, end_logits = logits.split(1, dim=-1) 1461 | start_logits = start_logits.squeeze(-1) 1462 | end_logits = end_logits.squeeze(-1) 1463 | 1464 | total_loss = None 1465 | if start_positions is not None and end_positions is not None: 1466 | # If we are on multi-GPU, split add a dimension 1467 | if len(start_positions.size()) > 1: 1468 | start_positions = start_positions.squeeze(-1) 1469 | if len(end_positions.size()) > 1: 1470 | end_positions = end_positions.squeeze(-1) 1471 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 1472 | ignored_index = start_logits.size(1) 1473 | start_positions.clamp_(0, ignored_index) 1474 | end_positions.clamp_(0, ignored_index) 1475 | 1476 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 1477 | start_loss = loss_fct(start_logits, start_positions) 1478 | end_loss = loss_fct(end_logits, end_positions) 1479 | total_loss = (start_loss + end_loss) / 2 1480 | 1481 | if not return_dict: 1482 | output = (start_logits, end_logits) + outputs[2:] 1483 | return ((total_loss,) + output) if total_loss is not None else output 1484 | 1485 | return QuestionAnsweringModelOutput( 1486 | loss=total_loss, 1487 | start_logits=start_logits, 1488 | end_logits=end_logits, 1489 | hidden_states=outputs.hidden_states, 1490 | attentions=outputs.attentions, 1491 | ) 1492 | 1493 | 1494 | def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): 1495 | """ 1496 | Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols 1497 | are ignored. This is modified from fairseq's `utils.make_positions`. 1498 | 1499 | Args: 1500 | x: torch.Tensor x: 1501 | 1502 | Returns: torch.Tensor 1503 | """ 1504 | # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. 1505 | mask = input_ids.ne(padding_idx).int() 1506 | incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask 1507 | return incremental_indices.long() + padding_idx 1508 | -------------------------------------------------------------------------------- /prepare_csk.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | import numpy as np 4 | import pandas as pd 5 | from tqdm import tqdm 6 | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer 7 | from comet_utils import use_task_specific_params, trim_batch 8 | from pathlib import Path 9 | 10 | 11 | if __name__ == "__main__": 12 | 13 | model_path = "comet/comet-atomic_2020_BART/" 14 | tokenizer = AutoTokenizer.from_pretrained(model_path) 15 | model = AutoModelForSeq2SeqLM.from_pretrained(model_path).cuda() 16 | device = str(model.device) 17 | use_task_specific_params(model, "summarization") 18 | model.zero_grad() 19 | model.eval() 20 | 21 | batch_size = 8 22 | relations = ["isAfter", "isBefore"] 23 | 24 | for dataset in ["roc", "nips", "aan", "nsf", "sind"]: 25 | # Path("data/" + dataset + "/csk/").mkdir(parents=True, exist_ok=True) 26 | print ("Dataset: {}".format(dataset)) 27 | for split in ["train", "test", "valid"]: 28 | print ("\tSplit: {}".format(split)) 29 | 30 | for rel in tqdm(relations, position=0, leave=True): 31 | comet_activations = {} 32 | x = open("data/" + dataset + "/" + split + ".tsv").readlines() 33 | for k, line in tqdm(enumerate(x), position=0, leave=True, total=len(x)): 34 | sents = line.strip().split('\t') 35 | queries = [] 36 | for head in sents: 37 | queries.append("{} {} [GEN]".format(head, rel)) 38 | 39 | with torch.no_grad(): 40 | batch = tokenizer(queries, return_tensors="pt", truncation=True, padding="max_length").to(device) 41 | input_ids, attention_mask = trim_batch(**batch, pad_token_id=tokenizer.pad_token_id) 42 | out = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) 43 | activations = out['decoder_hidden_states'][-1][:, 0, :].detach().cpu().numpy() 44 | 45 | comet_activations[str(k)] = activations 46 | pickle.dump(comet_activations, open("data/" + dataset + "/csk/" + split + '_' + rel + ".pkl", "wb")) 47 | 48 | -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | import csv, json, pandas as pd 2 | 3 | def get_filenames(split): 4 | with open(split, "r") as inp: 5 | filenames = inp.read() 6 | filenames = filenames.split('\n')[:-1] 7 | return filenames 8 | 9 | def get_story_text(data): 10 | story_sentences = {} 11 | annotations = data['annotations'] 12 | for annotation in annotations: 13 | story_id = annotation[0]['story_id'] 14 | story_sentences.setdefault(story_id, []) 15 | story_sentences[story_id].append(annotation[0]['original_text']) 16 | return story_sentences 17 | 18 | # For NIPS, AAN, NSF 19 | def write_data(task, split, split_name): 20 | directory = 'data/' + task + '/' 21 | dpath = directory + 'split/' + split 22 | files = get_filenames(dpath) 23 | 24 | outname = directory + split_name + '.tsv' 25 | 26 | with open(outname, "w") as out: 27 | tsv_writer = csv.writer(out, delimiter='\t') 28 | 29 | for file in files: 30 | if task == 'nips': 31 | with open(directory + 'txt_tokenized/' + 'a' + file + '.txt', 'r') as inp: 32 | lines = inp.readlines() 33 | else: 34 | with open(directory + 'txt_tokenized/' + file, 'r') as inp: 35 | lines = inp.readlines() 36 | 37 | lines = [line.strip() for line in lines] 38 | tsv_writer.writerow(lines) 39 | 40 | 41 | # For SIND 42 | def write_data_sind(split): 43 | data = json.load(open('data/sind/' + split + '.story-in-sequence.json','r')) 44 | story_sentences = get_story_text(data) 45 | 46 | if split == 'val': 47 | split_name = 'valid' 48 | else: 49 | split_name = split 50 | 51 | outname = 'data/sind/' + split_name + '.tsv' 52 | 53 | with open(outname, "w") as out: 54 | tsv_writer = csv.writer(out, delimiter='\t') 55 | for story_id in story_sentences.keys(): 56 | story = story_sentences[story_id] 57 | tsv_writer.writerow(story) 58 | 59 | 60 | # For ROC 61 | def write_data_roc(split): 62 | df = pd.read_csv('data/roc/' + split + '.csv') 63 | outname = 'data/roc/' + split + '.tsv' 64 | 65 | with open(outname, "w") as out: 66 | tsv_writer = csv.writer(out, delimiter='\t') 67 | for i in range(len(df)): 68 | row = df.iloc[i] 69 | story = [row['sentence'+str(j)] for j in range(1, 6)] 70 | tsv_writer.writerow(story) 71 | 72 | 73 | if __name__ == "__main__": 74 | write_data('nips', '2013le_papers', 'train') 75 | write_data('nips', '2014_papers', 'valid') 76 | write_data('nips', '2015_papers', 'test') 77 | 78 | for task in ['nsf', 'aan']: 79 | write_data(task, 'train', 'train') 80 | write_data(task, 'valid', 'valid') 81 | write_data(task, 'test', 'test') 82 | 83 | write_data_sind('train') 84 | write_data_sind('val') 85 | write_data_sind('test') 86 | 87 | write_data_roc('train') 88 | write_data_roc('valid') 89 | write_data_roc('test') 90 | -------------------------------------------------------------------------------- /results/aan/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/declare-lab/sentence-ordering/c6524778f60b7093166b3cf5f003b655becef498/results/aan/.gitkeep -------------------------------------------------------------------------------- /results/nips/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/declare-lab/sentence-ordering/c6524778f60b7093166b3cf5f003b655becef498/results/nips/.gitkeep -------------------------------------------------------------------------------- /results/nsf/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/declare-lab/sentence-ordering/c6524778f60b7093166b3cf5f003b655becef498/results/nsf/.gitkeep -------------------------------------------------------------------------------- /results/roc/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/declare-lab/sentence-ordering/c6524778f60b7093166b3cf5f003b655becef498/results/roc/.gitkeep -------------------------------------------------------------------------------- /results/sind/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/declare-lab/sentence-ordering/c6524778f60b7093166b3cf5f003b655becef498/results/sind/.gitkeep -------------------------------------------------------------------------------- /saved/aan/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/declare-lab/sentence-ordering/c6524778f60b7093166b3cf5f003b655becef498/saved/aan/.gitkeep -------------------------------------------------------------------------------- /saved/nips/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/declare-lab/sentence-ordering/c6524778f60b7093166b3cf5f003b655becef498/saved/nips/.gitkeep -------------------------------------------------------------------------------- /saved/nsf/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/declare-lab/sentence-ordering/c6524778f60b7093166b3cf5f003b655becef498/saved/nsf/.gitkeep -------------------------------------------------------------------------------- /saved/roc/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/declare-lab/sentence-ordering/c6524778f60b7093166b3cf5f003b655becef498/saved/roc/.gitkeep -------------------------------------------------------------------------------- /saved/sind/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/declare-lab/sentence-ordering/c6524778f60b7093166b3cf5f003b655becef498/saved/sind/.gitkeep -------------------------------------------------------------------------------- /stack.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/declare-lab/sentence-ordering/c6524778f60b7093166b3cf5f003b655becef498/stack.png -------------------------------------------------------------------------------- /topological_sort.py: -------------------------------------------------------------------------------- 1 | # Credits: The code for this file is based on https://github.com/shrimai/Topological-Sort-for-Sentence-Ordering 2 | 3 | from collections import defaultdict 4 | import csv 5 | import ast 6 | import argparse 7 | 8 | class Graph: 9 | ''' 10 | The code for this class is based on geeksforgeeks.com 11 | ''' 12 | def __init__(self,vertices): 13 | self.graph = defaultdict(list) 14 | self.V = vertices 15 | 16 | def addEdge(self, u, v, w): 17 | self.graph[u].append([v, w]) 18 | 19 | def topologicalSortUtil(self, v, visited, stack): 20 | 21 | visited[v] = True 22 | 23 | for i in self.graph[v]: 24 | if visited[i[0]] == False: 25 | self.topologicalSortUtil(i[0], visited, stack) 26 | 27 | stack.insert(0,v) 28 | 29 | def topologicalSort(self): 30 | visited = [False]*self.V 31 | stack =[] 32 | 33 | for i in range(self.V): 34 | if visited[i] == False: 35 | self.topologicalSortUtil(i, visited, stack) 36 | 37 | return stack 38 | 39 | def isCyclicUtil(self, v, visited, recStack): 40 | 41 | visited[v] = True 42 | recStack[v] = True 43 | 44 | for neighbour in self.graph[v]: 45 | if visited[neighbour[0]] == False: 46 | if self.isCyclicUtil( 47 | neighbour[0], visited, recStack) == True: 48 | return True 49 | elif recStack[neighbour[0]] == True: 50 | self.graph[v].remove(neighbour) 51 | return True 52 | 53 | recStack[v] = False 54 | return False 55 | 56 | def isCyclic(self): 57 | visited = [False] * self.V 58 | recStack = [False] * self.V 59 | for node in range(self.V): 60 | if visited[node] == False: 61 | if self.isCyclicUtil(node, visited, recStack) == True: 62 | return True 63 | return False 64 | 65 | class Stats(object): 66 | 67 | def __init__(self): 68 | self.n_samp = 0 69 | self.n_sent = 0 70 | self.n_pair = 0 71 | self.corr_samp = 0 72 | self.corr_sent = 0 73 | self.corr_pair = 0 74 | self.lcs_seq = 0 75 | self.tau = 0 76 | self.dist_window = [1, 2, 3] 77 | self.min_dist = [0]*len(self.dist_window) 78 | self.fm = 0 79 | self.lm = 0 80 | 81 | def pairwise_metric(self, g): 82 | ''' 83 | This calculates the percentage of skip-bigrams for which the 84 | relative order is predicted correctly. Rouge-S metric. 85 | ''' 86 | common = 0 87 | for vert in range(g.V): 88 | to_nodes = g.graph[vert] 89 | to_nodes = [node[0] for node in to_nodes] 90 | gold_nodes = list(range(vert+1, g.V)) 91 | common += len(set(gold_nodes).intersection(set(to_nodes))) 92 | 93 | return common 94 | 95 | def kendall_tau(self, porder, gorder): 96 | ''' 97 | It calculates the number of inversions required by the predicted 98 | order to reach the correct order. 99 | ''' 100 | pred_pairs, gold_pairs = [], [] 101 | for i in range(len(porder)): 102 | for j in range(i+1, len(porder)): 103 | pred_pairs.append((porder[i], porder[j])) 104 | gold_pairs.append((gorder[i], gorder[j])) 105 | common = len(set(pred_pairs).intersection(set(gold_pairs))) 106 | uncommon = len(gold_pairs) - common 107 | tau = 1 - (2*(uncommon/len(gold_pairs))) 108 | 109 | return tau 110 | 111 | def min_dist_metric(self, porder, gorder): 112 | ''' 113 | It calculates the displacement of sentences within a given window. 114 | ''' 115 | count = [0]*len(self.dist_window) 116 | for i in range(len(porder)): 117 | pidx = i 118 | pval = porder[i] 119 | gidx = gorder.index(pval) 120 | for w, window in enumerate(self.dist_window): 121 | if abs(pidx-gidx) <= window: 122 | count[w] += 1 123 | return count 124 | 125 | def lcs(self, X , Y): 126 | m = len(X) 127 | n = len(Y) 128 | 129 | L = [[None]*(n+1) for i in range(m+1)] 130 | 131 | for i in range(m+1): 132 | for j in range(n+1): 133 | if i == 0 or j == 0 : 134 | L[i][j] = 0 135 | elif X[i-1] == Y[j-1]: 136 | L[i][j] = L[i-1][j-1]+1 137 | else: 138 | L[i][j] = max(L[i-1][j] , L[i][j-1]) 139 | 140 | return L[m][n] 141 | 142 | def sample_match(self, order, gold_order): 143 | ''' 144 | It calculates the percentage of samples for which the entire 145 | sequence was correctly predicted. (PMR) 146 | ''' 147 | return order == gold_order 148 | 149 | 150 | def first_match(self, order, gold_order): 151 | ''' 152 | It calculates the percentage of samples for which the first sentence 153 | was correctly predicted. (PMR) 154 | ''' 155 | return order[0] == gold_order[0] 156 | 157 | def last_match(self, order, gold_order): 158 | ''' 159 | It calculates the percentage of samples for which the first sentence 160 | was correctly predicted. (PMR) 161 | ''' 162 | return order[-1] == gold_order[-1] 163 | 164 | def sentence_match(self, order, gold_order): 165 | ''' 166 | It measures the percentage of sentences for which their absolute 167 | position was correctly predicted. (Acc) 168 | ''' 169 | return sum([1 for x in range(len(order)) if order[x] == gold_order[x]]) 170 | 171 | def update_stats(self, nvert, npairs, order, gold_order, g): 172 | self.n_samp += 1 173 | self.n_sent += nvert 174 | self.n_pair += npairs 175 | 176 | if self.sample_match(order, gold_order): 177 | self.corr_samp += 1 178 | if self.first_match(order, gold_order): 179 | self.fm += 1 180 | if self.last_match(order, gold_order): 181 | self.lm += 1 182 | self.corr_sent += self.sentence_match(order, gold_order) 183 | self.corr_pair += self.pairwise_metric(g) 184 | self.lcs_seq += self.lcs(order, gold_order) 185 | self.tau += self.kendall_tau(order, gold_order) 186 | window_counts = self.min_dist_metric(order, gold_order) 187 | for w, wc in enumerate(window_counts): 188 | self.min_dist[w] += wc 189 | 190 | def print_stats(self): 191 | print("Perfect Match: " + str(self.corr_samp*100/self.n_samp)) 192 | print("First Sentence Match: " + str(self.fm*100/self.n_samp)) 193 | print("Last Sentence Match: " + str(self.lm*100/self.n_samp)) 194 | print("Sentence Accuracy: " + str(self.corr_sent*100/self.n_sent)) 195 | print("Rouge-S: " + str(self.corr_pair*100/self.n_pair)) 196 | print("LCS: " + str(self.lcs_seq*100/self.n_sent)) 197 | print("Kendall Tau Ratio: " + str(self.tau/self.n_samp)) 198 | for w, window in enumerate(self.dist_window): 199 | print("Min Dist Metric for window " + str(window) + ": " + \ 200 | str(self.min_dist[w]*100/self.n_sent)) 201 | 202 | def metric(self): 203 | PMR = str(round(self.corr_samp*100/self.n_samp, 4)) 204 | tau = str(round(self.tau/self.n_samp, 4)) 205 | return {"PMR": PMR, "tau": tau} 206 | 207 | def convert_to_graph(data): 208 | 209 | stats = Stats() 210 | i = 0 211 | no_docs, no_sents = 0, 0 212 | 213 | lengths = [] 214 | while i < len(data): 215 | ids = data[i][0] 216 | 217 | # get no vertices 218 | docid, nvert, npairs = ids.split('-') 219 | docid, nvert, npairs = int(docid), int(nvert), int(npairs) 220 | 221 | lengths.append(nvert) 222 | # create graph obj 223 | g = Graph(nvert) 224 | 225 | #read pred label 226 | for j in range(i, i+npairs): 227 | pred = int(data[j][8]) 228 | log0, log1 = float(data[j][6]), float(data[j][7]) 229 | pos_s1, pos_s2 = int(data[j][4]), int(data[j][5]) 230 | 231 | if pred == 0: 232 | g.addEdge(pos_s2, pos_s1, log0) 233 | elif pred == 1: 234 | g.addEdge(pos_s1, pos_s2, log1) 235 | 236 | i += npairs 237 | 238 | while g.isCyclic(): 239 | g.isCyclic() 240 | 241 | order = g.topologicalSort() 242 | no_sents += nvert 243 | no_docs += 1 244 | gold_order = list(range(nvert)) 245 | stats.update_stats(nvert, npairs, order, gold_order, g) 246 | 247 | if len(order) != len(gold_order): 248 | print("yes") 249 | 250 | return stats 251 | 252 | def readf(filename): 253 | data = [] 254 | with open(filename, "r") as inp: 255 | content = csv.reader(inp, delimiter='\t') 256 | for row in content: 257 | data.append(row) 258 | return data 259 | 260 | 261 | def readf_long(filename): 262 | data = [] 263 | with open(filename, "r") as inp: 264 | content = csv.reader(inp, delimiter='\t') 265 | for row in content: 266 | if int(row[0].split('-')[1]) > 10: 267 | data.append(row) 268 | return data 269 | 270 | def main(): 271 | parser = argparse.ArgumentParser() 272 | ## Required parameters 273 | parser.add_argument("--file_path", default=None, type=str, 274 | required=True, help="The input data dir.") 275 | args = parser.parse_args() 276 | 277 | data = readf(args.file_path) 278 | stats = convert_to_graph(data) 279 | stats.print_stats() 280 | 281 | print ("\nLonger than 10 sents:") 282 | data2 = readf_long(args.file_path) 283 | stats2 = convert_to_graph(data2) 284 | stats2.print_stats() 285 | 286 | if __name__ == "__main__": 287 | main() 288 | -------------------------------------------------------------------------------- /train_csk.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import pickle 5 | import argparse 6 | import numpy as np 7 | from tqdm import tqdm 8 | import torch 9 | import torch.optim as optim 10 | from csk_models import GraphNetwork 11 | from dataloader import SentenceOrderingLoader 12 | from topological_sort import convert_to_graph 13 | from transformers.optimization import AdamW, get_scheduler 14 | from transformers.trainer_pt_utils import get_parameter_names 15 | 16 | class MultipleOptimizer(object): 17 | def __init__(self, *op): 18 | self.optimizers = op 19 | 20 | def zero_grad(self): 21 | for op in self.optimizers: 22 | op.zero_grad() 23 | 24 | def step(self): 25 | for op in self.optimizers: 26 | op.step() 27 | 28 | def configure_dataloaders(dataset, batch_size): 29 | "Prepare dataloaders" 30 | train_loader = SentenceOrderingLoader('data/' + dataset + '/train.tsv', batch_size, shuffle=False) 31 | valid_loader = SentenceOrderingLoader('data/' + dataset + '/valid.tsv', batch_size, shuffle=False) 32 | test_loader = SentenceOrderingLoader('data/' + dataset + '/test.tsv', batch_size, shuffle=False) 33 | return train_loader, valid_loader, test_loader 34 | 35 | def configure_transformer_optimizer(model, args): 36 | "Prepare AdamW optimizer for transformer encoders" 37 | decay_parameters = get_parameter_names(model, [torch.nn.LayerNorm]) 38 | # decay_parameters = [name for name in decay_parameters if "bias" not in name] 39 | decay_parameters = [name for name in decay_parameters if ("bias" not in name and 'gcn' not in name and 'scorer' not in name)] 40 | optimizer_grouped_parameters = [ 41 | { 42 | "params": [p for n, p in model.named_parameters() if n in decay_parameters], 43 | "weight_decay": args.wd, 44 | }, 45 | { 46 | "params": [p for n, p in model.named_parameters() if n not in decay_parameters], 47 | "weight_decay": 0.0, 48 | }, 49 | ] 50 | optimizer_kwargs = { 51 | "betas": (args.adam_beta1, args.adam_beta2), 52 | "eps": args.adam_epsilon, 53 | "lr": args.lr 54 | } 55 | optimizer = AdamW(optimizer_grouped_parameters, **optimizer_kwargs) 56 | return optimizer 57 | 58 | def configure_gcn_optimizer(model, args): 59 | "Prepare Adam optimizer for GCN decoders" 60 | optimizer = optim.Adam([ 61 | {'params': model.gcn1.parameters()}, 62 | {'params': model.gcn2.parameters()}, 63 | {'params': model.scorer.parameters()} 64 | ], lr=args.lr0, weight_decay=args.wd0) 65 | return optimizer 66 | 67 | 68 | def configure_scheduler(optimizer, num_training_steps, args): 69 | "Prepare scheduler" 70 | warmup_steps = ( 71 | args.warmup_steps 72 | if args.warmup_steps > 0 73 | else math.ceil(num_training_steps * args.warmup_ratio) 74 | ) 75 | lr_scheduler = get_scheduler( 76 | args.lr_scheduler_type, 77 | optimizer, 78 | num_warmup_steps=warmup_steps, 79 | num_training_steps=num_training_steps, 80 | ) 81 | return lr_scheduler 82 | 83 | def initiate_graph_edges(lengths, past_future_diff): 84 | 85 | csk_node = sum(lengths) + len(lengths) 86 | 87 | # n1: node1, n2: node2, r: relation 88 | # relations: 89 | # 0: edges between different sentences 90 | # 1: self edges or edges between same sentences (to ensure self dependent feature propagation) 91 | # 2: edges between sentences and document 92 | # 3: self edge of the document (to ensure self dependent feature propagation) 93 | # 4, 5: edges between sentences and their commonsense feature nodes 94 | 95 | n1, n2, r = [], [], [] 96 | sentence_nodes, document_nodes, csk_nodes, node_count = set(), [], [], 0 97 | 98 | for k, l in enumerate(lengths): 99 | # sentence - sentence node 100 | # 0: different sentence, 1: same sentence 101 | for i in range(l): 102 | for j in range(i, l): 103 | n1.append(node_count+j); n2.append(node_count+i) 104 | if i != j: 105 | r.append(0) 106 | n1.append(node_count+i); n2.append(node_count+j); r.append(0) 107 | else: 108 | r.append(1) 109 | sentence_nodes.add(node_count+j) 110 | sentence_nodes.add(node_count+i) 111 | 112 | n1.append(csk_node); n2.append(i); r.append(4) 113 | n1.append(csk_node+1); n2.append(i) 114 | if past_future_diff: 115 | r.append(5) 116 | else: 117 | r.append(4) 118 | 119 | csk_nodes.append(csk_node); csk_nodes.append(csk_node+1) 120 | csk_node += 2 121 | 122 | # document - sentence node : 2 123 | for i in range(l): 124 | n1.append(node_count+l); n2.append(node_count+i); r.append(2) 125 | 126 | # document - document node : 3 127 | n1.append(node_count+l); n2.append(node_count+l); r.append(3) 128 | document_nodes.append(node_count+l) 129 | 130 | # increment node count 131 | node_count += l+1 132 | 133 | x = np.array([n1, n2, r]) 134 | return x, list(sentence_nodes), document_nodes, csk_nodes 135 | 136 | def csk_vectors(id_, cska, cskb): 137 | a = np.concatenate([cska[str(k)] for k in id_]) 138 | b = np.concatenate([cskb[str(k)] for k in id_]) 139 | c = np.zeros((len(a)+len(b),1024)) 140 | c[0::2, :] = a 141 | c[1::2, :] = b 142 | return c 143 | 144 | def predictions(x, log_prob, id_, indices, sentences, document_nodes): 145 | lp = log_prob.detach().cpu().numpy() 146 | edges = x[:2, indices].transpose(1, 0) 147 | predictions = log_prob.argmax(1).cpu().numpy() 148 | 149 | final_preds = [] 150 | for j in range(1, len(edges), 2): 151 | ind = (j-1)//2 152 | final_preds.append((1, edges[j][0], edges[j][1], lp[ind][0], lp[ind][1], predictions[ind])) 153 | 154 | new_final_preds, groups, k = [], [], 0 155 | for item in final_preds: 156 | if max(item[1], item[2]) < document_nodes[k]: 157 | groups.append(item) 158 | else: 159 | k += 1 160 | new_final_preds.append(groups) 161 | groups = [item] 162 | 163 | new_final_preds.append(groups) 164 | out = [] 165 | for count, fp, s in zip(id_, new_final_preds, sentences): 166 | min_index = fp[0][1] 167 | num_sents = len(s) 168 | sent_id = str(count) + '-' + str(num_sents) + '-' + str(num_sents*(num_sents-1)//2) 169 | 170 | for item in fp: 171 | out.append([sent_id, s[item[1]-min_index], s[item[2]-min_index], 1, 172 | item[1]-min_index, item[2]-min_index, item[3], item[4], item[5]]) 173 | 174 | return out 175 | 176 | def train_or_eval_model(model, dataloader, optimizer=None, train=False): 177 | losses = [] 178 | assert not train or optimizer!=None 179 | 180 | if train: 181 | model.train() 182 | else: 183 | model.eval() 184 | 185 | for id_, sentences, _ in tqdm(dataloader, leave=False): 186 | if train: 187 | optimizer.zero_grad() 188 | 189 | lengths = [len(item) for item in sentences] 190 | x, sentence_nodes, document_nodes, csk_nodes = initiate_graph_edges(lengths, pfd) 191 | 192 | if train: 193 | csk = csk_vectors(id_, train_csk_after, train_csk_before) 194 | else: 195 | csk = csk_vectors(id_, valid_csk_after, valid_csk_before) 196 | 197 | out, _, _, _ = model(x, sentence_nodes, document_nodes, csk_nodes, sentences, csk) 198 | indices = x[2] == 0 199 | prob = torch.softmax(out[indices].reshape(-1, 2), 1) 200 | log_prob = torch.log(prob) 201 | labels = torch.ones(len(prob), dtype=torch.long).cuda() 202 | loss = loss_function(log_prob, labels) 203 | 204 | if train: 205 | loss.backward() 206 | optimizer.step() 207 | 208 | losses.append(loss.item()) 209 | 210 | avg_loss = round(np.mean(losses), 4) 211 | return avg_loss 212 | 213 | def test_model(model, dataloader): 214 | losses, results = [], [] 215 | model.eval() 216 | 217 | for id_, sentences, _ in tqdm(dataloader, leave=False): 218 | lengths = [len(item) for item in sentences] 219 | x, sentence_nodes, document_nodes, csk_nodes = initiate_graph_edges(lengths, pfd) 220 | csk = csk_vectors(id_, test_csk_after, test_csk_before) 221 | 222 | with torch.no_grad(): 223 | out, _, _, _ = model(x, sentence_nodes, document_nodes, csk_nodes, sentences, csk) 224 | indices = x[2] == 0 225 | prob = torch.softmax(out[indices].reshape(-1, 2), 1) 226 | log_prob = torch.log(prob) 227 | labels = torch.ones(len(prob), dtype=torch.long).cuda() 228 | loss = loss_function(log_prob, labels) 229 | losses.append(loss.item()) 230 | batch_results = predictions(x, log_prob, id_, indices, sentences, document_nodes) 231 | results += batch_results 232 | 233 | avg_loss = round(np.mean(losses), 4) 234 | return avg_loss, results 235 | 236 | if __name__ == "__main__": 237 | 238 | parser = argparse.ArgumentParser() 239 | parser.add_argument("--lr", type=float, default=1e-6, help="Learning rate for transformers.") 240 | parser.add_argument("--lr0", type=float, default=1e-4, help="Learning rate for GCN.") 241 | parser.add_argument("--wd", default=0.0, type=float, help="Weight decay for transformers.") 242 | parser.add_argument("--wd0", default=1e-6, type=float, help="Weight decay for GCN.") 243 | parser.add_argument("--adam-epsilon", default=1e-8, type=float, help="Epsilon for AdamW optimizer.") 244 | parser.add_argument("--adam-beta1", default=0.9, type=float, help="beta1 for AdamW optimizer.") 245 | parser.add_argument("--adam-beta2", default=0.999, type=float, help="beta2 for AdamW optimizer.") 246 | parser.add_argument("--lr-scheduler-type", default="linear") 247 | parser.add_argument("--warmup-steps", type=int, default=0, help="Steps used for a linear warmup from 0 to lr.") 248 | parser.add_argument("--warmup-ratio", type=float, default=0.0, help="Ratio of total training steps used for a linear warmup from 0 to lr.") 249 | parser.add_argument("--dataset", default="roc", help="Which dataset: roc, nips, nsf, sind, aan") 250 | parser.add_argument("--batch-size", type=int, default=8, help="Batch size.") 251 | parser.add_argument("--epochs", type=int, default=6, help="Number of epochs.") 252 | parser.add_argument("--encoder", default="microsoft/deberta-base", help="Which sentence encoder") 253 | parser.add_argument("--hdim", type=int, default=100, help="Hidden dim GCN.") 254 | parser.add_argument('--pfd', action='store_true', default=False, help='Different relations for past future commonsense nodes.') 255 | 256 | 257 | args = parser.parse_args() 258 | print(args) 259 | 260 | global pfd 261 | global loss_function 262 | global train_csk_after, train_csk_before, valid_csk_after, valid_csk_before, test_csk_after, test_csk_before 263 | 264 | dataset = args.dataset 265 | batch_size = args.batch_size 266 | n_epochs = args.epochs 267 | encoder = args.encoder 268 | hdim = args.hdim 269 | pfd = args.pfd 270 | 271 | 272 | if args.pfd: 273 | num_rels = 6 274 | else: 275 | num_rels = 5 276 | 277 | run_ID = int(time.time()) 278 | print ('run id:', run_ID) 279 | 280 | model = GraphNetwork(encoder, hdim, hdim, rel_types=num_rels).cuda() 281 | loss_function = torch.nn.NLLLoss().cuda() 282 | optimizer1 = configure_transformer_optimizer(model, args) 283 | optimizer2 = configure_gcn_optimizer(model, args) 284 | optimizer = MultipleOptimizer(optimizer1, optimizer2) 285 | 286 | train_loader, valid_loader, test_loader = configure_dataloaders(dataset, batch_size) 287 | 288 | train_csk_after = pickle.load(open('data/' + dataset + '/csk/train_isAfter.pkl', 'rb')) 289 | train_csk_before = pickle.load(open('data/' + dataset + '/csk/train_isBefore.pkl', 'rb')) 290 | valid_csk_after = pickle.load(open('data/' + dataset + '/csk/valid_isAfter.pkl', 'rb')) 291 | valid_csk_before = pickle.load(open('data/' + dataset + '/csk/valid_isBefore.pkl', 'rb')) 292 | test_csk_after = pickle.load(open('data/' + dataset + '/csk/test_isAfter.pkl', 'rb')) 293 | test_csk_before = pickle.load(open('data/' + dataset + '/csk/test_isBefore.pkl', 'rb')) 294 | 295 | lf = open('results/'+ dataset + '/logs_csk_final.tsv', 'a') 296 | lf.write(str(run_ID) + '\t' + str(args) + '\n') 297 | 298 | best_loss = None 299 | for e in range(n_epochs): 300 | train_loss = train_or_eval_model(model, train_loader, optimizer, True) 301 | 302 | valid_loss = train_or_eval_model(model, valid_loader) 303 | # valid_loss, valid_results = test_model(model, valid_loader) 304 | # valid_stats = convert_to_graph(valid_results) 305 | 306 | test_loss, test_results = test_model(model, test_loader) 307 | test_stats = convert_to_graph(test_results) 308 | 309 | x = 'Epoch {}: train loss: {}, valid loss: {}; test loss: {} metrics: {}'.format(e+1, train_loss, valid_loss, test_loss, test_stats.metric()) 310 | print (x) 311 | lf.write(x + '\n') 312 | 313 | if best_loss == None or best_loss > valid_loss: 314 | if not os.path.exists('saved/'+ dataset + '/' + str(run_ID) + '/'): 315 | os.makedirs('saved/'+ dataset + '/' + str(run_ID) + '/') 316 | torch.save(model.state_dict(), 'saved/'+ dataset + '/' + str(run_ID) + '/model.pt') 317 | best_loss = valid_loss 318 | 319 | lf.write('\n\n') 320 | lf.close() 321 | 322 | model.load_state_dict(torch.load('saved/'+ dataset + '/' + str(run_ID) + '/model.pt')) 323 | model.eval() 324 | 325 | test_loss, results = test_model(model, test_loader) 326 | stats = convert_to_graph(results) 327 | print ('Test loss, metrics at best valid loss: {} {}'.format(test_loss, stats.metric())) 328 | 329 | content = [str(test_loss), str(stats.metric()), str(run_ID), str(args)] 330 | with open('results/' + dataset + '/results_csk_final.txt', 'a') as f: 331 | f.write('\t'.join(content) + '\n') 332 | 333 | with open('results/'+ dataset + '/results_csk_' + str(run_ID) + '.tsv', 'w') as f: 334 | for line in results: 335 | content = '\t'.join([str(s) for s in line]) 336 | f.write(content + '\n') 337 | --------------------------------------------------------------------------------