├── .gitignore ├── LICENSE ├── README.md ├── classify ├── __init__.py ├── compute │ ├── __init__.py │ └── trainer.py ├── data │ ├── __init__.py │ ├── dataset.py │ ├── load.py │ ├── loaders │ │ ├── __init__.py │ │ ├── loader.py │ │ ├── multirc.py │ │ └── snli.py │ ├── multirc_sent_sampler.py │ ├── sampler.py │ ├── snli_sampler.py │ ├── text.py │ └── utils.py ├── metric │ ├── __init__.py │ ├── abstract.py │ ├── dev │ │ ├── __init__.py │ │ ├── accuracy.py │ │ ├── micro_rationacc.py │ │ └── rationacc.py │ ├── load.py │ ├── loss │ │ ├── __init__.py │ │ ├── bce.py │ │ ├── ce.py │ │ ├── f1loss.py │ │ ├── rationaleloss.py │ │ └── regularizor.py │ └── train │ │ ├── __init__.py │ │ ├── alignment.py │ │ └── cost.py └── models │ ├── __init__.py │ ├── attention.py │ ├── encoder.py │ ├── ot_atten.py │ ├── ot_atten_sent.py │ └── pooling_attention.py ├── requirements.txt ├── setup.py ├── similarity ├── __init__.py ├── compute │ ├── __init__.py │ └── trainer.py ├── data │ ├── __init__.py │ ├── dataset.py │ ├── load.py │ ├── loaders │ │ ├── __init__.py │ │ ├── askubuntu.py │ │ ├── loader.py │ │ └── multinews.py │ ├── sampler.py │ ├── text.py │ └── utils.py ├── metric │ ├── __init__.py │ ├── abstract.py │ ├── dev │ │ ├── __init__.py │ │ ├── auc.py │ │ └── similarity.py │ ├── load.py │ ├── loss │ │ ├── __init__.py │ │ └── hinge.py │ └── train │ │ ├── __init__.py │ │ ├── alignment.py │ │ └── cost.py └── models │ ├── __init__.py │ ├── alignment.py │ ├── attention.py │ └── encoder.py ├── sinkhorn ├── __init__.py ├── cost_and_marginals.py ├── sinkhorn.py └── utils.py ├── train.py └── utils ├── __init__.py ├── berttokenizer.py ├── parsing.py ├── utils.py └── version.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.egg-info/ 3 | *.DS_store 4 | data/* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Rationale-alignment 2 | Pytorch Library for ACL 2020 paper: 3 | [Rationalizing Text Matching:Learning Sparse Alignments via Optimal Transport](https://arxiv.org/pdf/2005.13111.pdf) 4 | 5 | ## Usage 6 | [Data](https://drive.google.com/file/d/1prhH-tZZ-gGHj2-PdikGosHZT_SDBRyr/view?usp=sharing): please down load data and unzip to `data` folder 7 | 8 | Embedding: we use fasttext embedding [cc.en.300.bin](https://fasttext.cc/docs/en/crawl-vectors.html), please download and specify the `embedding_path` in `utils/parsing.py` 9 | 10 | ## Replicate results from the paper 11 | 12 | 13 | ## Example usage 14 | 15 | 16 | ## Cite 17 | ```sh 18 | @inproceedings{swanson-etal-2020-rationalizing, 19 | title = "Rationalizing Text Matching: {L}earning Sparse Alignments via Optimal Transport", 20 | author = "Swanson, Kyle and Yu, Lili and Lei, Tao", 21 | booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics", 22 | year = "2020", 23 | publisher = "Association for Computational Linguistics", 24 | url = "https://www.aclweb.org/anthology/2020.acl-main.496", 25 | } 26 | ``` 27 | 28 | -------------------------------------------------------------------------------- /classify/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asappresearch/rationale-alignment/8d2bf06ba4c121863833094d5d4896bf34a9a73e/classify/__init__.py -------------------------------------------------------------------------------- /classify/compute/__init__.py: -------------------------------------------------------------------------------- 1 | from classify.compute.trainer import AlignmentTrainer 2 | 3 | __all__ = ["AlignmentTrainer"] 4 | -------------------------------------------------------------------------------- /classify/data/__init__.py: -------------------------------------------------------------------------------- 1 | from classify.data.load import load_data 2 | from classify.data.sampler import Sampler 3 | 4 | __all__ = ["load_data", "Sampler"] 5 | -------------------------------------------------------------------------------- /classify/data/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Set 2 | 3 | import torch 4 | 5 | 6 | class Dataset: 7 | def __init__(self, 8 | ids: Set[str], 9 | id_to_document: Dict[str, List[torch.LongTensor]], 10 | id_mapping: Dict[str, Dict[str, Set[str]]], 11 | negative_ids: Set[id] = None, 12 | label_map: Dict[str, int] = None, 13 | evidence: Dict[str,list] = None, 14 | id_to_sentlength: Dict[str,list] = None): 15 | """ 16 | Holds an AskUbuntu alignment dataset. 17 | 18 | :param ids: A set of ids from which to sample during training. 19 | Note: May not contain all ids since some ids should not be sampled. 20 | :param id_to_document: A dictionary mapping ids to a dictionary 21 | which maps "sentences" to the sentences in the document. 22 | :param id_mapping: A dictionary mapping ids to a dictionary which maps 23 | "similar" to similar ids and "dissimilar" to dissimilar ids. 24 | :param negative_ids: The set of ids which can be sampled as negatives. 25 | If None, any id can be sampled as a negative. 26 | :param id_to_sentlength: save the length of sentences in document. Only used in multiRC with bert model. 27 | """ 28 | self.id_set = ids 29 | self.id_list = sorted(self.id_set) 30 | self.id_to_document = id_to_document 31 | self.id_mapping = id_mapping 32 | self.negative_ids = negative_ids or self.id_set 33 | self.label_map = label_map 34 | self.evidence = evidence 35 | self.id_to_sentlength = id_to_sentlength 36 | 37 | def __len__(self) -> int: 38 | return len(self.id_set) 39 | -------------------------------------------------------------------------------- /classify/data/load.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | 5 | # from classify.data.loaders import ( 6 | # AskUbuntuDataLoader, 7 | # MultiNewsDataLoader, 8 | # SummaryDataLoader, 9 | # ) 10 | # from classify.data.loaders import PubmedDataLoader, PubmedSummaryDataLoader 11 | from classify.data.loaders.snli import SNLIDataLoader 12 | from classify.data.loaders.multirc import MultircDataLoader 13 | 14 | from classify.data.text import TextField 15 | from utils.parsing import Arguments 16 | from classify.data.sampler import Sampler 17 | 18 | 19 | def load_data( 20 | args: Arguments, device: torch.device 21 | ) -> Tuple[TextField, Sampler, Sampler, Sampler]: 22 | """Loads data and returns a TextField and train, dev, and test Samplers.""" 23 | # Default to sampling negatives 24 | resample_negatives = True 25 | 26 | print("initializing dataloader") 27 | # Get DataLoader 28 | if args.dataset == "snli": 29 | from classify.data.snli_sampler import SNLISampler as Sampler 30 | 31 | data_loader = SNLIDataLoader(args) 32 | # elif args.dataset == "multirc" and args.word_to_word: 33 | # from classify.data.multirc_word_sampler import MultircWSampler as Sampler 34 | 35 | # data_loader = MultircDataLoader(args) 36 | elif args.dataset == "multirc": 37 | from classify.data.multirc_sent_sampler import MultircSentSampler as Sampler 38 | 39 | data_loader = MultircDataLoader(args) 40 | else: 41 | raise ValueError(f'Dataset "{args.dataset}" not supported') 42 | 43 | print("initializing sampler") 44 | # Create Samplers 45 | train_sampler = Sampler( 46 | data=data_loader.train, 47 | text_field=data_loader.text_field, 48 | batch_size=args.batch_size, 49 | shuffle=True, 50 | num_positives=args.num_positives, 51 | num_negatives=args.num_negatives, 52 | resample_negatives=resample_negatives, 53 | device=device, 54 | ) 55 | 56 | dev_sampler = Sampler( 57 | data=data_loader.dev, 58 | text_field=data_loader.text_field, 59 | batch_size=args.batch_size, 60 | num_positives=args.num_eval_positives, 61 | num_negatives=args.num_eval_negatives, 62 | device=device, 63 | ) 64 | 65 | test_sampler = Sampler( 66 | data=data_loader.test, 67 | text_field=data_loader.text_field, 68 | batch_size=args.batch_size, 69 | num_positives=args.num_eval_positives, 70 | num_negatives=args.num_eval_negatives, 71 | device=device, 72 | ) 73 | 74 | return data_loader.text_field, train_sampler, dev_sampler, test_sampler 75 | -------------------------------------------------------------------------------- /classify/data/loaders/__init__.py: -------------------------------------------------------------------------------- 1 | # from classify.data.loaders.snli import SNLIDataLoader 2 | 3 | # __all__ = ["SNLIDataLoader"] 4 | -------------------------------------------------------------------------------- /classify/data/loaders/loader.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | from classify.data.dataset import Dataset 4 | from classify.data.text import TextField 5 | 6 | 7 | class DataLoader: 8 | @property 9 | @abstractmethod 10 | def train(self) -> Dataset: 11 | """Returns the training data.""" 12 | pass 13 | 14 | @property 15 | @abstractmethod 16 | def dev(self) -> Dataset: 17 | """Returns the validation data.""" 18 | pass 19 | 20 | @property 21 | @abstractmethod 22 | def test(self) -> Dataset: 23 | """Returns the test data.""" 24 | pass 25 | 26 | @property 27 | @abstractmethod 28 | def text_field(self) -> TextField: 29 | """Returns the text field.""" 30 | pass 31 | 32 | def print_stats(self) -> None: 33 | """Prints statistics about the data.""" 34 | print() 35 | print(f'Total size = {len(self.train) + len(self.dev) + len(self.test):,}') 36 | print() 37 | print(f'Train size = {len(self.train):,}') 38 | print(f'Dev size = {len(self.dev):,}') 39 | print(f'Test size = {len(self.test):,}') 40 | print() 41 | # print(f'Vocabulary size = {len(self.text_field.vocabulary):,}') 42 | print() 43 | -------------------------------------------------------------------------------- /classify/data/loaders/multirc.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import Counter, defaultdict 3 | from copy import deepcopy 4 | from tqdm import tqdm 5 | from typing import Dict, List, Set 6 | import json 7 | import torch 8 | 9 | from classify.data.dataset import Dataset 10 | from classify.data.loaders.loader import DataLoader 11 | from utils.parsing import MultircArguments 12 | from classify.data.text import TextField 13 | from utils.berttokenizer import BTTokenizer 14 | from classify.data.utils import split_data, text_to_sentences 15 | 16 | import jsonlines 17 | from collections import defaultdict 18 | import random 19 | 20 | 21 | class MultircDataLoader(DataLoader): 22 | def __init__(self, args: MultircArguments): 23 | """Loads the pubmed dataset.""" 24 | 25 | # Determine word to index mapping 26 | # self.small_data = args.small_data #dataset small enough 27 | self.args = args 28 | id_to_document = self.load_text(args.multirc_path) 29 | self.id_to_document = id_to_document 30 | 31 | # Load data 32 | train_label, train_evidences = self.load_label( 33 | args.multirc_path, "train", args.small_data 34 | ) 35 | dev_label, dev_evidences = self.load_label( 36 | args.multirc_path, "val", args.small_data 37 | ) 38 | test_label, test_evidences = self.load_label( 39 | args.multirc_path, "test", args.small_data 40 | ) 41 | 42 | train_id_mapping = {idx: idx.split(":")[0] for idx in train_label.keys()} 43 | dev_id_mapping = {idx: idx.split(":")[0] for idx in dev_label.keys()} 44 | test_id_mapping = {idx: idx.split(":")[0] for idx in test_label.keys()} 45 | train_ids = set(train_label.keys()) 46 | dev_ids = set(dev_label.keys()) 47 | test_ids = set(test_label.keys()) 48 | 49 | if args.bert: 50 | self._text_field = BTTokenizer(args) 51 | else: 52 | if self.args.word_to_word: 53 | texts = list(id_to_document.values()) 54 | else: 55 | texts = [x for doc in list(id_to_document.values()) for x in doc] 56 | self._text_field = TextField() 57 | self._text_field.build_vocab(texts) 58 | 59 | sampled = { 60 | k: id_to_document[k] for k in random.sample(list(id_to_document.keys()), 10) 61 | } 62 | print(sampled) 63 | 64 | # Convert sentences to indices 65 | if self.args.word_to_word: 66 | id_to_doctoken: Dict[str, List[torch.LongTensor]] = { 67 | id: self._text_field.process(text)[: args.max_sentence_length] 68 | for id, text in tqdm(id_to_document.items()) 69 | } 70 | id_to_lengths = {} 71 | else: 72 | id_to_doctoken: Dict[str, List[torch.LongTensor]] = { 73 | id: [self._text_field.process(sentence) for sentence in document] 74 | for id, document in id_to_document.items() 75 | } 76 | id_to_sentlength = { 77 | id: [len(sentence.split()) for sentence in document] 78 | for id, document in id_to_document.items() 79 | } 80 | 81 | print(len(id_to_document)) 82 | print(len(id_to_doctoken)) 83 | sampled = { 84 | k: id_to_doctoken[k] for k in random.sample(list(id_to_doctoken.keys()), 10) 85 | } 86 | print(sampled) 87 | # import sys; sys.exit() 88 | # Define train, dev, test datasets 89 | 90 | self._train = Dataset( 91 | ids=train_ids, 92 | id_to_document=id_to_doctoken, 93 | id_mapping=train_id_mapping, 94 | label_map=train_label, 95 | evidence=train_evidences, 96 | id_to_sentlength=id_to_sentlength, 97 | ) 98 | self._dev = Dataset( 99 | ids=dev_ids, 100 | id_to_document=id_to_doctoken, 101 | id_mapping=dev_id_mapping, 102 | label_map=dev_label, 103 | evidence=dev_evidences, 104 | id_to_sentlength=id_to_sentlength, 105 | ) 106 | self._test = Dataset( 107 | ids=test_ids, 108 | id_to_document=id_to_doctoken, 109 | id_mapping=test_id_mapping, 110 | label_map=test_label, 111 | evidence=test_evidences, 112 | id_to_sentlength=id_to_sentlength, 113 | ) 114 | 115 | self.print_stats() 116 | 117 | def load_text( 118 | self, path: str, small_data: bool = False 119 | ) -> List[List[List[List[str]]]]: 120 | # return mapping from id to text, text list is splited as sentences. 121 | data = defaultdict(dict) 122 | print(f"reading text from {path}") 123 | # Reading the documents 124 | files = os.listdir(os.path.join(path, "docs")) 125 | for f in files: 126 | docid = f 127 | text = open(os.path.join(path, "docs", f), "r").readlines() 128 | text = [x.strip() for x in text] 129 | data[docid] = " ".join(text) if self.args.word_to_word else text 130 | for flavor in ["train", "test", "val"]: 131 | label_path = os.path.join(path, flavor + ".jsonl") 132 | print(f"reading questions from {label_path}") 133 | reader = jsonlines.Reader(open(label_path)) 134 | for line in reader: 135 | docid = line["annotation_id"] 136 | text = [x.strip() for x in line["query"].split("||")] 137 | assert len(text) == 2 138 | if len(text[0].split()) == 0 or len(text[1].split()) == 0: 139 | print(text) 140 | print("bad queries") 141 | else: 142 | data[docid] = " ".join(text) if self.args.word_to_word else text 143 | return data 144 | 145 | def load_label( 146 | self, path: str, flavor: str, small_data: bool = False 147 | ) -> List[List[List[List[str]]]]: 148 | label_path = os.path.join(path, flavor + ".jsonl") 149 | print(f"reading labels from {label_path}") 150 | labels = defaultdict(dict) 151 | evidences = {"token": defaultdict(list), "sentence": defaultdict(list)} 152 | label_toi = {"False": 0, "True": 1} 153 | reader = jsonlines.Reader(open(label_path)) 154 | for line in reader: 155 | label = line["classification"] 156 | idx = line["annotation_id"] 157 | if idx in self.id_to_document: 158 | docid0 = idx.split(":")[0] 159 | labels[idx] = label_toi[label] 160 | assert len(line["evidences"]) == 1 161 | for evi in line["evidences"][0]: 162 | assert evi["docid"] == docid0 163 | assert evi["end_sentence"] - evi["start_sentence"] == 1 164 | evidences["token"][idx].append( 165 | (evi["start_token"], evi["end_token"]) 166 | ) 167 | evidences["sentence"][idx].append( 168 | (evi["start_sentence"], evi["end_sentence"]) 169 | ) 170 | if small_data and len(labels) > 1500: 171 | break 172 | return labels, evidences 173 | 174 | @property 175 | def train(self) -> Dataset: 176 | return self._train 177 | 178 | @property 179 | def dev(self) -> Dataset: 180 | return self._dev 181 | 182 | @property 183 | def test(self) -> Dataset: 184 | return self._test 185 | 186 | @property 187 | def text_field(self) -> TextField: 188 | return self._text_field 189 | 190 | -------------------------------------------------------------------------------- /classify/data/loaders/snli.py: -------------------------------------------------------------------------------- 1 | from collections import Counter, defaultdict 2 | from copy import deepcopy 3 | from tqdm import tqdm 4 | from typing import Dict, List, Set 5 | import json 6 | import torch 7 | 8 | from classify.data.dataset import Dataset 9 | from classify.data.loaders.loader import DataLoader 10 | from utils.parsing import SNLIArguments 11 | from classify.data.text import TextField 12 | from classify.data.utils import split_data, text_to_sentences 13 | 14 | import jsonlines 15 | from collections import defaultdict 16 | import random 17 | 18 | 19 | class SNLIDataLoader(DataLoader): 20 | def __init__(self, args: SNLIArguments): 21 | """Loads the pubmed dataset.""" 22 | 23 | # Determine word to index mapping 24 | self.small_data = args.small_data 25 | 26 | # Load data 27 | train_label, train_evidences = self.load_label( 28 | args.snli_path, "train", args.small_data 29 | ) 30 | dev_label, dev_evidences = self.load_label( 31 | args.snli_path, "val", args.small_data 32 | ) 33 | test_label, test_evidences = self.load_label( 34 | args.snli_path, "test", args.small_data 35 | ) 36 | 37 | train_id_mapping = { 38 | k + "_premise": k + "_hypothesis" for k in train_label.keys() 39 | } 40 | dev_id_mapping = {k + "_premise": k + "_hypothesis" for k in dev_label.keys()} 41 | test_id_mapping = {k + "_premise": k + "_hypothesis" for k in test_label.keys()} 42 | train_ids = set(train_label.keys()) 43 | dev_ids = set(dev_label.keys()) 44 | test_ids = set(test_label.keys()) 45 | 46 | if self.small_data: 47 | allids = ( 48 | list(train_id_mapping.keys()) 49 | + list(train_id_mapping.values()) 50 | + list(dev_id_mapping.keys()) 51 | + list(dev_id_mapping.values()) 52 | + list(test_id_mapping.keys()) 53 | + list(test_id_mapping.values()) 54 | ) 55 | id_to_text = self.load_text(args.snli_path) 56 | id_to_text = {k: v for k, v in id_to_text.items() if k in allids} 57 | else: 58 | id_to_text = self.load_text(args.snli_path) 59 | 60 | texts = list(id_to_text.values()) 61 | self._text_field = TextField() 62 | self._text_field.build_vocab(texts) 63 | 64 | sampled = {k: id_to_text[k] for k in random.sample(list(id_to_text.keys()), 10)} 65 | print(sampled) 66 | 67 | # Convert sentences to indices 68 | id_to_doctoken: Dict[str, List[torch.LongTensor]] = { 69 | idx: self._text_field.process(text) 70 | # for id, sentence in tqdm(id_to_doctoken.items()) 71 | for idx, text in tqdm(id_to_text.items()) 72 | } 73 | 74 | print(len(id_to_text)) 75 | print(len(id_to_doctoken)) 76 | sampled = { 77 | k: id_to_doctoken[k] for k in random.sample(list(id_to_doctoken.keys()), 10) 78 | } 79 | print(sampled) 80 | # import sys; sys.exit() 81 | # Define train, dev, test datasets 82 | self._train = Dataset( 83 | ids=train_ids, 84 | id_to_document=id_to_doctoken, 85 | id_mapping=train_id_mapping, 86 | label_map=train_label, 87 | evidence=train_evidences, 88 | ) 89 | self._dev = Dataset( 90 | ids=dev_ids, 91 | id_to_document=id_to_doctoken, 92 | id_mapping=dev_id_mapping, 93 | label_map=dev_label, 94 | evidence=dev_evidences, 95 | ) 96 | self._test = Dataset( 97 | ids=test_ids, 98 | id_to_document=id_to_doctoken, 99 | id_mapping=test_id_mapping, 100 | label_map=test_label, 101 | evidence=test_evidences, 102 | ) 103 | 104 | self.print_stats() 105 | 106 | @staticmethod 107 | def load_text(path: str, small_data: bool = False) -> List[List[List[List[str]]]]: 108 | data = defaultdict(dict) 109 | print(f"reading text from {path}") 110 | reader = jsonlines.Reader(open(path)) 111 | for line in reader: 112 | # doc, side = line['docid'].split('_') 113 | # data[doc][side] = line['document'] 114 | data[line["docid"]] = line["document"] 115 | return data 116 | 117 | @staticmethod 118 | def load_label( 119 | path: str, flavor: str, small_data: bool = False 120 | ) -> List[List[List[List[str]]]]: 121 | label_path = path.replace("docs", flavor) 122 | print(f"reading labels from {label_path}") 123 | labels = defaultdict(dict) 124 | evidences = defaultdict(list) 125 | label_toi = {"entailment": 0, "contradiction": 1, "neutral": 2} 126 | 127 | reader = jsonlines.Reader(open(label_path)) 128 | for line in reader: 129 | label = line["classification"] 130 | idx = line["annotation_id"] 131 | labels[idx] = label_toi[label] 132 | evidences[label + "_hypothesis"] = [] 133 | evidences[label + "_premise"] = [] 134 | for evi in line["evidences"][0]: 135 | evidences[evi["docid"]].append((evi["start_token"], evi["end_token"])) 136 | if small_data and len(labels) > 1500: 137 | break 138 | return labels, evidences 139 | 140 | @property 141 | def train(self) -> Dataset: 142 | return self._train 143 | 144 | @property 145 | def dev(self) -> Dataset: 146 | return self._dev 147 | 148 | @property 149 | def test(self) -> Dataset: 150 | return self._test 151 | 152 | @property 153 | def text_field(self) -> TextField: 154 | return self._text_field 155 | -------------------------------------------------------------------------------- /classify/data/multirc_sent_sampler.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Dict, Iterator, List, Optional, Set, Tuple 3 | 4 | import torch 5 | from torch.nn.utils.rnn import pad_sequence 6 | from tqdm import trange 7 | 8 | from classify.data.text import TextField 9 | from classify.data.dataset import Dataset 10 | 11 | import numpy as np 12 | 13 | 14 | class MultircSentSampler: 15 | def __init__( 16 | self, 17 | data: Dataset, 18 | text_field: TextField, 19 | batch_size: int, 20 | shuffle: bool = False, 21 | num_positives: Optional[int] = None, 22 | num_negatives: Optional[int] = None, 23 | resample_negatives: bool = False, 24 | seed: int = 0, 25 | device: torch.device = torch.device("cpu"), 26 | ): 27 | """ 28 | Constructs a SimilarityDataSampler. 29 | 30 | :param data: A Dataset. 31 | :param text_field: The TextField object initialized with all text data. 32 | :param batch_size: Batch size. 33 | :param shuffle: Whether to shuffle the data. 34 | :param num_positives: Number of positives per example. Defaults to all of them. 35 | :param num_negatives: Number of negatives per example. Defaults to all of them. 36 | :param resample_negatives: Whether to resample negatives after each epoch. 37 | :param seed: Initial random seed. 38 | :param device: The torch device to broadcast to. 39 | """ 40 | self.data = data 41 | self.text_field = text_field 42 | self.batch_size = batch_size 43 | self.shuffle = shuffle 44 | self.num_positives = num_positives 45 | self.num_negatives = num_negatives 46 | self.resample_negatives = resample_negatives 47 | self.seed = seed 48 | self.device = device 49 | self.id_to_sentlength = data.id_to_sentlength 50 | 51 | # print('initialize special sampler for multirc') 52 | 53 | # self.pad_index = self.text_field.vocabulary[self.text_field.pad] 54 | self.pad_index = self.text_field.pad_index() 55 | 56 | def sample( 57 | self, 58 | ) -> Iterator[ 59 | Tuple[ 60 | torch.LongTensor, 61 | List[Tuple[torch.LongTensor, torch.LongTensor]], 62 | List[Dict[str, torch.LongTensor]], 63 | ] 64 | ]: 65 | """ 66 | Samples pairs of similar/dissimilar documents. 67 | 68 | :return: A tuple consisting of: 69 | 1) batch_sentences: A tensor with all the sentences that need to be encoded (num_sentences x sentence_length). 70 | 2) batch_scope: A list of tuples of tensors indicating the indices in batch_sentences 71 | corresponding to each of the two documents being compared. 72 | 3) batch_targets: A dictionary mapping to the binary targets for each document pair 73 | and mapping to the indices of all pairs, positive pairs, and negative pairs. 74 | """ 75 | # Seed 76 | self.seed += 1 77 | random.seed(self.seed) 78 | 79 | # Shuffle 80 | if self.shuffle: 81 | random.shuffle(self.data.id_list) 82 | 83 | # Iterate through batcches of data 84 | for i in trange(0, len(self.data), self.batch_size): 85 | # Get batch ids 86 | batch_document_ids = self.data.id_list[i : i + self.batch_size] 87 | 88 | # Initialize batch variables 89 | sentence_index = label_scope = 0 90 | batch_sentences, batch_scope, batch_targets = [], [], [] 91 | scope, positives, negatives, targets = [], [], [], [] 92 | id_to_scope = {} 93 | # id_to_lengths = {} 94 | 95 | # print(f'\n the {i} samples') 96 | # if i == 250: 97 | # import pdb; pdb.set_trace() 98 | 99 | for query_id in batch_document_ids: 100 | docid = query_id.split(":")[0] 101 | 102 | id_to_scope[query_id] = [] 103 | for sentence in self.data.id_to_document[query_id]: 104 | batch_sentences.append(sentence) 105 | id_to_scope[query_id].append(sentence_index) 106 | sentence_index += 1 107 | if docid not in id_to_scope: 108 | id_to_scope[docid] = [] 109 | # id_to_lengths[docid] = [] 110 | for sentence in self.data.id_to_document[docid]: 111 | batch_sentences.append(sentence) 112 | # id_to_lengths[docid].append(len(sentence)) ## Need to change how this is computed. 113 | id_to_scope[docid].append(sentence_index) 114 | sentence_index += 1 115 | 116 | batch_scope.append( 117 | ( 118 | torch.LongTensor(id_to_scope[query_id]).to(self.device), 119 | torch.LongTensor(id_to_scope[docid]).to(self.device), 120 | ) 121 | ) 122 | 123 | row_r = np.zeros(len(self.data.id_to_document[query_id])) 124 | # for s,e in self.data.evidence[id+'_premise']: 125 | # row_r[s:e] = 1 126 | column_r = np.zeros(len(self.data.id_to_document[docid])) 127 | for s, e in self.data.evidence["sentence"][query_id]: 128 | column_r[s:e] = 1 129 | 130 | # if self.data.label_map[query_id]==0: 131 | # positives.append(label_scope) 132 | # else: 133 | # negatives.append(label_scope) 134 | 135 | lengths = self.id_to_sentlength[docid] 136 | w_column_r = np.zeros(sum(lengths)) 137 | for s, e in self.data.evidence["token"][query_id]: 138 | w_column_r[s:e] = 1 139 | 140 | new_column_r = rationale_sent_to_token(lengths, column_r) 141 | assert np.array_equal(w_column_r, new_column_r) 142 | 143 | batch_targets.append( 144 | { 145 | "annotationid": query_id, 146 | "docid": docid, 147 | "lengths": self.id_to_sentlength[docid], 148 | "scope": torch.LongTensor([label_scope]).to(self.device), 149 | "row_evidence": torch.LongTensor(row_r).to(self.device), 150 | "column_evidence": torch.LongTensor(column_r).to(self.device), 151 | # 'positives': torch.LongTensor(positives).to(self.device), 152 | # 'negatives': torch.LongTensor(negatives).to(self.device), 153 | "targets": torch.LongTensor([self.data.label_map[query_id]]).to( 154 | self.device 155 | ), 156 | } 157 | ) 158 | label_scope += 1 159 | # batch_targets.append(self.data.label_map[id]) 160 | 161 | # batch_targets = torch.LongTensor(batch_targets).to(self.device) 162 | 163 | # Pad sentences 164 | batch_sentences = pad_sequence( 165 | batch_sentences, batch_first=True, padding_value=self.pad_index 166 | ) 167 | # Convert sentences to tensors 168 | batch_sentences = torch.LongTensor(batch_sentences).to(self.device) 169 | 170 | assert len(batch_scope) == len(batch_targets) 171 | # assert len(batch_sentences) == 2*len(batch_targets) 172 | 173 | yield batch_sentences, batch_scope, batch_targets 174 | 175 | def __len__(self) -> int: 176 | """Return the number of batches in the sampler.""" 177 | return len(self.data) // self.batch_size 178 | 179 | def __call__( 180 | self, 181 | ) -> Iterator[ 182 | Tuple[ 183 | torch.LongTensor, 184 | List[Tuple[torch.LongTensor, torch.LongTensor]], 185 | List[Dict[str, torch.LongTensor]], 186 | ] 187 | ]: 188 | return self.sample() 189 | 190 | 191 | def rationale_sent_to_token(lengths, rationales): 192 | total_l = sum(lengths) 193 | r_tk = np.zeros(total_l) 194 | for i, s in enumerate(list(rationales)): 195 | if s != 0: 196 | r_tk[sum(lengths[:i]) : sum(lengths[: i + 1])] = 1 197 | return r_tk 198 | -------------------------------------------------------------------------------- /classify/data/sampler.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Dict, Iterator, List, Optional, Set, Tuple 3 | 4 | import torch 5 | from torch.nn.utils.rnn import pad_sequence 6 | from tqdm import trange 7 | 8 | from classify.data.text import TextField 9 | from classify.data.dataset import Dataset 10 | 11 | 12 | class Sampler: 13 | def __init__(self, 14 | data: Dataset, 15 | text_field: TextField, 16 | batch_size: int, 17 | shuffle: bool = False, 18 | num_positives: Optional[int] = None, 19 | num_negatives: Optional[int] = None, 20 | resample_negatives: bool = False, 21 | seed: int = 0, 22 | device: torch.device = torch.device('cpu')): 23 | """ 24 | Constructs a SimilarityDataSampler. 25 | 26 | :param data: A Dataset. 27 | :param text_field: The TextField object initialized with all text data. 28 | :param batch_size: Batch size. 29 | :param shuffle: Whether to shuffle the data. 30 | :param num_positives: Number of positives per example. Defaults to all of them. 31 | :param num_negatives: Number of negatives per example. Defaults to all of them. 32 | :param resample_negatives: Whether to resample negatives after each epoch. 33 | :param seed: Initial random seed. 34 | :param device: The torch device to broadcast to. 35 | """ 36 | self.data = data 37 | self.text_field = text_field 38 | self.batch_size = batch_size 39 | self.shuffle = shuffle 40 | self.num_positives = num_positives 41 | self.num_negatives = num_negatives 42 | self.resample_negatives = resample_negatives 43 | self.seed = seed 44 | self.device = device 45 | 46 | # self.pad_index = self.text_field.vocabulary[self.text_field.pad] 47 | self.pad_index = self.text_field.pad_index() 48 | 49 | def sample_negatives(self, 50 | id: str, 51 | available_ids: Set[str]) -> Set[str]: 52 | """Samples negative ids for a document from a list of available of ids. 53 | 54 | :param id: The id of the document for which negatives should be sampled. 55 | :param available_ids: A list of ids available to sample from. 56 | :return: A (sorted) list of negative ids sampled from available ids. 57 | """ 58 | # IDs that can't be negatives 59 | non_negative_ids = self.data.id_mapping[id]['similar'] | {id} 60 | valid_negative_ids = self.data.negative_ids - non_negative_ids 61 | 62 | # Determine negative ids 63 | negative_ids = available_ids & valid_negative_ids 64 | if len(negative_ids) == 0: 65 | negative_ids = available_ids 66 | print('zero size') 67 | print(non_negative_ids) 68 | print(valid_negative_ids) 69 | print(available_ids) 70 | 71 | # Add more negatives from outside of available_ids if necessary 72 | if len(negative_ids) < self.num_negatives: 73 | candidate_ids = [candidate_id for candidate_id in self.data.id_list if candidate_id in valid_negative_ids] 74 | random.shuffle(candidate_ids) 75 | 76 | for candidate_id in candidate_ids: 77 | negative_ids.add(candidate_id) 78 | 79 | if len(negative_ids) >= self.num_negatives: 80 | break 81 | 82 | # Subsample negatives if too many 83 | if len(negative_ids) > self.num_negatives: 84 | negative_ids = sorted(negative_ids) 85 | random.shuffle(negative_ids) 86 | negative_ids = set(negative_ids[:self.num_negatives]) 87 | 88 | return negative_ids 89 | 90 | def sample(self) -> Iterator[Tuple[torch.LongTensor, 91 | List[Tuple[torch.LongTensor, torch.LongTensor]], 92 | List[Dict[str, torch.LongTensor]]]]: 93 | """ 94 | Samples pairs of similar/dissimilar documents. 95 | 96 | :return: A tuple consisting of: 97 | 1) batch_sentences: A tensor with all the sentences that need to be encoded (num_sentences x sentence_length). 98 | 2) batch_scope: A list of tuples of tensors indicating the indices in batch_sentences 99 | corresponding to each of the two documents being compared. 100 | 3) batch_targets: A dictionary mapping to the binary targets for each document pair 101 | and mapping to the indices of all pairs, positive pairs, and negative pairs. 102 | """ 103 | # Seed 104 | self.seed += 1 105 | random.seed(self.seed) 106 | 107 | # Shuffle 108 | if self.shuffle: 109 | random.shuffle(self.data.id_list) 110 | 111 | # Iterate through batches of data 112 | for i in trange(0, len(self.data), self.batch_size): 113 | # Get batch ids 114 | batch_document_ids = self.data.id_list[i:i + self.batch_size] 115 | 116 | if len(batch_document_ids) <4: 117 | continue 118 | 119 | # Get ids of all documents which will be encoded in this batch 120 | # (i.e. batch_document_ids plus all similar ids) 121 | batch_available_ids: Set[str] = set.union( 122 | set(batch_document_ids), 123 | *[self.data.id_mapping[document_id]['similar'] for document_id in batch_document_ids] 124 | ) 125 | 126 | # Initialize batch variables 127 | sentence_index = scope_index = 0 128 | id_to_scope = {} 129 | batch_sentences, batch_scope, batch_targets = [], [], [] 130 | 131 | # Loop through document ids and add sentences, scope, and targets 132 | for document_id in batch_document_ids: 133 | # Add scope and targets 134 | scope, positives, negatives, targets = [], [], [], [] 135 | 136 | # Get similar and dissimilar ids 137 | similar_ids = self.data.id_mapping[document_id]['similar'] 138 | dissimilar_ids = self.data.id_mapping[document_id]['dissimilar'] 139 | 140 | # Sample dissimilar ids if necessary 141 | if self.resample_negatives or len(dissimilar_ids) == 0: 142 | dissimilar_ids = self.sample_negatives( 143 | id=document_id, 144 | available_ids=batch_available_ids 145 | ) 146 | self.data.id_mapping[document_id]['dissimilar'] = dissimilar_ids 147 | 148 | # Subsample positives if too many 149 | if self.num_positives is not None and len(similar_ids) > self.num_positives: 150 | similar_ids = sorted(similar_ids) 151 | random.shuffle(similar_ids) 152 | similar_ids = set(similar_ids[:self.num_positives]) 153 | 154 | # Subsample negatives if too many 155 | if self.num_negatives is not None and len(dissimilar_ids) > self.num_negatives: 156 | dissimilar_ids = sorted(dissimilar_ids) 157 | random.shuffle(dissimilar_ids) 158 | dissimilar_ids = set(dissimilar_ids[:self.num_negatives]) 159 | 160 | # Add all sentences related to this document 161 | related_ids = set.union({document_id}, similar_ids, dissimilar_ids) 162 | new_ids = related_ids - set(id_to_scope.keys()) 163 | for id in new_ids: 164 | # Initialize scope for id 165 | id_to_scope[id] = [] 166 | 167 | # Add sentences 168 | for sentence in self.data.id_to_document[id]: 169 | batch_sentences.append(sentence) 170 | id_to_scope[id].append(sentence_index) 171 | # if len(id_to_scope[id]) > 1: 172 | # print(f'sentence longer than 1. {id} ') 173 | sentence_index += 1 174 | 175 | # Add similar document scope/target 176 | for similar_id in similar_ids: 177 | batch_scope.append((torch.LongTensor(id_to_scope[document_id]).to(self.device), 178 | torch.LongTensor(id_to_scope[similar_id]).to(self.device))) 179 | scope.append(scope_index) 180 | positives.append(scope_index) 181 | scope_index += 1 182 | targets.append(1) 183 | 184 | # Add dissimilar document scope/target 185 | for dissimilar_id in dissimilar_ids: 186 | batch_scope.append((torch.LongTensor(id_to_scope[document_id]).to(self.device), 187 | torch.LongTensor(id_to_scope[dissimilar_id]).to(self.device))) 188 | scope.append(scope_index) 189 | negatives.append(scope_index) 190 | scope_index += 1 191 | targets.append(0) 192 | 193 | batch_targets.append({ 194 | 'scope': torch.LongTensor(scope).to(self.device), 195 | 'positives': torch.LongTensor(positives).to(self.device), 196 | 'negatives': torch.LongTensor(negatives).to(self.device), 197 | 'targets': torch.LongTensor(targets).to(self.device) 198 | }) 199 | 200 | # Pad sentences 201 | batch_sentences = pad_sequence(batch_sentences, batch_first=True, padding_value=self.pad_index) 202 | 203 | # Convert sentences to tensors 204 | batch_sentences = torch.LongTensor(batch_sentences).to(self.device) 205 | 206 | yield batch_sentences, batch_scope, batch_targets 207 | 208 | def __len__(self) -> int: 209 | """Return the number of batches in the sampler.""" 210 | return len(self.data) // self.batch_size 211 | 212 | def __call__(self) -> Iterator[Tuple[torch.LongTensor, 213 | List[Tuple[torch.LongTensor, torch.LongTensor]], 214 | List[Dict[str, torch.LongTensor]]]]: 215 | return self.sample() 216 | -------------------------------------------------------------------------------- /classify/data/snli_sampler.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Dict, Iterator, List, Optional, Set, Tuple 3 | 4 | import torch 5 | from torch.nn.utils.rnn import pad_sequence 6 | from tqdm import trange 7 | 8 | from classify.data.text import TextField 9 | from classify.data.dataset import Dataset 10 | 11 | import numpy as np 12 | class SNLISampler: 13 | def __init__(self, 14 | data: Dataset, 15 | text_field: TextField, 16 | batch_size: int, 17 | shuffle: bool = False, 18 | num_positives: Optional[int] = None, 19 | num_negatives: Optional[int] = None, 20 | resample_negatives: bool = False, 21 | seed: int = 0, 22 | device: torch.device = torch.device('cpu')): 23 | """ 24 | Constructs a SimilarityDataSampler. 25 | 26 | :param data: A Dataset. 27 | :param text_field: The TextField object initialized with all text data. 28 | :param batch_size: Batch size. 29 | :param shuffle: Whether to shuffle the data. 30 | :param num_positives: Number of positives per example. Defaults to all of them. 31 | :param num_negatives: Number of negatives per example. Defaults to all of them. 32 | :param resample_negatives: Whether to resample negatives after each epoch. 33 | :param seed: Initial random seed. 34 | :param device: The torch device to broadcast to. 35 | """ 36 | self.data = data 37 | self.text_field = text_field 38 | self.batch_size = batch_size 39 | self.shuffle = shuffle 40 | self.num_positives = num_positives 41 | self.num_negatives = num_negatives 42 | self.resample_negatives = resample_negatives 43 | self.seed = seed 44 | self.device = device 45 | 46 | # self.pad_index = self.text_field.vocabulary[self.text_field.pad] 47 | self.pad_index = self.text_field.pad_index() 48 | 49 | def sample(self) -> Iterator[Tuple[torch.LongTensor, 50 | List[Tuple[torch.LongTensor, torch.LongTensor]], 51 | List[Dict[str, torch.LongTensor]]]]: 52 | """ 53 | Samples pairs of similar/dissimilar documents. 54 | 55 | :return: A tuple consisting of: 56 | 1) batch_sentences: A tensor with all the sentences that need to be encoded (num_sentences x sentence_length). 57 | 2) batch_scope: A list of tuples of tensors indicating the indices in batch_sentences 58 | corresponding to each of the two documents being compared. 59 | 3) batch_targets: A dictionary mapping to the binary targets for each document pair 60 | and mapping to the indices of all pairs, positive pairs, and negative pairs. 61 | """ 62 | # Seed 63 | self.seed += 1 64 | random.seed(self.seed) 65 | 66 | # Shuffle 67 | if self.shuffle: 68 | random.shuffle(self.data.id_list) 69 | 70 | # Iterate through batcches of data 71 | for i in trange(0, len(self.data), self.batch_size): 72 | # Get batch ids 73 | batch_document_ids = self.data.id_list[i:i + self.batch_size] 74 | 75 | # Initialize batch variables 76 | sentence_index = scope_index = 0 77 | batch_sentences, batch_scope, batch_targets = [], [], [] 78 | scope, positives, negatives, targets = [], [], [], [] 79 | 80 | for i, id in enumerate(batch_document_ids): 81 | batch_sentences.append(self.data.id_to_document[id+'_premise']) 82 | batch_sentences.append(self.data.id_to_document[id+'_hypothesis']) 83 | 84 | batch_scope.append((torch.LongTensor([2*i]).to(self.device), 85 | torch.LongTensor([2*i+1]).to(self.device))) 86 | positive = [i] if self.data.label_map[id]==0 else [] 87 | negative = [i] if self.data.label_map[id]==1 else [] 88 | 89 | row_r = np.zeros(len(self.data.id_to_document[id+'_premise'])) 90 | for s,e in self.data.evidence[id+'_premise']: 91 | row_r[s:e] = 1 92 | column_r = np.zeros(len(self.data.id_to_document[id+'_hypothesis'])) 93 | for s,e in self.data.evidence[id+'_hypothesis']: 94 | column_r[s:e] = 1 95 | 96 | batch_targets.append({ 97 | 'annotationid': id, 98 | 'scope': torch.LongTensor([scope_index]).to(self.device), 99 | 'row_evidence': torch.LongTensor(row_r).to(self.device), 100 | 'column_evidence': torch.LongTensor(column_r).to(self.device), 101 | 'targets': torch.LongTensor([self.data.label_map[id]]).to(self.device) 102 | }) 103 | scope_index += 1 104 | # batch_targets.append(self.data.label_map[id]) 105 | 106 | # batch_targets = torch.LongTensor(batch_targets).to(self.device) 107 | 108 | # Pad sentences 109 | batch_sentences = pad_sequence(batch_sentences, batch_first=True, padding_value=self.pad_index) 110 | # Convert sentences to tensors 111 | batch_sentences = torch.LongTensor(batch_sentences).to(self.device) 112 | 113 | assert len(batch_scope) == len(batch_targets) 114 | assert len(batch_sentences) == 2*len(batch_targets) 115 | 116 | yield batch_sentences, batch_scope, batch_targets 117 | 118 | def __len__(self) -> int: 119 | """Return the number of batches in the sampler.""" 120 | return len(self.data) // self.batch_size 121 | 122 | def __call__(self) -> Iterator[Tuple[torch.LongTensor, 123 | List[Tuple[torch.LongTensor, torch.LongTensor]], 124 | List[Dict[str, torch.LongTensor]]]]: 125 | return self.sample() 126 | -------------------------------------------------------------------------------- /classify/data/text.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import csv 3 | from itertools import chain 4 | import os 5 | from typing import Callable, Dict, Iterable, Optional 6 | 7 | import numpy as np 8 | import pandas as pd 9 | from sklearn.feature_extraction.text import TfidfVectorizer 10 | import torch 11 | from tqdm import tqdm 12 | 13 | 14 | class TextField: 15 | def __init__(self, 16 | skip_oov: bool = False, 17 | lower: bool = False, 18 | tokenizer: Callable[[str], Iterable[str]] = lambda text: text.split(), 19 | pad_token: Optional[str] = '', 20 | unk_token: Optional[str] = '', 21 | sos_token: Optional[str] = None, 22 | eos_token: Optional[str] = None, 23 | vocabulary: Optional[Dict[str, int]] = None): 24 | if vocabulary: 25 | self.vocabulary = OrderedDict((tok, i) for i, tok in enumerate(vocabulary)) 26 | else: 27 | # Add specials 28 | self.vocabulary = OrderedDict() 29 | 30 | specials = [pad_token, unk_token, sos_token, eos_token] 31 | 32 | for token in filter(lambda x: x is not None, specials): 33 | self.vocabulary[token] = len(self.vocabulary) 34 | 35 | self.pad = pad_token 36 | self.unk = unk_token 37 | self.sos = sos_token 38 | self.eos = eos_token 39 | 40 | self.lower = lower 41 | self.tokenizer = tokenizer 42 | 43 | self._embeddings = None 44 | 45 | self._build_reverse_vocab() 46 | self.skip_oov = skip_oov 47 | self.weights = {} 48 | self.avg_weight = 1.0 49 | 50 | def build_idf_weights(self, *data: Iterable[str]) -> torch.FloatTensor: 51 | """Build IDF weights. 52 | 53 | Parameters 54 | ---------- 55 | data : Iterable[str] 56 | List of input strings. 57 | 58 | """ 59 | data = chain.from_iterable(data) 60 | if self.lower: 61 | data = [text.lower() for text in data] 62 | 63 | vectorizer = TfidfVectorizer(min_df=1, ngram_range=(1, 1), binary=False) 64 | vectorizer.fit(data) 65 | 66 | self.weights = { 67 | word: idf 68 | for word, idf in zip(vectorizer.get_feature_names(), vectorizer.idf_) 69 | if word in self.vocabulary 70 | } 71 | self.avg_weight = np.mean(list(self.weights.values())) 72 | 73 | def _init_vocabulary(self) -> None: 74 | """Initializes vocabulary with special tokens.""" 75 | # Add specials 76 | self.vocabulary = OrderedDict() 77 | 78 | specials = [self.pad, self.unk, self.sos, self.eos] 79 | 80 | for token in filter(lambda x: x is not None, specials): 81 | self.vocabulary[token] = len(self.vocabulary) # type: ignore 82 | 83 | def load_vocab(self, path: str) -> Dict[str, int]: 84 | """Loads a vocabulary from a .txt file. 85 | 86 | Returns 87 | ------- 88 | Dict[str, int] 89 | A vocabulary dictionary mapping from string to int. 90 | 91 | """ 92 | self._init_vocabulary() 93 | 94 | with open(path) as f: 95 | words = [word for line in f for word in line.strip().split()] 96 | 97 | for word in words: 98 | self.vocabulary[word] = len(self.vocabulary) 99 | 100 | return self.vocabulary 101 | 102 | def build_vocab(self, data: Iterable[str], *args) -> Dict[str, int]: 103 | """Build the vocabulary. 104 | Parameters 105 | ---------- 106 | data : Iterable[str] 107 | List of input strings. 108 | """ 109 | datasets = [data] + list(args) 110 | for dataset in datasets: 111 | for example in tqdm(dataset): 112 | # Lowercase if requested 113 | example = example.lower() if self.lower else example 114 | # Tokenize and add to vocabulary 115 | for token in self.tokenizer(example): 116 | self.vocabulary.setdefault(token, len(self.vocabulary)) 117 | 118 | self._build_reverse_vocab() 119 | 120 | return self.vocabulary 121 | 122 | def load_embeddings(self, path: str) -> torch.FloatTensor: 123 | """Load pretrained word embeddings. 124 | 125 | Parameters 126 | ---------- 127 | path : str 128 | The path to the pretrained embeddings 129 | Returns 130 | ------- 131 | torch.FloatTensor 132 | The matrix of pretrained word embeddings 133 | 134 | """ 135 | ext = os.path.splitext(path)[-1] 136 | 137 | if ext == '.bin': # fasttext 138 | try: 139 | import fasttext 140 | except Exception: 141 | try: 142 | import fastText as fasttext 143 | except Exception: 144 | raise ValueError("fasttext not installed.") 145 | model = fasttext.load_model(path) 146 | vectors = [model.get_word_vector(token) * self.weights.get(token, self.avg_weight) for token in tqdm(self.vocabulary)] 147 | else: 148 | # Load any .txt or word2vec kind of format 149 | model = dict() 150 | data = pd.read_csv(path, sep=" ", index_col=0, header=None, quoting=csv.QUOTE_NONE) 151 | embedding_size = len(data.columns) 152 | for word, vector in data.iterrows(): 153 | if word in self.vocabulary: 154 | model[word] = np.array(vector.values) * self.weights.get(word, self.avg_weight) 155 | 156 | # Reorder according to self._vocab 157 | vectors = [model.get(token, np.zeros(embedding_size)) for token in self.vocabulary] 158 | 159 | self.embeddings = torch.FloatTensor(np.array(vectors)) 160 | return self.embeddings 161 | 162 | def process(self, example: str) -> torch.LongTensor: # type: ignore 163 | """Process an example, and create a Tensor. 164 | Parameters 165 | ---------- 166 | example: str 167 | The example to process, as a single string 168 | Returns 169 | ------- 170 | torch.LongTensor 171 | The processed example, tokenized and numericalized 172 | """ 173 | # Lowercase and tokenize 174 | example = example.lower() if self.lower else example 175 | tokens = self.tokenizer(example) 176 | 177 | # Add extra tokens 178 | if self.sos is not None: 179 | tokens = [self.sos] + list(tokens) 180 | if self.eos is not None: 181 | tokens = list(tokens) + [self.eos] 182 | 183 | # Numericalize 184 | numericals = [] 185 | for token in tokens: 186 | if token not in self.vocabulary: 187 | if self.unk is None or self.unk not in self.vocabulary: 188 | raise ValueError("Encounterd out-of-vocabulary token \ 189 | but the unk_token is either missing \ 190 | or not defined in the vocabulary.") 191 | else: 192 | token = self.unk 193 | 194 | numerical = self.vocabulary[token] # type: ignore 195 | numericals.append(numerical) 196 | 197 | processed = torch.LongTensor(numericals) 198 | return processed 199 | 200 | def _build_reverse_vocab(self) -> None: 201 | """Builds reverse vocabulary.""" 202 | self._reverse_vocab = {index: token for token, index in self.vocabulary.items()} 203 | 204 | def deprocess(self, indices: torch.LongTensor) -> str: 205 | """Converts indices to string.""" 206 | pad_index = self.vocabulary[self.pad] 207 | return ' '.join(self._reverse_vocab[index.item()] for index in indices if index != pad_index) 208 | 209 | def pad_index(self) -> int: 210 | return self.vocabulary[self.pad] 211 | -------------------------------------------------------------------------------- /classify/data/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Any, List, Optional, Tuple 3 | 4 | 5 | 6 | sentence_tokenizer = None 7 | 8 | 9 | def split_data(data: List[Any], 10 | sizes: Tuple[float, float, float] = (0.8, 0.1, 0.1), 11 | seed: int = 0) -> Tuple[List[Any], List[Any], List[Any]]: 12 | """ 13 | Randomly splits data into train, val, and test sets according to the provided sizes. 14 | 15 | :param data: The data to split into train, val, and test. 16 | :param sizes: The sizes of the train, val, and test sets (as a proportion of total size). 17 | :param seed: Random seed. 18 | :return: Train, val, and test sets. 19 | """ 20 | # Checks 21 | assert len(sizes) == 3 22 | assert all(0 <= size <= 1 for size in sizes) 23 | assert sum(sizes) == 1 24 | 25 | # Shuffle 26 | random.seed(seed) 27 | random.shuffle(data) 28 | 29 | # Determine split sizes 30 | train_size = int(sizes[0] * len(data)) 31 | train_val_size = int((sizes[0] + sizes[1]) * len(data)) 32 | 33 | # Split 34 | train = data[:train_size] 35 | val = data[train_size:train_val_size] 36 | test = data[train_val_size:] 37 | 38 | return train, val, test 39 | 40 | 41 | def tokenize_sentence(text: str) -> List[str]: 42 | """ 43 | Tokenizes text into sentences. 44 | 45 | :param text: A string. 46 | :return: A list of sentences. 47 | """ 48 | global sentence_tokenizer 49 | 50 | if sentence_tokenizer is None: 51 | import nltk 52 | sentence_tokenizer = nltk.load('tokenizers/punkt/english.pickle') 53 | 54 | return sentence_tokenizer.tokenize(text) 55 | 56 | 57 | def text_to_sentences(text: str, 58 | tokenizer: str='sentence', 59 | sentence_tokenize: bool = True, 60 | max_num_sentences: Optional[int] = None, 61 | max_sentence_length: Optional[int] = None) -> List[str]: 62 | """ 63 | Splits text into sentences (if desired). 64 | 65 | Also enforces a maximum sentence length 66 | and maximum number of sentences. 67 | 68 | :param text: The text to split. 69 | :param sentence_tokenize: Whether to split into sentences. 70 | :param max_num_sentences: Maximum number of sentences. 71 | :param max_sentence_length: Maximum length of a sentence (in tokens). 72 | :return: The text split into sentences (if desired) 73 | or as just a single sentence. 74 | """ 75 | # Sentence tokenize 76 | if sentence_tokenize: 77 | sentences = tokenize_sentence(text)[:max_num_sentences] 78 | else: 79 | sentences = [text] 80 | 81 | # Enforce maximum sentence length 82 | sentences = [' '.join(sentence.split()[:max_sentence_length]) for sentence in sentences] 83 | 84 | return sentences 85 | 86 | def pubmed_tokenizer(text: str, 87 | tokenizer: str='sentence', 88 | # predictor: Predictor=None, 89 | max_num_sentences: Optional[int] = None, 90 | max_sentence_length: Optional[int] = None) -> List[str]: 91 | pass 92 | 93 | ''' 94 | def pubmed_tokenizer(text: str, 95 | tokenizer: str='sentence', 96 | predictor: Predictor=None, 97 | max_num_sentences: Optional[int] = None, 98 | max_sentence_length: Optional[int] = None) -> List[str]: 99 | """ 100 | # from allennlp.predictors import Predictor 101 | Splits text into sentences (if desired). 102 | 103 | Also enforces a maximum sentence length 104 | and maximum number of sentences. 105 | 106 | :param text: The text to split. 107 | :param sentence_tokenize: Whether to split into sentences. 108 | :param max_num_sentences: Maximum number of sentences. 109 | :param max_sentence_length: Maximum length of a sentence (in tokens). 110 | :return: The text split into sentences (if desired) 111 | or as just a single sentence. 112 | """ 113 | # Sentence tokenize 114 | if tokenizer =='sentence': 115 | sentences = tokenize_sentence(text)[:max_num_sentences] 116 | elif tokenizer == 'word': 117 | sentences = [text] 118 | elif tokenizer == 'phrase': 119 | print('tokenizing phrase') 120 | full_sentences = tokenize_sentence(text)[:max_num_sentences] 121 | sentences = [] 122 | for sent in full_sentences: 123 | sentences.extend(phrase_tokenizer(sent, predictor, phrase_len=5)) 124 | sentences = sentences[:max_num_sentences*3] 125 | else: 126 | print('unknow tokenizer') 127 | # Enforce maximum sentence length 128 | sentences = [' '.join(sentence.split()[:max_sentence_length]) for sentence in sentences] 129 | 130 | return sentences 131 | ''' 132 | 133 | 134 | def process_pubmed_sentences(text: List[List[str]], 135 | sentence_tokenize: bool = True, 136 | max_num_sentences: Optional[int] = None, 137 | max_sentence_length: Optional[int] = None) -> List[str]: 138 | """ 139 | Splits text into sentences (if desired). 140 | Also enforces a maximum sentence length 141 | and maximum number of sentences. 142 | :param text: The text to split. 143 | :param sentence_tokenize: Whether to split into sentences. 144 | :param max_num_sentences: Maximum number of sentences. 145 | :param max_sentence_length: Maximum length of a sentence (in tokens). 146 | :return: The text split into sentences (if desired) 147 | or as just a single sentence. 148 | """ 149 | # Sentence tokenize 150 | if sentence_tokenize: 151 | sentences = text[:max_num_sentences] 152 | else: 153 | sentences = [word for sent in text for word in sent] 154 | 155 | # Enforce maximum sentence length 156 | sentences = [' '.join(sentence[:max_sentence_length]) for sentence in sentences] 157 | 158 | return sentences -------------------------------------------------------------------------------- /classify/metric/__init__.py: -------------------------------------------------------------------------------- 1 | from classify.metric.abstract import Metric 2 | from classify.metric.load import load_loss_and_metrics 3 | 4 | __all__ = ["load_loss_and_metrics", "Metric"] 5 | -------------------------------------------------------------------------------- /classify/metric/abstract.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Dict, List, Optional, Tuple, Union 3 | 4 | import torch 5 | 6 | from sinkhorn import compute_alignment_cost, compute_entropy 7 | 8 | 9 | class Metric: 10 | @abstractmethod 11 | def compute(self, preds, targets, *argv) -> torch.float: 12 | pass 13 | 14 | def __call__(self, preds, targets, *argv) -> torch.float: 15 | return self.compute(preds, targets, *argv) 16 | 17 | def __str__(self) -> str: 18 | return self.__class__.__name__ 19 | 20 | 21 | class AlignmentMetric(Metric): 22 | """Computes the metric for saying one document is aligned with another.""" 23 | 24 | @staticmethod 25 | def _compute_entropy(preds: List[Tuple[torch.FloatTensor, torch.FloatTensor]]) -> torch.FloatTensor: 26 | """Computes the entropy term (epislon * H(P)) of each (cost, alignment) tuple in preds.""" 27 | return torch.stack([compute_entropy(alignment) for cost, alignment in preds], dim=0) 28 | 29 | @staticmethod 30 | def _compute_cost(preds: List[Tuple[torch.FloatTensor, torch.FloatTensor]]) -> torch.FloatTensor: 31 | """Computes the alignment cost of each (cost, alignment) tuple in preds.""" 32 | return torch.stack([compute_alignment_cost(C=cost, P=alignment) for cost, alignment in preds], dim=0) 33 | 34 | @staticmethod 35 | def _compute_similarities(preds: List[Tuple[torch.FloatTensor, torch.FloatTensor]]) -> torch.FloatTensor: 36 | """Computes the alignment similarities (i.e. -cost) of each (cost, alignment) tuple in preds.""" 37 | return -AlignmentMetric._compute_cost(preds) 38 | 39 | @abstractmethod 40 | def compute(self, 41 | preds: List[Tuple[torch.FloatTensor, torch.FloatTensor]], 42 | targets: Union[List[torch.LongTensor], List[Dict[str, torch.LongTensor]]], 43 | step: Optional[int]) -> torch.float: 44 | pass 45 | 46 | 47 | class AlignmentAverageMetric(AlignmentMetric): 48 | """Computes a metric and averages it across documents.""" 49 | 50 | def __init__(self, similar: Optional[bool] = None): 51 | self.similar = similar # Whether to only include similar or only dissimilar examples 52 | 53 | @abstractmethod 54 | def _compute_one(self, 55 | cost: torch.FloatTensor, 56 | alignment: torch.FloatTensor, 57 | target: int) -> torch.float: 58 | """ 59 | Computes the metric and count of aligning two documents. 60 | 61 | :param cost: The cost of aligning sentence i with sentence j (matrix is n x m). 62 | :param alignment: The probability of aligning sentence i with sentence j (matrix is n x m). 63 | :param target: Whether the documents are similar or not. 64 | :return: The value. 65 | """ 66 | pass 67 | 68 | def _compute_count(self, 69 | cost: torch.FloatTensor, 70 | alignment: torch.FloatTensor, 71 | target: int) -> int: 72 | """ 73 | Computes the count of items associated with the documents for the purpose of averaging. 74 | 75 | :param cost: The cost of aligning sentence i with sentence j (matrix is n x m). 76 | :param alignment: The probability of aligning sentence i with sentence j (matrix is n x m). 77 | :param target: Whether the documents are similar or not. 78 | :return: The count (typically either # of sentences or 1). 79 | """ 80 | if self.similar is None: 81 | return 1 82 | 83 | return target == self.similar 84 | 85 | def compute(self, 86 | preds: List[Tuple[torch.FloatTensor, torch.FloatTensor]], 87 | targets: List[Dict[str, torch.LongTensor]]) -> torch.float: 88 | """ 89 | Computes metric across a list of instances of two sets of objects. 90 | 91 | :param preds: A list of (cost, alignment) tuples (each is n x m). 92 | :param targets: A list of LongTensors indicating the correct alignment. 93 | :return: The metric of the alignments. 94 | """ 95 | # Initialize 96 | metric, count = 0, 0 97 | 98 | # Extract targets 99 | targets = [t.item() for target in targets for t in target['targets']] 100 | 101 | # Check lengths 102 | assert len(preds) == len(targets) 103 | 104 | # Loop over alignments and add metric and count 105 | for (cost, alignment), target in zip(preds, targets): 106 | new_count = self._compute_count(cost, alignment, target) 107 | 108 | if new_count == 0: 109 | continue 110 | 111 | count += new_count 112 | metric += self._compute_one(cost, alignment, target) 113 | 114 | # Average metric 115 | metric = metric / count if count != 0 else 0 116 | 117 | return metric 118 | 119 | def __str__(self) -> str: 120 | super_str = super(AlignmentAverageMetric, self).__str__() 121 | 122 | if self.similar is None: 123 | return super_str 124 | 125 | return ('Similar' if self.similar else 'Dissimilar') + super_str 126 | -------------------------------------------------------------------------------- /classify/metric/dev/__init__.py: -------------------------------------------------------------------------------- 1 | from classify.metric.dev.accuracy import Accuracy, F1 2 | 3 | 4 | __all__ = [ 5 | "Accuracy", 6 | "F1", 7 | ] 8 | -------------------------------------------------------------------------------- /classify/metric/dev/accuracy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from classify.metric.abstract import Metric 4 | from sklearn.metrics import f1_score 5 | 6 | 7 | class Accuracy(Metric): 8 | def compute(self, logit: torch.Tensor, targets: dict) -> torch.Tensor: 9 | """Computes the loss. 10 | 11 | Parameters 12 | ---------- 13 | pred: Tensor 14 | input logits of shape (B x N) 15 | target: LontTensor 16 | target tensor of shape (B) or (B x N) 17 | 18 | Returns 19 | ------- 20 | accuracy: torch.Tensor 21 | single label accuracy, of shape (B) 22 | 23 | """ 24 | # If 2-dimensional, select the highest score in each row 25 | target = [t for target in targets for t in target["targets"]] 26 | target = torch.stack(target) 27 | pred = logit 28 | 29 | if len(target.size()) == 2: 30 | target = target.argmax(dim=1) 31 | acc = pred.argmax(dim=1) == target 32 | return acc.float().mean() 33 | 34 | 35 | class F1(Metric): 36 | def compute(self, logit: torch.Tensor, targets: dict) -> torch.Tensor: 37 | target = [t for target in targets for t in target["targets"]] 38 | target = torch.stack(target) 39 | 40 | pred = logit.argmax(dim=1) 41 | 42 | if len(target.size()) == 2: 43 | target = target.argmax(dim=1) 44 | 45 | target = target.detach().cpu().numpy() 46 | pred = pred.detach().cpu().numpy() 47 | f1 = f1_score(target, pred, average="macro") 48 | return f1 49 | 50 | -------------------------------------------------------------------------------- /classify/metric/dev/micro_rationacc.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Tuple, Dict 2 | 3 | from sinkhorn import compute_alignment_cost, compute_entropy 4 | import torch 5 | import numpy as np 6 | from classify.metric.abstract import AlignmentAverageMetric 7 | from utils.utils import prod 8 | from sklearn.metrics import f1_score, precision_score, recall_score 9 | 10 | 11 | def compute_raionale_metrics( 12 | preds: List[Tuple[torch.FloatTensor, torch.FloatTensor]], 13 | targets: List[Dict[str, torch.LongTensor]], 14 | threshold: float = 0.1, 15 | absolute_threshold: bool = False, 16 | ) -> torch.float: 17 | """ 18 | Computes metric across a list of instances of two sets of objects. 19 | 20 | :param preds: A list of (cost, alignment) tuples (each is n x m). 21 | :param targets: A list of LongTensors indicating the correct alignment. 22 | :return: The metric of the alignments. 23 | """ 24 | # Initialize 25 | # Check lengths 26 | epsilon = 1e-10 27 | with torch.no_grad(): 28 | assert len(preds) == len(targets) 29 | 30 | target_true = predicted_true = correct_true = 0 31 | target_true_c = predicted_true_c = correct_true_c = 0 32 | rationale_count = total_count = 0 33 | 34 | # Loop over alignments and add metric and count 35 | all_column_r = [] 36 | all_predict_column_r = [] 37 | ps = [] 38 | rs = [] 39 | fs = [] 40 | 41 | for (cost, alignment), target in zip(preds, targets): 42 | if len(alignment.shape) == 3: 43 | # attention: 44 | # row_alignment = (-cost).softmax(dim=1); column_alignment = (-cost).softmax(dim=0) 45 | row_alignment, column_alignment = alignment[0], alignment[1] 46 | if absolute_threshold: 47 | column_alignment = column_alignment >= threshold # n 48 | row_alignment = row_alignment >= threshold # m 49 | else: 50 | column_alignment = ( 51 | column_alignment 52 | >= threshold / prod(column_alignment.shape[-2:]) 53 | ).float() # n 54 | row_alignment = ( 55 | row_alignment >= threshold / prod(column_alignment.shape[-2:]) 56 | ).float() # m 57 | predict_row_r = column_alignment.sum(1).cpu().numpy() >= 1 # n 58 | predict_column_r = row_alignment.sum(0).cpu().numpy() >= 1 # m 59 | else: 60 | if absolute_threshold: 61 | alignment = alignment >= threshold 62 | else: 63 | alignment = ( 64 | alignment >= threshold / prod(alignment.shape[-2:]) 65 | ).float() # n 66 | 67 | predict_row_r = alignment.sum(1).cpu().numpy() >= 1 # n 68 | predict_column_r = alignment.sum(0).cpu().numpy() >= 1 # m 69 | 70 | row_r = target["row_evidence"].cpu().numpy() # n 71 | column_r = target["column_evidence"].cpu().numpy() # m 72 | 73 | # print(len(column_r)) 74 | # print(len(row_r)) 75 | 76 | # For multirc, needs to chagne from sentence annotation to token annotation 77 | if "lengths" in target: 78 | # print('converting sent rationale to word rationale') 79 | # print(column_r) 80 | # print(sum(column_r)) 81 | # print(predict_column_r) 82 | # print(sum(predict_column_r)) 83 | column_r = rationale_sent_to_token(target["lengths"], column_r) 84 | predict_column_r = rationale_sent_to_token( 85 | target["lengths"], predict_column_r 86 | ) 87 | # print('after converting') 88 | # print(column_r) 89 | # print(predict_column_r) 90 | # print(sum(column_r)) 91 | # print(sum(predict_column_r)) 92 | # import sys; sys.exit() 93 | assert len(row_r) == len(predict_row_r) 94 | assert len(column_r) == len(predict_column_r) 95 | # print(f'predicted row: {predict_row_r }') 96 | # print(f'real row: {row_r }') 97 | f1 = f1_score(column_r, predict_column_r) 98 | fs.append(f1) 99 | 100 | all_column_r.extend(column_r) 101 | all_predict_column_r.extend(predict_column_r) 102 | 103 | if sum(column_r): 104 | rationale_count += sum(predict_column_r) 105 | total_count += len(predict_column_r) 106 | if sum(row_r): 107 | rationale_count += sum(predict_row_r) 108 | total_count += len(predict_row_r) 109 | target_true += np.sum(row_r == 1) # .float() 110 | predicted_true += np.sum(predict_row_r == 1) # .float() 111 | correct_true += np.sum((row_r == 1) * (predict_row_r == 1)) # .float() 112 | 113 | target_true_c += np.sum(column_r == 1) # .float() 114 | predicted_true_c += np.sum(predict_column_r == 1) # .float() 115 | correct_true_c += np.sum( 116 | (column_r == 1) * (predict_column_r == 1) 117 | ) # .float() 118 | 119 | precision = correct_true / (predicted_true + epsilon) 120 | recall = correct_true / (target_true + epsilon) 121 | f1 = 2 * precision * recall / (precision + recall + epsilon) 122 | 123 | precision_c = correct_true_c / (predicted_true_c + epsilon) 124 | recall_c = correct_true_c / (target_true_c + epsilon) 125 | # f1_score_c = 2 * precision_c * recall_c / (precision_c + recall_c + epsilon)s 126 | f1_score_c = sum(fs) / len(fs) 127 | 128 | p_all = (correct_true + correct_true_c) / ( 129 | predicted_true + predicted_true_c + epsilon 130 | ) 131 | r_all = (correct_true + correct_true_c) / ( 132 | target_true + target_true_c + epsilon 133 | ) 134 | f1_all = 2 * p_all * r_all / (p_all + r_all + epsilon) 135 | 136 | rationale_ratio = rationale_count / total_count 137 | # print(f'p:',precision_c) 138 | # print(f'r:',recall_c) 139 | # print(f'f1:',f1_score_c) 140 | # for av in ['micro', 'macro', 'weighted' ]: 141 | # print(av) 142 | # print(f'p:',precision_score(all_column_r, all_predict_column_r)) #, average=av)) 143 | # print(f'r:',recall_score(all_column_r, all_predict_column_r)) #, average=av)) 144 | # print(f'f1:',f1_score(all_column_r, all_predict_column_r)) #, average=av)) 145 | # import sys; sys.exit() 146 | 147 | precision_c = precision_score( 148 | all_column_r, all_predict_column_r, average="macro" 149 | ) 150 | recall_c = recall_score(all_column_r, all_predict_column_r, average="macro") 151 | f1_score_c = f1_score(all_column_r, all_predict_column_r, average="macro") 152 | 153 | return ( 154 | precision, 155 | recall, 156 | f1, 157 | precision_c, 158 | recall_c, 159 | f1_score_c, 160 | p_all, 161 | r_all, 162 | f1_all, 163 | rationale_ratio, 164 | ) 165 | 166 | 167 | def rationale_sent_to_token(lengths, rationales): 168 | total_l = sum(lengths) 169 | r_tk = np.zeros(total_l) 170 | for i, s in enumerate(list(rationales)): 171 | if s != 0: 172 | r_tk[sum(lengths[:i]) : sum(lengths[: i + 1])] = 1 173 | return r_tk 174 | -------------------------------------------------------------------------------- /classify/metric/dev/rationacc.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Tuple, Dict 2 | 3 | from sinkhorn import compute_alignment_cost, compute_entropy 4 | import torch 5 | import numpy as np 6 | from classify.metric.abstract import AlignmentAverageMetric 7 | from utils.utils import prod 8 | from sklearn.metrics import f1_score, precision_score, recall_score 9 | 10 | 11 | def compute_raionale_metrics( 12 | preds: List[Tuple[torch.FloatTensor, torch.FloatTensor]], 13 | targets: List[Dict[str, torch.LongTensor]], 14 | threshold: float = 0.1, 15 | absolute_threshold: bool = False, 16 | ) -> torch.float: 17 | """ 18 | Computes metric across a list of instances of two sets of objects. 19 | 20 | :param preds: A list of (cost, alignment) tuples (each is n x m). 21 | :param targets: A list of LongTensors indicating the correct alignment. 22 | :return: The metric of the alignments. 23 | """ 24 | # Initialize 25 | # Check lengths 26 | epsilon = 1e-10 27 | with torch.no_grad(): 28 | assert len(preds) == len(targets) 29 | 30 | # target_true = predicted_true = correct_true = 0 31 | # target_true_c = predicted_true_c = correct_true_c = 0 32 | rationale_count = total_count = 0 33 | 34 | # Initialize the result list for row and columns 35 | p_r = [] 36 | r_r = [] 37 | f_r = [] 38 | p_c = [] 39 | r_c = [] 40 | f_c = [] 41 | 42 | for (cost, alignment), target in zip(preds, targets): 43 | if len(alignment.shape) == 3: 44 | # attention: 45 | # row_alignment = (-cost).softmax(dim=1); column_alignment = (-cost).softmax(dim=0) 46 | row_alignment, column_alignment = alignment[0], alignment[1] 47 | if absolute_threshold: 48 | column_alignment = column_alignment >= threshold # n 49 | row_alignment = row_alignment >= threshold # m 50 | else: 51 | column_alignment = ( 52 | column_alignment 53 | >= threshold / prod(column_alignment.shape[-2:]) 54 | ).float() # n 55 | row_alignment = ( 56 | row_alignment >= threshold / prod(column_alignment.shape[-2:]) 57 | ).float() # m 58 | predict_row_r = column_alignment.sum(1).cpu().numpy() >= 1 # n 59 | predict_column_r = row_alignment.sum(0).cpu().numpy() >= 1 # m 60 | else: 61 | if absolute_threshold: 62 | alignment = alignment >= threshold 63 | else: 64 | alignment = ( 65 | alignment >= threshold / prod(alignment.shape[-2:]) 66 | ).float() # n 67 | 68 | predict_row_r = alignment.sum(1).cpu().numpy() >= 1 # n 69 | predict_column_r = alignment.sum(0).cpu().numpy() >= 1 # m 70 | 71 | row_r = target["row_evidence"].cpu().numpy() # n 72 | column_r = target["column_evidence"].cpu().numpy() # m 73 | 74 | # print(len(column_r)) 75 | # print(len(row_r)) 76 | 77 | # For multirc, needs to chagne from sentence annotation to token annotation 78 | if "lengths" in target: 79 | column_r = rationale_sent_to_token(target["lengths"], column_r) 80 | predict_column_r = rationale_sent_to_token( 81 | target["lengths"], predict_column_r 82 | ) 83 | 84 | assert len(row_r) == len(predict_row_r) 85 | assert len(column_r) == len(predict_column_r) 86 | 87 | if (sum(column_r) + sum(predict_column_r)) != 0: 88 | p_instance = precision_score( 89 | column_r, predict_column_r 90 | ) # , average=av)) 91 | r_instance = recall_score(column_r, predict_column_r) # , average=av)) 92 | f_instance = f1_score(column_r, predict_column_r) 93 | p_c.append(p_instance) 94 | r_c.append(r_instance) 95 | f_c.append(f_instance) 96 | rationale_count += sum(predict_column_r) 97 | total_count += len(predict_column_r) 98 | # if (p_instance+r_instance) !=0: 99 | # assert p_instance!=r_instance 100 | else: 101 | print("zero rationale annotatino") 102 | if not "lengths" in target: 103 | # if sum(row_r): # + sum(predict_row_r): 104 | if "lengths" in target: 105 | print("multirc has no ratioanle anotation on QA pairs") 106 | import sys 107 | 108 | sys.exit() 109 | p_instance = precision_score(row_r, predict_row_r) # , average=av)) 110 | r_instance = recall_score(row_r, predict_row_r) # , average=av)) 111 | f_instance = f1_score(row_r, predict_row_r) 112 | p_r.append(p_instance) 113 | r_r.append(r_instance) 114 | f_r.append(f_instance) 115 | rationale_count += sum(predict_row_r) 116 | total_count += len(predict_row_r) 117 | 118 | rationale_ratio = rationale_count / total_count 119 | 120 | pc = sum(p_c) / (len(p_c) + epsilon) 121 | rc = sum(r_c) / (len(r_c) + epsilon) 122 | fc = sum(f_c) / (len(f_c) + epsilon) 123 | 124 | pr = sum(p_r) / (len(p_r) + epsilon) 125 | rr = sum(r_r) / (len(r_r) + epsilon) 126 | fr = sum(f_r) / (len(f_r) + epsilon) 127 | 128 | p_all = p_c + p_r 129 | r_all = r_c + r_r 130 | f_all = f_c + f_r 131 | 132 | p = sum(p_all) / (len(p_all) + epsilon) 133 | r = sum(r_all) / (len(r_all) + epsilon) 134 | f = sum(f_all) / (len(f_all) + epsilon) 135 | 136 | if "lengths" in targets[0]: 137 | # For multirc, there is no annotation for row, aks q+a 138 | assert len(p_r) == 0 139 | assert p == pc 140 | assert f == fc 141 | assert r == rc 142 | if rc == pc: 143 | # print(p_c) 144 | # print(r_c) 145 | print(rc) 146 | print(pc) 147 | print() 148 | # return precision, recall, f1, precision_c, recall_c, f1_score_c, p_all, r_all, f1_all, rationale_ratio 149 | return pr, rr, fr, pc, rc, fc, p, r, f, rationale_ratio 150 | 151 | 152 | def rationale_sent_to_token(lengths, rationales): 153 | total_l = sum(lengths) 154 | r_tk = np.zeros(total_l) 155 | for i, s in enumerate(list(rationales)): 156 | if s != 0: 157 | r_tk[sum(lengths[:i]) : sum(lengths[: i + 1])] = 1 158 | return r_tk 159 | -------------------------------------------------------------------------------- /classify/metric/load.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | from classify.metric import Metric 4 | from classify.metric.dev import * 5 | from classify.metric.loss import * 6 | from classify.metric.train import * 7 | from utils.parsing import Arguments 8 | 9 | 10 | def load_loss(args): 11 | print(f"using loss {args.loss_fn}") 12 | if args.loss_fn == "hinge": 13 | loss_fn = HingeLoss( 14 | margin=args.margin, pooling=args.hinge_pooling, alpha=args.hinge_alpha 15 | ) 16 | elif args.loss_fn == "cross_entropy": 17 | loss_fn = CrossEntropyLoss() 18 | elif args.loss_fn == "bce": 19 | loss_fn = BinaryCrossEntropyLoss() 20 | elif args.loss_fn == "marginloss": 21 | loss_fn = MultiMarginLoss(args) 22 | elif args.loss_fn == "f1loss": 23 | loss_fn = F1Loss() 24 | return loss_fn 25 | 26 | 27 | def load_loss_and_metrics( 28 | args: Arguments, 29 | ) -> Tuple[Metric, Metric, List[Metric], List[Metric]]: 30 | """ 31 | Defines the loss and metric functions that will be used during AskUbuntu training. 32 | 33 | :param args: Arguments. 34 | :return: A tuple consisting of: 35 | 1) Training loss function 36 | 2) Dev metric function 37 | 3) A list of additional training metrics 38 | 4) A list of additional validation metrics 39 | """ 40 | # Loss 41 | loss_fn = load_loss(args) 42 | if args.dataset in ["snli", "multirc"]: 43 | metric_fn = F1() 44 | extra_validation_metrics = [Accuracy()] 45 | extra_training_metrics = [Accuracy()] 46 | else: 47 | metric_fn = AUC() 48 | extra_validation_metrics = [ 49 | AUC(max_fpr=0.1), 50 | AUC(max_fpr=0.05), 51 | MAP(), 52 | MRR(), 53 | Precision(n=1), 54 | Precision(n=5), 55 | ] 56 | extra_training_metrics = [] 57 | 58 | if args.alignment != "average": 59 | extra_training_metrics += [ 60 | CostRange(), 61 | CostMin(), 62 | CostMax(), 63 | CostMean(), 64 | CostMedian(), 65 | AlignmentCount(), 66 | AlignmentCount(normalize="min"), 67 | AlignmentCount(normalize="full"), 68 | AlignmentSum(), 69 | AlignmentRowMarginalError(), 70 | AlignmentColumnMarginalError(), 71 | # AlignmentCost(similar=True), 72 | # AlignmentCost(similar=False), 73 | # AlignmentEntropy(similar=True), 74 | # AlignmentEntropy(similar=False) 75 | ] 76 | 77 | # if args.cost_fn in ['dot_product', 'scaled_dot_product', 'cosine_similarity']: 78 | # extra_training_metrics += [ 79 | # CostSign(positive=True, similar=True), 80 | # CostSign(positive=True, similar=False), 81 | # CostSign(positive=False, similar=True), 82 | # CostSign(positive=False, similar=False) 83 | # ] 84 | 85 | if args.alignment != "average": 86 | extra_validation_metrics += [ 87 | AlignmentCount(), 88 | AlignmentCount(normalize="min"), 89 | AlignmentCount(normalize="full"), 90 | ] 91 | 92 | return loss_fn, metric_fn, extra_training_metrics, extra_validation_metrics 93 | -------------------------------------------------------------------------------- /classify/metric/loss/__init__.py: -------------------------------------------------------------------------------- 1 | # from classify.metric.loss.hinge import HingeLoss 2 | from classify.metric.loss.ce import CrossEntropyLoss 3 | from classify.metric.loss.bce import BinaryCrossEntropyLoss 4 | 5 | # from classify.metric.loss.multimarginloss import MultiMarginLoss 6 | from classify.metric.loss.f1loss import F1Loss 7 | 8 | __all__ = [ 9 | # "HingeLoss", 10 | "CrossEntropyLoss", 11 | "BinaryCrossEntropyLoss", 12 | # "MultiMarginLoss", 13 | "F1Loss", 14 | ] 15 | -------------------------------------------------------------------------------- /classify/metric/loss/bce.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple, Optional 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from classify.metric.abstract import AlignmentMetric 7 | from classify.metric.abstract import Metric 8 | 9 | from utils.utils import prod 10 | 11 | 12 | class BinaryCrossEntropyLoss(Metric): 13 | """Computes the hinge loss between aligned and un-aligned document pairs (for AskUbuntu). 14 | 15 | For each document, the loss is sum_ij |negative_similarity_i - positive_similarity_j + margin| 16 | i.e. sum over all positive/negative pairs 17 | """ 18 | 19 | def __init__( 20 | self, 21 | weight: Optional[torch.Tensor] = None, 22 | ignore_index: Optional[int] = None, 23 | reduction: str = "mean", 24 | ) -> None: 25 | """Initialize the MultiLabelNLLLoss. 26 | 27 | Parameters 28 | ---------- 29 | weight : Optional[torch.Tensor] 30 | A manual rescaling weight given to each class. 31 | If given, has to be a Tensor of size N, where N is the 32 | number of classes. 33 | ignore_index : Optional[int], optional 34 | Specifies a target value that is ignored and does not 35 | contribute to the input gradient. When size_average is 36 | True, the loss is averaged over non-ignored targets. 37 | reduction : str, optional 38 | Specifies the reduction to apply to the output: 39 | 'none' | 'mean' | 'sum'. 40 | 'none': no reduction will be applied, 41 | 'mean': the output will be averaged 42 | 'sum': the output will be summed. 43 | 44 | """ 45 | super(BinaryCrossEntropyLoss, self).__init__() 46 | self.weight = weight 47 | self.ignore_index = ignore_index 48 | self.reduction = reduction 49 | 50 | def compute( 51 | self, logits: torch.Tensor, targets: torch.Tensor, step: int = 4 52 | ) -> torch.Tensor: 53 | """Computes the Negative log likelihood loss for multilabel. 54 | 55 | Parameters 56 | ---------- 57 | pred: torch.Tensor 58 | input logits of shape (B x N) 59 | target: torch.LontTensor 60 | target tensor of shape (B x N) 61 | 62 | Returns 63 | ------- 64 | loss: torch.float 65 | Multi label negative log likelihood loss, of shape (B) 66 | 67 | """ 68 | 69 | targets = [t for target in targets for t in target["targets"]] 70 | targets = torch.stack(targets).float() 71 | 72 | logits = torch.stack( 73 | [torch.sum(cost * alignment) for cost, alignment in logits] 74 | ) 75 | 76 | if self.ignore_index is not None: 77 | targets[:, self.ignore_index] = 0 78 | 79 | # if self.weight is None: 80 | # self.weight = torch.ones(logits.size(1)).to(logits) 81 | 82 | loss = F.binary_cross_entropy_with_logits( 83 | logits, targets 84 | ) # , weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction) 85 | return loss 86 | -------------------------------------------------------------------------------- /classify/metric/loss/ce.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple, Optional 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from classify.metric.abstract import AlignmentMetric 7 | from classify.metric.abstract import Metric 8 | 9 | 10 | class CrossEntropyLoss(Metric): 11 | """Computes the hinge loss between aligned and un-aligned document pairs (for AskUbuntu). 12 | 13 | For each document, the loss is sum_ij |negative_similarity_i - positive_similarity_j + margin| 14 | i.e. sum over all positive/negative pairs 15 | """ 16 | 17 | def __init__( 18 | self, 19 | weight: Optional[torch.Tensor] = None, 20 | ignore_index: Optional[int] = None, 21 | reduction: str = "mean", 22 | ) -> None: 23 | """Initialize the MultiLabelNLLLoss. 24 | 25 | Parameters 26 | ---------- 27 | weight : Optional[torch.Tensor] 28 | A manual rescaling weight given to each class. 29 | If given, has to be a Tensor of size N, where N is the 30 | number of classes. 31 | ignore_index : Optional[int], optional 32 | Specifies a target value that is ignored and does not 33 | contribute to the input gradient. When size_average is 34 | True, the loss is averaged over non-ignored targets. 35 | reduction : str, optional 36 | Specifies the reduction to apply to the output: 37 | 'none' | 'mean' | 'sum'. 38 | 'none': no reduction will be applied, 39 | 'mean': the output will be averaged 40 | 'sum': the output will be summed. 41 | 42 | """ 43 | super(CrossEntropyLoss, self).__init__() 44 | self.weight = weight 45 | self.ignore_index = ignore_index 46 | self.reduction = reduction 47 | 48 | def compute( 49 | self, logits: torch.Tensor, targets: torch.Tensor, step: int = 4 50 | ) -> torch.Tensor: 51 | """Computes the Negative log likelihood loss for multilabel. 52 | 53 | Parameters 54 | ---------- 55 | pred: torch.Tensor 56 | input logits of shape (B x N) 57 | target: torch.LontTensor 58 | target tensor of shape (B x N) 59 | 60 | Returns 61 | ------- 62 | loss: torch.float 63 | Multi label negative log likelihood loss, of shape (B) 64 | 65 | """ 66 | targets = [t for target in targets for t in target["targets"]] 67 | targets = torch.stack(targets) 68 | if self.ignore_index is not None: 69 | targets[:, self.ignore_index] = 0 70 | 71 | if self.weight is None: 72 | self.weight = torch.ones(logits.size(1)).to(logits) 73 | 74 | loss = F.cross_entropy( 75 | logits, targets 76 | ) # , weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction) 77 | return loss 78 | -------------------------------------------------------------------------------- /classify/metric/loss/f1loss.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple 2 | 3 | import torch 4 | 5 | from classify.metric.abstract import Metric 6 | import torch.nn.functional as F 7 | 8 | 9 | class F1Loss(Metric): 10 | '''Calculate F1 score. Can work with gpu tensors 11 | 12 | The original implmentation is written by Michal Haltuf on Kaggle. 13 | 14 | Returns 15 | ------- 16 | torch.Tensor 17 | `ndim` == 1. epsilon <= val <= 1 18 | 19 | Reference 20 | --------- 21 | - https://www.kaggle.com/rejpalcz/best-loss-function-for-f1-score-metric 22 | - https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html#sklearn.metrics.f1_score 23 | - https://discuss.pytorch.org/t/calculating-precision-recall-and-f1-score-in-case-of-multi-label-classification/28265/6 24 | - http://www.ryanzhang.info/python/writing-your-own-loss-function-module-for-pytorch/ 25 | ''' 26 | def __init__(self, 27 | epsilon=1e-7): 28 | super(F1Loss, self).__init__() 29 | self.epsilon = epsilon 30 | 31 | def forward(self, y_pred, y_true,): 32 | assert y_pred.ndim == 2 33 | assert y_true.ndim == 1 34 | y_true = F.one_hot(y_true, 2).to(torch.float32) 35 | y_pred = F.softmax(y_pred, dim=1) 36 | 37 | tp = (y_true * y_pred).sum(dim=0).to(torch.float32) 38 | tn = ((1 - y_true) * (1 - y_pred)).sum(dim=0).to(torch.float32) 39 | fp = ((1 - y_true) * y_pred).sum(dim=0).to(torch.float32) 40 | fn = (y_true * (1 - y_pred)).sum(dim=0).to(torch.float32) 41 | 42 | precision = tp / (tp + fp + self.epsilon) 43 | recall = tp / (tp + fn + self.epsilon) 44 | 45 | f1 = 2* (precision*recall) / (precision + recall + self.epsilon) 46 | f1 = f1.clamp(min=self.epsilon, max=1-self.epsilon) 47 | return 1 - f1.mean() 48 | 49 | 50 | def compute(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 51 | """Computes the Negative log likelihood loss for multilabel. 52 | 53 | Parameters 54 | ---------- 55 | pred: torch.Tensor 56 | input logits of shape (B x N) 57 | target: torch.LontTensor 58 | target tensor of shape (B x N) 59 | 60 | Returns 61 | ------- 62 | loss: torch.float 63 | Multi label negative log likelihood loss, of shape (B) 64 | 65 | """ 66 | # TODO: Need to compute this logits 67 | # pred = torch.Tensor(pred).cuda() 68 | # target = torch.Tensor(target) 69 | targets = [t for target in targets for t in target['targets']] 70 | targets = torch.stack(targets) 71 | loss = self.forward(logits, targets) 72 | 73 | return loss 74 | 75 | -------------------------------------------------------------------------------- /classify/metric/loss/rationaleloss.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple, Optional 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from classify.metric.abstract import Metric 7 | 8 | 9 | class RationaleBCELoss(Metric): 10 | def __init__(self, domain): 11 | self.domain = domain 12 | 13 | def compute( 14 | self, 15 | preds: List[Tuple[torch.FloatTensor, torch.FloatTensor]], 16 | targets: List[Dict[str, torch.LongTensor]], 17 | ) -> torch.float: 18 | rationale_loss = 0 19 | for (c, alignment), target in zip(preds, targets): 20 | # gold_column_r = torch.stack([target['column_evidence'] for target in targets]).to(self.device) 21 | 22 | # rationale_pred = [] 23 | if len(alignment.shape) == 3: 24 | # attention: 25 | # row_alignment = (-cost).softmax(dim=1); column_alignment = (-cost).softmax(dim=0) 26 | row_alignment, column_alignment = alignment[0], alignment[1] 27 | predict_row_r = column_alignment.sum(1) 28 | predict_column_r = row_alignment.sum(0) 29 | else: 30 | predict_row_r = alignment.sum(1) 31 | predict_column_r = alignment.sum(0) 32 | 33 | gold_column_r = target["column_evidence"].float() 34 | rationale_loss += F.binary_cross_entropy_with_logits( 35 | predict_column_r, gold_column_r 36 | ) / len(targets) 37 | if self.domain == "snli": 38 | gold_row_r = target["row_evidence"].float() 39 | rationale_loss += F.binary_cross_entropy_with_logits( 40 | predict_row_r, gold_row_r 41 | ) / len(targets) 42 | 43 | return rationale_loss 44 | -------------------------------------------------------------------------------- /classify/metric/loss/regularizor.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple 2 | 3 | import torch 4 | 5 | from classify.metric.abstract import AlignmentMetric 6 | from utils.utils import prod 7 | 8 | 9 | class ReguCost(AlignmentMetric): 10 | """Computes the hinge loss between aligned and un-aligned document pairs (for AskUbuntu). 11 | 12 | For each document, the loss is sum_ij |negative_similarity_i - positive_similarity_j + margin| 13 | i.e. sum over all positive/negative pairs 14 | """ 15 | 16 | def __init__(self, cost_lambda, ltype, device): 17 | super(ReguCost, self).__init__() 18 | 19 | self.lmbd = cost_lambda 20 | self.type = ltype 21 | self.device = device 22 | 23 | def compute( 24 | self, 25 | preds: List[Tuple[torch.FloatTensor, torch.FloatTensor]], 26 | targets: List[Dict[str, torch.LongTensor]], 27 | step: int = 4, 28 | ) -> torch.float: 29 | reg = 0 30 | for (cost, alignment) in preds: 31 | # reg += torch.sum(cost<0) /prod(cost.shape[-2:]) 32 | reg += self.cost_reg(cost) # ideally, cost needs to be larger, 33 | 34 | return reg * self.lmbd 35 | 36 | def cost_reg(self, cost): 37 | if self.type == "l0": 38 | reg = torch.mean((cost < 0).float().to(self.device)) 39 | if self.type == "l0.5": 40 | reg = -torch.mean(cost * (cost < 0).float().to(self.device)) 41 | if self.type == "l1": 42 | reg = -torch.mean(cost) 43 | # if self.type == 'l2': 44 | # reg = 0 45 | 46 | return reg 47 | -------------------------------------------------------------------------------- /classify/metric/train/__init__.py: -------------------------------------------------------------------------------- 1 | from classify.metric.train.alignment import ( 2 | AlignmentSum, 3 | AlignmentCount, 4 | AlignmentRowMarginalError, 5 | AlignmentColumnMarginalError, 6 | AlignmentCost, 7 | AlignmentEntropy, 8 | AlignmentEpsilonEntropy, 9 | ) 10 | from classify.metric.train.cost import ( 11 | CostRange, 12 | CostMin, 13 | CostMax, 14 | CostMean, 15 | CostMedian, 16 | CostSign, 17 | ) 18 | 19 | __all__ = [ 20 | "AlignmentSum", 21 | "AlignmentCount", 22 | "AlignmentRowMarginalError", 23 | "AlignmentColumnMarginalError", 24 | "AlignmentCost", 25 | "AlignmentEntropy", 26 | "AlignmentEpsilonEntropy", 27 | "CostRange", 28 | "CostMin", 29 | "CostMax", 30 | "CostMean", 31 | "CostMedian", 32 | "CostSign", 33 | ] 34 | -------------------------------------------------------------------------------- /classify/metric/train/alignment.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from sinkhorn import compute_alignment_cost, compute_entropy 4 | import torch 5 | 6 | from classify.metric.abstract import AlignmentAverageMetric 7 | 8 | 9 | class AlignmentSum(AlignmentAverageMetric): 10 | def _compute_one(self, 11 | cost: torch.FloatTensor, 12 | alignment: torch.FloatTensor, 13 | target: int) -> torch.float: 14 | return alignment.sum() 15 | 16 | 17 | class AlignmentCount(AlignmentAverageMetric): 18 | """Returns the number of non-zero entries (possibly normalized).""" 19 | 20 | def __init__(self, 21 | normalize: str = 'none', # 'none', 'min' to norm by min(n, m), 'full' to norm by (nm) 22 | threshold_scaling: float = 1.0, # how much to scale the 1 / (n * m) threshold 23 | similar: Optional[bool] = None): 24 | assert normalize in ['none', 'min', 'full'] 25 | 26 | super(AlignmentCount, self).__init__(similar=similar) 27 | self.normalize = normalize 28 | 29 | def _compute_one(self, 30 | cost: torch.FloatTensor, 31 | alignment: torch.FloatTensor, 32 | target: int) -> torch.float: 33 | n, m = alignment.shape[-2:] 34 | count = torch.sum(alignment != 0).float() 35 | 36 | if self.normalize == 'full': 37 | return count / (n * m) 38 | elif self.normalize == 'min': 39 | return count / min(n, m) 40 | 41 | return count 42 | 43 | def __str__(self) -> str: 44 | string = super(AlignmentCount, self).__str__() 45 | 46 | if self.normalize != 'none': 47 | string += '_normalized_by_' + ('min_nm' if self.normalize == 'min' else 'nm') 48 | 49 | return string 50 | 51 | 52 | class AlignmentMarginalError(AlignmentAverageMetric): 53 | """Returns the average absolute error of either the row or column marginal, assuming uniform marginal.""" 54 | 55 | def __init__(self, side: int, similar: Optional[bool] = None): 56 | super(AlignmentMarginalError, self).__init__(similar=similar) 57 | assert side in {0, 1} 58 | self.side = side 59 | 60 | def _compute_one(self, 61 | cost: torch.FloatTensor, 62 | alignment: torch.FloatTensor, 63 | target: int) -> torch.float: 64 | device = alignment.device 65 | marginal_dim = -2 if self.side == 0 else -1 66 | sum_dim = -1 if self.side == 0 else -2 67 | marginal = torch.ones(alignment.size(marginal_dim), device=device) / alignment.size(marginal_dim) 68 | marginal_hat = alignment.sum(sum_dim) 69 | error = torch.abs(marginal - marginal_hat).mean() 70 | 71 | return error 72 | 73 | 74 | class AlignmentRowMarginalError(AlignmentMarginalError): 75 | def __init__(self, similar: Optional[bool] = None): 76 | super(AlignmentRowMarginalError, self).__init__(side=0, similar=similar) 77 | 78 | 79 | class AlignmentColumnMarginalError(AlignmentMarginalError): 80 | def __init__(self, similar: Optional[bool] = None): 81 | super(AlignmentColumnMarginalError, self).__init__(side=1, similar=similar) 82 | 83 | 84 | class AlignmentCost(AlignmentAverageMetric): 85 | """Computes the cost of aligning two objects.""" 86 | 87 | def _compute_one(self, 88 | cost: torch.FloatTensor, 89 | alignment: torch.FloatTensor, 90 | target: int) -> torch.float: 91 | return compute_alignment_cost(C=cost, P=alignment) 92 | 93 | 94 | class AlignmentEntropy(AlignmentAverageMetric): 95 | """Computes the entropy of aligning two objects.""" 96 | 97 | def _compute_one(self, 98 | cost: torch.FloatTensor, 99 | alignment: torch.FloatTensor, 100 | target: int) -> torch.float: 101 | return compute_entropy(P=alignment) 102 | 103 | 104 | class AlignmentEpsilonEntropy(AlignmentAverageMetric): 105 | """Computes epsilon times the entropy of aligning two objects.""" 106 | 107 | def __init__(self, epsilon: float, similar: Optional[bool] = None): 108 | super(AlignmentEpsilonEntropy, self).__init__(similar=similar) 109 | self.epsilon = epsilon 110 | 111 | def _compute_one(self, 112 | cost: torch.FloatTensor, 113 | alignment: torch.FloatTensor, 114 | target: int) -> torch.float: 115 | return self.epsilon * compute_entropy(P=alignment) 116 | -------------------------------------------------------------------------------- /classify/metric/train/cost.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from classify.metric.abstract import AlignmentAverageMetric 4 | 5 | 6 | class CostRange(AlignmentAverageMetric): 7 | def _compute_one( 8 | self, cost: torch.FloatTensor, alignment: torch.FloatTensor, target: int 9 | ) -> torch.float: 10 | return cost.max() - cost.min() 11 | 12 | 13 | class CostMin(AlignmentAverageMetric): 14 | def _compute_one( 15 | self, cost: torch.FloatTensor, alignment: torch.FloatTensor, target: int 16 | ) -> torch.float: 17 | return cost.min() 18 | 19 | 20 | class CostMax(AlignmentAverageMetric): 21 | def _compute_one( 22 | self, cost: torch.FloatTensor, alignment: torch.FloatTensor, target: int 23 | ) -> torch.float: 24 | return cost.max() 25 | 26 | 27 | class CostMean(AlignmentAverageMetric): 28 | def _compute_one( 29 | self, cost: torch.FloatTensor, alignment: torch.FloatTensor, target: int 30 | ) -> torch.float: 31 | return cost.mean() 32 | 33 | 34 | class CostMedian(AlignmentAverageMetric): 35 | def _compute_one( 36 | self, cost: torch.FloatTensor, alignment: torch.FloatTensor, target: int 37 | ) -> torch.float: 38 | return cost.median() 39 | 40 | 41 | class CostSign(AlignmentAverageMetric): 42 | def __init__(self, positive: bool, similar: bool): 43 | super(CostSign, self).__init__(similar=similar) 44 | self.positive = positive # Whether to count positive or negative costs 45 | 46 | def _compute_one( 47 | self, cost: torch.FloatTensor, alignment: torch.FloatTensor, target: int 48 | ) -> torch.float: 49 | if self.positive: 50 | return torch.sum(cost > 0).float() 51 | return torch.sum(cost < 0).float() 52 | 53 | def __str__(self) -> str: 54 | return f'Num{"Positive" if self.positive else "Negative"}CostWhen{"Similar" if self.similar else "Dissimilar"}' 55 | -------------------------------------------------------------------------------- /classify/models/__init__.py: -------------------------------------------------------------------------------- 1 | # # from classify.models.alignment import AlignmentModel 2 | # from classify.models.attention import SparsemaxFunction 3 | 4 | # __all__ = ["AlignmentModel", "SparsemaxFunction"] 5 | -------------------------------------------------------------------------------- /classify/models/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Tuple 4 | 5 | from utils.parsing import Arguments 6 | from utils.utils import compute_cost, prod 7 | 8 | 9 | def load_attention_layer(args: Arguments, bidirectional: bool) -> nn.Module: 10 | if args.attention_type == 0 or args.attention_type == 1: 11 | return attention0(args, bidirectional) 12 | if args.attention_type == 2 or args.attention_type == 3: 13 | return attention2(args, bidirectional) 14 | if args.attention_type == 4: 15 | return attention4(args, bidirectional) 16 | if args.attention_type == 5: 17 | return attention5(args, bidirectional) 18 | 19 | 20 | def build_ffn(input_size: int, output_size: int, args: Arguments) -> nn.Module: 21 | """Builds a 2-layer feed-forward network.""" 22 | return nn.Sequential( 23 | nn.Linear(input_size, output_size), nn.Dropout(args.dropout), nn.LeakyReLU(0.2), 24 | ) 25 | 26 | 27 | class attention5(nn.Module): 28 | def __init__(self, args: Arguments, input_size: int): 29 | super(attention5, self).__init__() 30 | self.args = args 31 | self.G = build_ffn( 32 | input_size=input_size, output_size=self.args.hidden_size, args=self.args 33 | ) 34 | if args.dataset == "snli": 35 | self.out = nn.Linear(2 * self.args.hidden_size, 3) 36 | elif args.dataset == "multirc": 37 | self.out = nn.Linear(2 * self.args.hidden_size, 2) 38 | self.sparsemax = SparsemaxFunction.apply 39 | 40 | def forward( 41 | self, 42 | row_vecs: torch.FloatTensor, 43 | column_vecs: torch.FloatTensor, 44 | cost: torch.FloatTensor, 45 | threshold: float = 0, 46 | ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: 47 | # https://arxiv.org/abs/1606.01933 48 | # Attend 49 | a, b = row_vecs, column_vecs # n x hidden_size, m x hidden_size 50 | 51 | if self.args.using_sparsemax: 52 | row_alignment = self.sparsemax(-cost, 1) 53 | column_alignment = self.sparsemax(-cost, 0) 54 | else: 55 | row_alignment = (-cost).softmax(dim=1) 56 | column_alignment = (-cost).softmax(dim=0) 57 | 58 | if threshold: 59 | if self.args.absolute_threshold: 60 | row_alignment = ( 61 | row_alignment * (row_alignment >= threshold).float().cuda() 62 | ) 63 | column_alignment = ( 64 | column_alignment * (column_alignment >= threshold).float().cuda() 65 | ) 66 | else: 67 | mask = ( 68 | ( 69 | torch.rand(row_alignment.size(0), row_alignment.size(1)) 70 | >= threshold 71 | ) 72 | .float() 73 | .cuda() 74 | ) 75 | # threshold_alignments = [alignment * ( torch.rand(alignment.size(-2), alignment.size(-1)) >=0.5 ).float() for alignment in alignments] 76 | row_alignment = ( 77 | row_alignment * mask 78 | ) # (row_alignment >= threshold/size).float() 79 | column_alignment = ( 80 | column_alignment * mask 81 | ) # (column_alignment >= threshold/size).float() 82 | 83 | beta = torch.sum( 84 | row_alignment.unsqueeze(dim=2) * b.unsqueeze(dim=0), dim=1 85 | ) # n x hidden_size 86 | alpha = torch.sum( 87 | column_alignment.unsqueeze(dim=2) * a.unsqueeze(dim=1), dim=0 88 | ) # m x hidden_size 89 | 90 | # Compare 91 | if self.args.force_attention_linear: 92 | v_1 = beta.mean(dim=0, keepdim=True) # n x hidden_size 93 | v_2 = alpha.mean(dim=0, keepdim=True) # m x hidden_size 94 | else: 95 | v_1 = self.G(beta).mean(dim=0, keepdim=True) # n x hidden_size 96 | v_2 = self.G(alpha).mean(dim=0, keepdim=True) # m x hidden_size 97 | 98 | y = compute_cost(cost_fn=self.args.cost_fn, x1=v_1, x2=v_2) 99 | 100 | # import pdb; pdb.set_trace() 101 | logit = self.out( 102 | torch.cat((self.G(beta).mean(dim=0), self.G(alpha).mean(dim=0)), dim=0) 103 | ) 104 | cost_matrix = y * torch.ones_like(cost) 105 | alignment = torch.stack( 106 | (row_alignment.detach(), column_alignment.detach()), dim=0 107 | ) 108 | return cost_matrix, alignment, logit 109 | 110 | 111 | """Sparsemax activation function. 112 | Pytorch implementation of Sparsemax function from: 113 | -- "From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification" 114 | -- André F. T. Martins, Ramón Fernandez Astudillo (http://arxiv.org/abs/1602.02068) 115 | """ 116 | """ 117 | An implementation of sparsemax (Martins & Astudillo, 2016). See 118 | :cite:`DBLP:journals/corr/MartinsA16` for detailed description. 119 | By Ben Peters and Vlad Niculae 120 | """ 121 | # From: https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/modules/sparse_activations.py 122 | 123 | import torch 124 | from torch.autograd import Function 125 | import torch.nn as nn 126 | 127 | 128 | def _make_ix_like(input, dim=0): 129 | d = input.size(dim) 130 | rho = torch.arange(1, d + 1, device=input.device, dtype=input.dtype) 131 | view = [1] * input.dim() 132 | view[0] = -1 133 | return rho.view(view).transpose(0, dim) 134 | 135 | 136 | def _threshold_and_support(input, dim=0): 137 | """Sparsemax building block: compute the threshold 138 | Args: 139 | input: any dimension 140 | dim: dimension along which to apply the sparsemax 141 | Returns: 142 | the threshold value 143 | """ 144 | 145 | input_srt, _ = torch.sort(input, descending=True, dim=dim) 146 | input_cumsum = input_srt.cumsum(dim) - 1 147 | rhos = _make_ix_like(input, dim) 148 | support = rhos * input_srt > input_cumsum 149 | 150 | support_size = support.sum(dim=dim).unsqueeze(dim) 151 | tau = input_cumsum.gather(dim, support_size - 1) 152 | tau /= support_size.to(input.dtype) 153 | return tau, support_size 154 | 155 | 156 | class SparsemaxFunction(Function): 157 | @staticmethod 158 | def forward(ctx, input, dim=0): 159 | """sparsemax: normalizing sparse transform (a la softmax) 160 | Parameters: 161 | input (Tensor): any shape 162 | dim: dimension along which to apply sparsemax 163 | Returns: 164 | output (Tensor): same shape as input 165 | """ 166 | ctx.dim = dim 167 | max_val, _ = input.max(dim=dim, keepdim=True) 168 | input -= max_val # same numerical stability trick as for softmax 169 | tau, supp_size = _threshold_and_support(input, dim=dim) 170 | output = torch.clamp(input - tau, min=0) 171 | ctx.save_for_backward(supp_size, output) 172 | return output 173 | 174 | @staticmethod 175 | def backward(ctx, grad_output): 176 | supp_size, output = ctx.saved_tensors 177 | dim = ctx.dim 178 | grad_input = grad_output.clone() 179 | grad_input[output == 0] = 0 180 | 181 | v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze() 182 | v_hat = v_hat.unsqueeze(dim) 183 | grad_input = torch.where(output != 0, grad_input - v_hat, grad_input) 184 | return grad_input, None 185 | 186 | 187 | # sparsemax = SparsemaxFunction.apply 188 | 189 | 190 | # class Sparsemax(nn.Module): 191 | 192 | # def __init__(self, dim=0): 193 | # self.dim = dim 194 | # super(Sparsemax, self).__init__() 195 | 196 | # def forward(self, input): 197 | # return sparsemax(input, self.dim) 198 | -------------------------------------------------------------------------------- /classify/models/encoder.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import math 3 | from typing import Any, Iterator, List, Optional, Tuple 4 | 5 | from sinkhorn import batch_sinkhorn, construct_cost_and_marginals 6 | from sru import SRU 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from utils.parsing import Arguments 12 | from utils.utils import compute_cost, prod, unpad_tensors 13 | from classify.models.attention import load_attention_layer 14 | from classify.models.pooling_attention import ( 15 | SelfAttentionPooling, 16 | ReduceLayer, 17 | ) 18 | 19 | 20 | class Embedder(nn.Module): 21 | def __init__( 22 | self, 23 | args: Arguments, 24 | text_field, 25 | bidirectional: bool = True, 26 | layer_norm: bool = False, 27 | highway_bias: float = 0.0, 28 | pooling: str = "average", 29 | embedding_dropout: float = 0.1, 30 | rescale: bool = True, 31 | device: torch.device = torch.device("cpu"), 32 | ): 33 | """Constructs an model to compute embeddings.""" 34 | super(Embedder, self).__init__() 35 | 36 | # Save values 37 | self.args = args 38 | self.device = device 39 | pad_index = text_field.pad_index() 40 | self.pad_index = pad_index 41 | self.embdrop = nn.Dropout(embedding_dropout) 42 | self.pooling = pooling 43 | 44 | if self.args.bert: 45 | from transformers import AutoModel 46 | 47 | self.encoder = AutoModel.from_pretrained(args.bert_type) 48 | print("finish loading bert encoder") 49 | self.emb_size = self.encoder.config.hidden_size 50 | self.bidirectional = False 51 | self.bert_bs = args.bert_batch_size 52 | else: 53 | num_embeddings = len(text_field.vocabulary) 54 | self.num_embeddings = num_embeddings 55 | if args.small_data: 56 | self.embedding_size = 300 57 | print("random initializing for debugging") 58 | self.embedding = nn.Embedding(self.num_embeddings, self.embedding_size) 59 | else: 60 | print(f'Loading embeddings from "{args.embedding_path}"') 61 | embedding_matrix = text_field.load_embeddings(args.embedding_path) 62 | self.embedding_size = embedding_matrix.size(1) 63 | # Create models/parameters 64 | self.embedding = nn.Embedding( 65 | num_embeddings=self.num_embeddings, 66 | embedding_dim=self.embedding_size, 67 | padding_idx=self.pad_index, 68 | ) 69 | self.embedding.weight.data = embedding_matrix 70 | self.embedding.weight.requires_grad = False 71 | 72 | self.bidirectional = bidirectional 73 | self.layer_norm = layer_norm 74 | self.highway_bias = highway_bias 75 | self.rescale = rescale 76 | self.emb_size = self.args.hidden_size * (1 + self.bidirectional) 77 | 78 | if self.args.encoder == "sru": 79 | self.encoder = SRU( 80 | input_size=self.embedding_size, 81 | hidden_size=self.args.hidden_size, 82 | num_layers=self.args.num_layers, 83 | dropout=self.args.dropout, 84 | bidirectional=self.bidirectional, 85 | layer_norm=self.layer_norm, 86 | rescale=self.rescale, 87 | highway_bias=self.highway_bias, 88 | ) 89 | 90 | # if args.hidden_norm: 91 | # self.hiddennorm = nn.InstanceNorm1d(self.emb_size, affine=True) 92 | self.output_size = self.emb_size 93 | 94 | if self.pooling == "attention": 95 | self.poollayer = SelfAttentionPooling( 96 | input_dim=self.output_size, 97 | attention_heads=self.args.attention_heads, 98 | attention_units=[self.args.attention_units], 99 | input_dropout=self.args.dropout, 100 | ) 101 | else: 102 | self.poollayer = ReduceLayer(pool=pooling) 103 | 104 | # Move to device 105 | self.to(self.device) 106 | 107 | def rnn_encode( 108 | self, 109 | data: torch.LongTensor, # batch_size x seq_len 110 | return_sequence: bool = False, 111 | ) -> Tuple[List[Tuple[torch.FloatTensor, torch.FloatTensor]], List[Any]]: 112 | """ 113 | Aligns document pairs. 114 | 115 | :param data: Sentences represented as LongTensors of word indices. 116 | :param scope: A list of tuples of row_indices and column_indices indexing into data 117 | to extract the appropriate sentences for each document pair. 118 | :param data: A list of data for each document pair. 119 | :return: A tuple consisting of a list of (cost, alignment) tuples and a list of data. 120 | """ 121 | # Transpose from batch first to sequence first 122 | data = data.transpose(0, 1) # seq_len x batch_size 123 | 124 | # Create mask 125 | mask = (data != self.pad_index).float() # seq_len x batch_size 126 | 127 | # Embed 128 | embedded = self.embdrop( 129 | self.embedding(data) 130 | ) # seq_len x batch_size x embedding_size 131 | 132 | # RNN encoder 133 | h_seq, _ = self.encoder( 134 | embedded, mask_pad=(1 - mask) 135 | ) # seq_len x batch_size x 2*hidden_size 136 | # output_states, c_states = sru(x) # forward pass 137 | # output_states is (length, batch size, number of directions * hidden size) 138 | # c_states is (layers, batch size, number of directions * hidden size) 139 | 140 | h_seq = h_seq.transpose(0, 1) # batch_size x seq_len x 2*hidden_size 141 | mask = mask.transpose(0, 1) 142 | 143 | # return masked_h, masked_h_seq 144 | if return_sequence: 145 | masked_h_seq = h_seq * mask.unsqueeze( 146 | 2 147 | ) # batch_size x seq_len x 2*hidden_size 148 | return masked_h_seq # self.project(masked_h_seq) 149 | else: 150 | # Average pooling 151 | # mask = mask.unsqueeze(2) 152 | # masked_h = masked_h_seq.sum(dim=1)/mask.sum(dim=1) # batch_size x 2*hidden_size 153 | output = self.poollayer(h_seq, mask) 154 | return output # self.project(masked_h) 155 | 156 | def bert_encode( 157 | self, 158 | data: torch.LongTensor, # batch_size x seq_len 159 | return_sequence: bool = False, 160 | token_type_ids: Optional[torch.Tensor] = None, 161 | attention_mask: Optional[torch.Tensor] = None, 162 | position_ids: Optional[torch.Tensor] = None, 163 | head_mask: Optional[torch.Tensor] = None, 164 | ) -> torch.Tensor: 165 | """ 166 | Uses an RNN and self-attention to encode a batch of sequences of word embeddings. 167 | :param batch: A FloatTensor of shape `(sequence_length, batch_size, embedding_size)` containing embedded text. 168 | :param lengths: A LongTensor of shape `(batch_size)` containing the lengths of the sequences, used for masking. 169 | :return: A FloatTensor of shape `(batch_size, output_size)` containing the encoding for each sequence 170 | in the batch. 171 | """ 172 | 173 | if attention_mask is None and self.pad_index is not None: 174 | attention_mask = (data != self.pad_index).float() 175 | 176 | attention_mask = attention_mask.to(self.device) 177 | outputs = self.encoder(data, attention_mask=attention_mask) 178 | # if not 'distil' in self.args.bert_type: 179 | masked_h_seq = outputs[0] 180 | masked_h = outputs[1] 181 | # else: 182 | # masked_h = outputs[0] 183 | if return_sequence: 184 | return outputs[0] # self.project(outputs[0]) 185 | else: 186 | return outputs[1] # self.project(outputs[1]) 187 | 188 | def forward( 189 | self, 190 | data: torch.LongTensor, # batch_size x seq_len 191 | return_sequence: bool = False, 192 | ) -> Tuple[List[Tuple[torch.FloatTensor, torch.FloatTensor]], List[Any]]: 193 | if self.args.bert: 194 | # if len(data) > self.bert_bs: 195 | encodings = [] 196 | batch_size = self.bert_bs 197 | for batch_idx in range(len(data) // batch_size + 1): 198 | start_idx = batch_idx * batch_size 199 | end_idx = (batch_idx + 1) * batch_size 200 | batch = data[start_idx:end_idx] 201 | # print(data.shape) 202 | # print(batch.shape) 203 | if len(batch) == 0: 204 | break 205 | encoded = self.bert_encode(batch, return_sequence) 206 | # print(encoded.shape) 207 | encodings.extend(encoded) 208 | del encoded 209 | encodings = torch.stack(encodings) 210 | # print(encodings.shape) 211 | return encodings 212 | else: 213 | return self.rnn_encode(data, return_sequence) 214 | 215 | -------------------------------------------------------------------------------- /classify/models/pooling_attention.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Iterable 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | 8 | class SelfAttentionPooling(nn.Module): 9 | """Self attention pooling.""" 10 | def __init__(self, 11 | input_dim: int, 12 | attention_heads: int = 16, 13 | attention_units: Optional[Iterable[int]] = None, 14 | output_activation: Optional[torch.nn.Module] = None, 15 | hidden_activation: Optional[torch.nn.Module] = None, 16 | input_dropout: float = 0.1, 17 | attention_dropout: float = 0.1, 18 | ): 19 | """Initialize a self attention pooling layer 20 | 21 | Parameters 22 | ---------- 23 | input_dim : int 24 | The input data dim 25 | attention_heads: int 26 | the number of attn heads 27 | attention_units: Iterable[int] 28 | the list of hidden dimensions of the MLP computing the attn 29 | input_dropout: float 30 | dropout applied to the data argument of the forward method. 31 | attention_dropout: float 32 | dropout applied to the attention output before applying it 33 | to the input for reduction. decouples the attn dropout 34 | from the input dropout 35 | """ 36 | super().__init__() 37 | # creating the MLP 38 | dimensions = [input_dim, *attention_units, attention_heads] 39 | self.input_dim = input_dim 40 | self.in_drop = nn.Dropout(input_dropout) if input_dropout > 0. else nn.Identity() 41 | layers = [] 42 | for l in range(len(dimensions) - 2): 43 | layers.append(nn.Linear(dimensions[l], dimensions[l+1], bias=False)) 44 | layers.append(nn.Tanh() if hidden_activation is None else hidden_activation) 45 | layers.append(nn.Linear(dimensions[-2], dimensions[-1], bias=False)) 46 | if attention_dropout > 0.: 47 | layers.append(nn.Dropout(attention_dropout)) 48 | self.mlp = nn.Sequential(*layers) 49 | self.output_activation = nn.Softmax(dim=1) \ 50 | if output_activation is None else output_activation 51 | 52 | def forward(self, 53 | data: torch.Tensor, 54 | padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: 55 | """Performs a forward pass. 56 | 57 | Parameters 58 | ---------- 59 | data : torch.Tensor 60 | The input data, as a tensor of shape [B x S x H] 61 | padding_mask: torch.Tensor 62 | The input padding_mask, as a tensor of shape [B X S] 63 | 64 | Returns 65 | ---------- 66 | torch.Tensor 67 | The output data, as a tensor of shape [B x H] 68 | 69 | """ 70 | # input_tensor is 3D float tensor, batchsize x num_encs x dim 71 | batch_size, num_encs, dim = data.shape 72 | # apply input droput 73 | data = self.in_drop(data) 74 | # apply projection and reshape to batchsize x num_encs x num_heads 75 | attention_logits = self.mlp(data.reshape(-1, dim)).reshape(batch_size, num_encs, -1) 76 | # apply padding_mask. dimension stays batchsize x num_encs x num_heads 77 | if padding_mask is not None: 78 | padding_mask = padding_mask.unsqueeze(2).float() 79 | attention_logits = attention_logits * padding_mask + (1. - padding_mask) * -1e20 80 | # apply softmax. dimension stays batchsize x num_encs x num_heads 81 | attention = self.output_activation(attention_logits) 82 | # attend. attention is batchsize x num_encs x num_heads. data is batchsize x num_encs x dim 83 | # resulting dim is batchsize x num_heads x dim 84 | attended = torch.bmm(attention.transpose(1, 2), data) 85 | # average over attention heads and return. dimension is batchsize x dim 86 | return attended.mean(dim=1) 87 | 88 | 89 | class ReduceLayer(nn.Module): 90 | """Implement an sigmoid module. 91 | 92 | Can be used to form a classifier out of any encoder. 93 | Note: by default takes the log_softmax so that it can be fed to 94 | the NLLLoss module. You can disable this behavior through the 95 | `take_log` argument. 96 | 97 | """ 98 | def __init__(self, 99 | pool: str='average', 100 | reduce_dim: int = 1, 101 | padding_idx: Optional[int] = 0) -> None: 102 | """Initialize the SoftmaxLayer. 103 | 104 | Parameters 105 | ---------- 106 | """ 107 | super().__init__() 108 | # output of nn.embedding: B X S X E 109 | # input and output of RNN: S X B X H 110 | # Padding mask: B X S 111 | self.reduce_dim = reduce_dim # Most of time, output is B x S x E, with seqlength on dimension 1 112 | self.pool = pool 113 | # self.padding_idx = padding_idx 114 | 115 | 116 | def forward(self, 117 | data: torch.Tensor, 118 | state: Optional[torch.Tensor] = None, 119 | padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: 120 | """Perform a forward pass through the network. 121 | 122 | Parameters 123 | ---------- 124 | data : torch.Tensor 125 | The input data, as a float tensor of shape [B x S x E] 126 | state: Tensor 127 | An optional previous state of shape [L x B x H] 128 | padding_mask: Tensor, optional 129 | The padding mask of shape [B x S] 130 | 131 | Returns 132 | ------- 133 | torch.Tensor 134 | The encoded output, as a float tensor of shape [B x H] 135 | 136 | """ 137 | output = data 138 | # print('input') 139 | # print(output.shape) 140 | if padding_mask is None: 141 | padding_mask = torch.ones(*output.shape[:2]).to(output) 142 | 143 | # print('mask') 144 | # print(padding_mask.shape) 145 | 146 | # cast(torch.Tensor, padding_mask) 147 | if self.pool == 'average': 148 | # print(padding_mask.shape) 149 | # print(data.shape) 150 | padding_mask = padding_mask.unsqueeze(2) 151 | output = (output * padding_mask).sum(dim=self.reduce_dim) #BXE 152 | output = output / padding_mask.sum(dim=self.reduce_dim) 153 | elif self.pool == 'sum': 154 | output = (output * padding_mask.unsqueeze(2)).sum(dim=self.reduce_dim) 155 | elif self.pool == 'last': 156 | lengths = padding_mask.long().sum(dim=self.reduce_dim) 157 | output = output[torch.arange(output.size(0)).long(), lengths - 1, :] 158 | elif self.pool == 'first': 159 | output = output[torch.arange(output.size(0)).long(), 0, :] 160 | elif self.pool == 'sqrt_reduction': 161 | '''original implementation can be found here 162 | https://github.asapp.dev/aganatra/nlp/blob/master/src/agnlp/utils/sqrt_n_reduction.py''' 163 | padding_mask = padding_mask.unsqueeze(2) 164 | output = (output * padding_mask).sum(dim=self.reduce_dim) #BXE 165 | output = output/sqrt(padding_mask.sum(dim=self.reduce_dim).float()) 166 | # elif self.pool == 'decay': 167 | # xxxx 168 | else: 169 | pool = self.pool 170 | print(pool) 171 | raise ValueError(f"Invalid pool type: {pool}") 172 | 173 | return output 174 | 175 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fasttext 2 | matplotlib 3 | nltk 4 | numpy 5 | pandas 6 | scikit-learn 7 | sru 8 | pyyaml 9 | tensorboardX 10 | tensorflow 11 | torch 12 | tqdm 13 | typed-argument-parser 14 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | setup.py 3 | """ 4 | 5 | from setuptools import setup, find_packages 6 | from typing import Dict 7 | import os 8 | 9 | 10 | NAME = "rationale-alignment" 11 | AUTHOR = "ASAPP Inc." 12 | EMAIL = "liliyu@asapp.com" 13 | DESCRIPTION = ( 14 | "Pytorch based library for ACL2020 paper about rationalizing text matching." 15 | ) 16 | 17 | 18 | def readme(): 19 | with open("README.md", encoding="utf-8") as f: 20 | return f.read() 21 | 22 | 23 | def required(): 24 | with open("requirements.txt") as f: 25 | return f.read().splitlines() 26 | 27 | 28 | setup( 29 | name=NAME, 30 | version="0.0.1", 31 | description=DESCRIPTION, 32 | # Author information 33 | author=AUTHOR, 34 | author_email=EMAIL, 35 | license="MIT", 36 | # What is packaged here. 37 | packages=find_packages(), 38 | install_requires=required(), 39 | python_requires=">=3.6.1", 40 | zip_safe=True, 41 | ) 42 | -------------------------------------------------------------------------------- /similarity/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asappresearch/rationale-alignment/8d2bf06ba4c121863833094d5d4896bf34a9a73e/similarity/__init__.py -------------------------------------------------------------------------------- /similarity/compute/__init__.py: -------------------------------------------------------------------------------- 1 | from similarity.compute.trainer import AlignmentTrainer 2 | 3 | __all__ = ["AlignmentTrainer"] 4 | -------------------------------------------------------------------------------- /similarity/data/__init__.py: -------------------------------------------------------------------------------- 1 | from similarity.data.load import load_data 2 | from similarity.data.sampler import Sampler 3 | 4 | __all__ = ["load_data", "Sampler"] 5 | -------------------------------------------------------------------------------- /similarity/data/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Set 2 | 3 | import torch 4 | 5 | 6 | class Dataset: 7 | def __init__(self, 8 | ids: Set[str], 9 | id_to_document: Dict[str, List[torch.LongTensor]], 10 | id_mapping: Dict[str, Dict[str, Set[str]]], 11 | negative_ids: Set[id] = None): 12 | """ 13 | Holds an AskUbuntu alignment dataset. 14 | 15 | :param ids: A set of ids from which to sample during training. 16 | Note: May not contain all ids since some ids should not be sampled. 17 | :param id_to_document: A dictionary mapping ids to a dictionary 18 | which maps "sentences" to the sentences in the document. 19 | :param id_mapping: A dictionary mapping ids to a dictionary which maps 20 | "similar" to similar ids and "dissimilar" to dissimilar ids. 21 | :param negative_ids: The set of ids which can be sampled as negatives. 22 | If None, any id can be sampled as a negative. 23 | """ 24 | self.id_set = ids 25 | self.id_list = sorted(self.id_set) 26 | self.id_to_document = id_to_document 27 | self.id_mapping = id_mapping 28 | self.negative_ids = negative_ids or self.id_set 29 | 30 | def __len__(self) -> int: 31 | return len(self.id_set) 32 | -------------------------------------------------------------------------------- /similarity/data/load.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | 5 | from similarity.data.loaders import ( 6 | AskUbuntuDataLoader, 7 | MultiNewsDataLoader, 8 | ) 9 | from similarity.data.sampler import Sampler 10 | from similarity.data.text import TextField 11 | from utils.parsing import Arguments 12 | 13 | 14 | def load_data( 15 | args: Arguments, device: torch.device 16 | ) -> Tuple[TextField, Sampler, Sampler, Sampler]: 17 | """Loads data and returns a TextField and train, dev, and test Samplers.""" 18 | # Default to sampling negatives 19 | resample_negatives = True 20 | 21 | # Get DataLoader 22 | if args.dataset in ["askubuntu", "superuser_askubuntu"]: 23 | data_loader = AskUbuntuDataLoader(args) 24 | assert (args.dev_path is None) == (args.test_path is None) 25 | resample_negatives = args.dev_path is None 26 | elif args.dataset == "summary": 27 | data_loader = SummaryDataLoader(args) 28 | elif args.dataset == "multinews": 29 | data_loader = MultiNewsDataLoader(args) 30 | elif args.dataset == "pubmed": 31 | data_loader = PubmedDataLoader(args) 32 | elif args.dataset == "pubmedsummary": 33 | data_loader = PubmedSummaryDataLoader(args) 34 | else: 35 | raise ValueError(f'Dataset "{args.dataset}" not supported') 36 | 37 | # Create Samplers 38 | train_sampler = Sampler( 39 | data=data_loader.train, 40 | text_field=data_loader.text_field, 41 | batch_size=args.batch_size, 42 | shuffle=True, 43 | num_positives=args.num_positives, 44 | num_negatives=args.num_negatives, 45 | resample_negatives=resample_negatives, 46 | device=device, 47 | ) 48 | 49 | dev_sampler = Sampler( 50 | data=data_loader.dev, 51 | text_field=data_loader.text_field, 52 | batch_size=args.batch_size, 53 | num_positives=args.num_eval_positives, 54 | num_negatives=args.num_eval_negatives, 55 | device=device, 56 | ) 57 | 58 | test_sampler = Sampler( 59 | data=data_loader.test, 60 | text_field=data_loader.text_field, 61 | batch_size=args.batch_size, 62 | num_positives=args.num_eval_positives, 63 | num_negatives=args.num_eval_negatives, 64 | device=device, 65 | ) 66 | 67 | return data_loader.text_field, train_sampler, dev_sampler, test_sampler 68 | -------------------------------------------------------------------------------- /similarity/data/loaders/__init__.py: -------------------------------------------------------------------------------- 1 | from similarity.data.loaders.askubuntu import AskUbuntuDataLoader 2 | from similarity.data.loaders.multinews import MultiNewsDataLoader 3 | 4 | 5 | __all__ = [ 6 | "AskUbuntuDataLoader", 7 | "MultiNewsDataLoader", 8 | ] 9 | -------------------------------------------------------------------------------- /similarity/data/loaders/loader.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | from similarity.data.dataset import Dataset 4 | from similarity.data.text import TextField 5 | 6 | 7 | class DataLoader: 8 | @property 9 | @abstractmethod 10 | def train(self) -> Dataset: 11 | """Returns the training data.""" 12 | pass 13 | 14 | @property 15 | @abstractmethod 16 | def dev(self) -> Dataset: 17 | """Returns the validation data.""" 18 | pass 19 | 20 | @property 21 | @abstractmethod 22 | def test(self) -> Dataset: 23 | """Returns the test data.""" 24 | pass 25 | 26 | @property 27 | @abstractmethod 28 | def text_field(self) -> TextField: 29 | """Returns the text field.""" 30 | pass 31 | 32 | def print_stats(self) -> None: 33 | """Prints statistics about the data.""" 34 | print() 35 | print(f"Total size = {len(self.train) + len(self.dev) + len(self.test):,}") 36 | print() 37 | print(f"Train size = {len(self.train):,}") 38 | print(f"Dev size = {len(self.dev):,}") 39 | print(f"Test size = {len(self.test):,}") 40 | print() 41 | # print(f'Vocabulary size = {len(self.text_field.vocabulary):,}') 42 | print() 43 | -------------------------------------------------------------------------------- /similarity/data/loaders/multinews.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from copy import deepcopy 3 | from tqdm import tqdm 4 | from typing import Dict, List, Set 5 | 6 | import torch 7 | 8 | from similarity.data.dataset import Dataset 9 | from similarity.data.loaders.loader import DataLoader 10 | from similarity.data.text import TextField 11 | from similarity.data.utils import split_data, text_to_sentences 12 | from utils.parsing import MultiNewsArguments 13 | 14 | 15 | class MultiNewsDataLoader(DataLoader): 16 | def __init__(self, args: MultiNewsArguments): 17 | """Loads the multinews dataset.""" 18 | # Load data 19 | article_groups = self.load_data(args.news_path, small_data=args.small_data) 20 | 21 | # Create ID mapping 22 | id_mapping = {} 23 | for i, articles in enumerate(article_groups): 24 | similar_ids = {f"{i}_{j}" for j in range(len(articles))} 25 | 26 | for j, article in enumerate(articles): 27 | id = f"{i}_{j}" 28 | similar_ids.remove(id) 29 | id_mapping[id] = {"similar": deepcopy(similar_ids), "dissimilar": set()} 30 | similar_ids.add(id) 31 | 32 | # Create ID to document 33 | id_to_document: Dict[str, List[str]] = { 34 | f"{i}_{j}": text_to_sentences( 35 | text=article, 36 | sentence_tokenize=not args.no_sentence_tokenize, 37 | max_num_sentences=args.max_num_sentences, 38 | max_sentence_length=args.max_sentence_length, 39 | ) 40 | for i, articles in enumerate(tqdm(article_groups)) 41 | for j, article in enumerate(articles) 42 | } 43 | 44 | # Create text field 45 | self._text_field = TextField() 46 | self._text_field.build_vocab( 47 | sentence for document in id_to_document.values() for sentence in document 48 | ) 49 | 50 | print_stats_for_paper = False 51 | if print_stats_for_paper: 52 | print("\n\n==count of simiular pairs:") 53 | # print(cnt_pospair) 54 | print(sum([len(mapping["similar"]) for mapping in id_mapping.values()]) / 2) 55 | 56 | # Create ID to document 57 | from similarity.data.utils import tokenize_sentence 58 | 59 | cnt_sent_per_doc = [ 60 | len(tokenize_sentence(article)) 61 | for articles in article_groups 62 | for article in articles 63 | ] 64 | print("\n==Average sentence count:") 65 | print(sum(cnt_sent_per_doc) / len(cnt_sent_per_doc)) 66 | print("==Max sentence count:") 67 | print(max(cnt_sent_per_doc)) 68 | 69 | cnt_words_per_doc = [ 70 | len(article.split()) 71 | for articles in article_groups 72 | for article in articles 73 | ] 74 | print("==Average words count:") 75 | print(sum(cnt_words_per_doc) / len(cnt_words_per_doc)) 76 | 77 | print("==Max words count:") 78 | print(max(cnt_words_per_doc)) 79 | 80 | print("==count of total documents:") 81 | print(len(cnt_sent_per_doc)) 82 | print("\n\n") 83 | print(f"\n==Vocabulary size = {len(self.text_field.vocabulary):,}") 84 | import sys 85 | 86 | sys.exit() 87 | 88 | # Convert sentences to indices 89 | id_to_document: Dict[str, List[torch.LongTensor]] = { 90 | id: [self._text_field.process(sentence) for sentence in document] 91 | for id, document in tqdm(id_to_document.items()) 92 | } 93 | 94 | # Split data 95 | train_groups, dev_groups, test_groups = split_data( 96 | list(range(len(article_groups))) 97 | ) 98 | 99 | train_ids = { 100 | f"{i}_{j}" for i in train_groups for j in range(len(article_groups[i])) 101 | } 102 | dev_ids = { 103 | f"{i}_{j}" for i in dev_groups for j in range(len(article_groups[i])) 104 | } 105 | test_ids = { 106 | f"{i}_{j}" for i in test_groups for j in range(len(article_groups[i])) 107 | } 108 | 109 | # Define train, dev, test datasets 110 | self._train = Dataset( 111 | ids=train_ids, id_to_document=id_to_document, id_mapping=id_mapping 112 | ) 113 | self._dev = Dataset( 114 | ids=dev_ids, id_to_document=id_to_document, id_mapping=id_mapping 115 | ) 116 | self._test = Dataset( 117 | ids=test_ids, id_to_document=id_to_document, id_mapping=id_mapping 118 | ) 119 | 120 | self.print_stats() 121 | 122 | @staticmethod 123 | def load_data(path: str, small_data: bool = False) -> List[List[str]]: 124 | num_examples = 100 if small_data else float("inf") 125 | 126 | article_groups = [] 127 | with open(path) as f: 128 | for line in tqdm(f): 129 | articles = [ 130 | article.strip() 131 | for article in line.replace("NEWLINE_CHAR", "\n").split("|||||") 132 | ] 133 | articles = [article for article in articles if article != ""] 134 | article_groups.append(articles) 135 | 136 | if len(article_groups) >= num_examples: 137 | break 138 | 139 | # Try to remove junk by only keeping articles that appear once (i.e. not a common error message) 140 | article_counts = Counter( 141 | article for articles in article_groups for article in articles 142 | ) 143 | article_groups = [ 144 | [article for article in articles if article_counts[article] == 1] 145 | for articles in article_groups 146 | ] 147 | 148 | # Require at least two articles per group so that there are similar articles 149 | article_groups = [articles for articles in article_groups if len(articles) >= 2] 150 | 151 | return article_groups 152 | 153 | @property 154 | def train(self) -> Dataset: 155 | return self._train 156 | 157 | @property 158 | def dev(self) -> Dataset: 159 | return self._dev 160 | 161 | @property 162 | def test(self) -> Dataset: 163 | return self._test 164 | 165 | @property 166 | def text_field(self) -> TextField: 167 | return self._text_field 168 | -------------------------------------------------------------------------------- /similarity/data/text.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import csv 3 | from itertools import chain 4 | import os 5 | from typing import Callable, Dict, Iterable, Optional 6 | 7 | import numpy as np 8 | import pandas as pd 9 | from sklearn.feature_extraction.text import TfidfVectorizer 10 | import torch 11 | from tqdm import tqdm 12 | 13 | 14 | class TextField: 15 | def __init__(self, 16 | skip_oov: bool = False, 17 | lower: bool = False, 18 | tokenizer: Callable[[str], Iterable[str]] = lambda text: text.split(), 19 | pad_token: Optional[str] = '', 20 | unk_token: Optional[str] = '', 21 | sos_token: Optional[str] = None, 22 | eos_token: Optional[str] = None, 23 | vocabulary: Optional[Dict[str, int]] = None): 24 | if vocabulary: 25 | self.vocabulary = OrderedDict((tok, i) for i, tok in enumerate(vocabulary)) 26 | else: 27 | # Add specials 28 | self.vocabulary = OrderedDict() 29 | 30 | specials = [pad_token, unk_token, sos_token, eos_token] 31 | 32 | for token in filter(lambda x: x is not None, specials): 33 | self.vocabulary[token] = len(self.vocabulary) 34 | 35 | self.pad = pad_token 36 | self.unk = unk_token 37 | self.sos = sos_token 38 | self.eos = eos_token 39 | 40 | self.lower = lower 41 | self.tokenizer = tokenizer 42 | 43 | self._embeddings = None 44 | 45 | self._build_reverse_vocab() 46 | self.skip_oov = skip_oov 47 | self.weights = {} 48 | self.avg_weight = 1.0 49 | 50 | def build_idf_weights(self, *data: Iterable[str]) -> torch.FloatTensor: 51 | """Build IDF weights. 52 | 53 | Parameters 54 | ---------- 55 | data : Iterable[str] 56 | List of input strings. 57 | 58 | """ 59 | data = chain.from_iterable(data) 60 | if self.lower: 61 | data = [text.lower() for text in data] 62 | 63 | vectorizer = TfidfVectorizer(min_df=1, ngram_range=(1, 1), binary=False) 64 | vectorizer.fit(data) 65 | 66 | self.weights = { 67 | word: idf 68 | for word, idf in zip(vectorizer.get_feature_names(), vectorizer.idf_) 69 | if word in self.vocabulary 70 | } 71 | self.avg_weight = np.mean(list(self.weights.values())) 72 | 73 | def _init_vocabulary(self) -> None: 74 | """Initializes vocabulary with special tokens.""" 75 | # Add specials 76 | self.vocabulary = OrderedDict() 77 | 78 | specials = [self.pad, self.unk, self.sos, self.eos] 79 | 80 | for token in filter(lambda x: x is not None, specials): 81 | self.vocabulary[token] = len(self.vocabulary) # type: ignore 82 | 83 | def load_vocab(self, path: str) -> Dict[str, int]: 84 | """Loads a vocabulary from a .txt file. 85 | 86 | Returns 87 | ------- 88 | Dict[str, int] 89 | A vocabulary dictionary mapping from string to int. 90 | 91 | """ 92 | self._init_vocabulary() 93 | 94 | with open(path) as f: 95 | words = [word for line in f for word in line.strip().split()] 96 | 97 | for word in words: 98 | self.vocabulary[word] = len(self.vocabulary) 99 | 100 | return self.vocabulary 101 | 102 | def build_vocab(self, data: Iterable[str], *args) -> Dict[str, int]: 103 | """Build the vocabulary. 104 | Parameters 105 | ---------- 106 | data : Iterable[str] 107 | List of input strings. 108 | """ 109 | datasets = [data] + list(args) 110 | for dataset in datasets: 111 | for example in tqdm(dataset): 112 | # Lowercase if requested 113 | example = example.lower() if self.lower else example 114 | # Tokenize and add to vocabulary 115 | for token in self.tokenizer(example): 116 | self.vocabulary.setdefault(token, len(self.vocabulary)) 117 | 118 | self._build_reverse_vocab() 119 | 120 | return self.vocabulary 121 | 122 | def load_embeddings(self, path: str) -> torch.FloatTensor: 123 | """Load pretrained word embeddings. 124 | 125 | Parameters 126 | ---------- 127 | path : str 128 | The path to the pretrained embeddings 129 | Returns 130 | ------- 131 | torch.FloatTensor 132 | The matrix of pretrained word embeddings 133 | 134 | """ 135 | ext = os.path.splitext(path)[-1] 136 | 137 | if ext == '.bin': # fasttext 138 | try: 139 | import fasttext 140 | except Exception: 141 | try: 142 | import fastText as fasttext 143 | except Exception: 144 | raise ValueError("fasttext not installed.") 145 | model = fasttext.load_model(path) 146 | vectors = [model.get_word_vector(token) * self.weights.get(token, self.avg_weight) for token in tqdm(self.vocabulary)] 147 | else: 148 | # Load any .txt or word2vec kind of format 149 | model = dict() 150 | data = pd.read_csv(path, sep=" ", index_col=0, header=None, quoting=csv.QUOTE_NONE) 151 | embedding_size = len(data.columns) 152 | for word, vector in data.iterrows(): 153 | if word in self.vocabulary: 154 | model[word] = np.array(vector.values) * self.weights.get(word, self.avg_weight) 155 | 156 | # Reorder according to self._vocab 157 | vectors = [model.get(token, np.zeros(embedding_size)) for token in self.vocabulary] 158 | 159 | self.embeddings = torch.FloatTensor(np.array(vectors)) 160 | return self.embeddings 161 | 162 | def process(self, example: str) -> torch.LongTensor: # type: ignore 163 | """Process an example, and create a Tensor. 164 | Parameters 165 | ---------- 166 | example: str 167 | The example to process, as a single string 168 | Returns 169 | ------- 170 | torch.LongTensor 171 | The processed example, tokenized and numericalized 172 | """ 173 | # Lowercase and tokenize 174 | example = example.lower() if self.lower else example 175 | tokens = self.tokenizer(example) 176 | 177 | # Add extra tokens 178 | if self.sos is not None: 179 | tokens = [self.sos] + list(tokens) 180 | if self.eos is not None: 181 | tokens = list(tokens) + [self.eos] 182 | 183 | # Numericalize 184 | numericals = [] 185 | for token in tokens: 186 | if token not in self.vocabulary: 187 | if self.unk is None or self.unk not in self.vocabulary: 188 | raise ValueError("Encounterd out-of-vocabulary token \ 189 | but the unk_token is either missing \ 190 | or not defined in the vocabulary.") 191 | else: 192 | token = self.unk 193 | 194 | numerical = self.vocabulary[token] # type: ignore 195 | numericals.append(numerical) 196 | 197 | processed = torch.LongTensor(numericals) 198 | return processed 199 | 200 | def _build_reverse_vocab(self) -> None: 201 | """Builds reverse vocabulary.""" 202 | self._reverse_vocab = {index: token for token, index in self.vocabulary.items()} 203 | 204 | def deprocess(self, indices: torch.LongTensor) -> str: 205 | """Converts indices to string.""" 206 | pad_index = self.vocabulary[self.pad] 207 | return ' '.join(self._reverse_vocab[index.item()] for index in indices if index != pad_index) 208 | 209 | def pad_index(self) -> int: 210 | return self.vocabulary[self.pad] 211 | -------------------------------------------------------------------------------- /similarity/data/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Any, List, Optional, Tuple 3 | 4 | 5 | 6 | sentence_tokenizer = None 7 | 8 | 9 | def split_data(data: List[Any], 10 | sizes: Tuple[float, float, float] = (0.8, 0.1, 0.1), 11 | seed: int = 0) -> Tuple[List[Any], List[Any], List[Any]]: 12 | """ 13 | Randomly splits data into train, val, and test sets according to the provided sizes. 14 | 15 | :param data: The data to split into train, val, and test. 16 | :param sizes: The sizes of the train, val, and test sets (as a proportion of total size). 17 | :param seed: Random seed. 18 | :return: Train, val, and test sets. 19 | """ 20 | # Checks 21 | assert len(sizes) == 3 22 | assert all(0 <= size <= 1 for size in sizes) 23 | assert sum(sizes) == 1 24 | 25 | # Shuffle 26 | random.seed(seed) 27 | random.shuffle(data) 28 | 29 | # Determine split sizes 30 | train_size = int(sizes[0] * len(data)) 31 | train_val_size = int((sizes[0] + sizes[1]) * len(data)) 32 | 33 | # Split 34 | train = data[:train_size] 35 | val = data[train_size:train_val_size] 36 | test = data[train_val_size:] 37 | 38 | return train, val, test 39 | 40 | 41 | def tokenize_sentence(text: str) -> List[str]: 42 | """ 43 | Tokenizes text into sentences. 44 | 45 | :param text: A string. 46 | :return: A list of sentences. 47 | """ 48 | global sentence_tokenizer 49 | 50 | if sentence_tokenizer is None: 51 | import nltk 52 | sentence_tokenizer = nltk.load('tokenizers/punkt/english.pickle') 53 | 54 | return sentence_tokenizer.tokenize(text) 55 | 56 | 57 | def text_to_sentences(text: str, 58 | tokenizer: str='sentence', 59 | sentence_tokenize: bool = True, 60 | max_num_sentences: Optional[int] = None, 61 | max_sentence_length: Optional[int] = None) -> List[str]: 62 | """ 63 | Splits text into sentences (if desired). 64 | 65 | Also enforces a maximum sentence length 66 | and maximum number of sentences. 67 | 68 | :param text: The text to split. 69 | :param sentence_tokenize: Whether to split into sentences. 70 | :param max_num_sentences: Maximum number of sentences. 71 | :param max_sentence_length: Maximum length of a sentence (in tokens). 72 | :return: The text split into sentences (if desired) 73 | or as just a single sentence. 74 | """ 75 | # Sentence tokenize 76 | if sentence_tokenize: 77 | sentences = tokenize_sentence(text)[:max_num_sentences] 78 | else: 79 | sentences = [text] 80 | 81 | # Enforce maximum sentence length 82 | sentences = [' '.join(sentence.split()[:max_sentence_length]) for sentence in sentences] 83 | 84 | return sentences 85 | 86 | def pubmed_tokenizer(text: str, 87 | tokenizer: str='sentence', 88 | # predictor: Predictor=None, 89 | max_num_sentences: Optional[int] = None, 90 | max_sentence_length: Optional[int] = None) -> List[str]: 91 | pass 92 | 93 | ''' 94 | def pubmed_tokenizer(text: str, 95 | tokenizer: str='sentence', 96 | predictor: Predictor=None, 97 | max_num_sentences: Optional[int] = None, 98 | max_sentence_length: Optional[int] = None) -> List[str]: 99 | """ 100 | # from allennlp.predictors import Predictor 101 | Splits text into sentences (if desired). 102 | 103 | Also enforces a maximum sentence length 104 | and maximum number of sentences. 105 | 106 | :param text: The text to split. 107 | :param sentence_tokenize: Whether to split into sentences. 108 | :param max_num_sentences: Maximum number of sentences. 109 | :param max_sentence_length: Maximum length of a sentence (in tokens). 110 | :return: The text split into sentences (if desired) 111 | or as just a single sentence. 112 | """ 113 | # Sentence tokenize 114 | if tokenizer =='sentence': 115 | sentences = tokenize_sentence(text)[:max_num_sentences] 116 | elif tokenizer == 'word': 117 | sentences = [text] 118 | elif tokenizer == 'phrase': 119 | print('tokenizing phrase') 120 | full_sentences = tokenize_sentence(text)[:max_num_sentences] 121 | sentences = [] 122 | for sent in full_sentences: 123 | sentences.extend(phrase_tokenizer(sent, predictor, phrase_len=5)) 124 | sentences = sentences[:max_num_sentences*3] 125 | else: 126 | print('unknow tokenizer') 127 | # Enforce maximum sentence length 128 | sentences = [' '.join(sentence.split()[:max_sentence_length]) for sentence in sentences] 129 | 130 | return sentences 131 | ''' 132 | 133 | 134 | def process_pubmed_sentences(text: List[List[str]], 135 | sentence_tokenize: bool = True, 136 | max_num_sentences: Optional[int] = None, 137 | max_sentence_length: Optional[int] = None) -> List[str]: 138 | """ 139 | Splits text into sentences (if desired). 140 | Also enforces a maximum sentence length 141 | and maximum number of sentences. 142 | :param text: The text to split. 143 | :param sentence_tokenize: Whether to split into sentences. 144 | :param max_num_sentences: Maximum number of sentences. 145 | :param max_sentence_length: Maximum length of a sentence (in tokens). 146 | :return: The text split into sentences (if desired) 147 | or as just a single sentence. 148 | """ 149 | # Sentence tokenize 150 | if sentence_tokenize: 151 | sentences = text[:max_num_sentences] 152 | else: 153 | sentences = [word for sent in text for word in sent] 154 | 155 | # Enforce maximum sentence length 156 | sentences = [' '.join(sentence[:max_sentence_length]) for sentence in sentences] 157 | 158 | return sentences -------------------------------------------------------------------------------- /similarity/metric/__init__.py: -------------------------------------------------------------------------------- 1 | from similarity.metric.abstract import Metric 2 | from similarity.metric.load import load_loss_and_metrics 3 | 4 | __all__ = ["load_loss_and_metrics", "Metric"] 5 | -------------------------------------------------------------------------------- /similarity/metric/abstract.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Dict, List, Optional, Tuple, Union 3 | 4 | import torch 5 | 6 | from sinkhorn import compute_alignment_cost, compute_entropy 7 | 8 | 9 | class Metric: 10 | @abstractmethod 11 | def compute(self, preds, targets, *argv) -> torch.float: 12 | pass 13 | 14 | def __call__(self, preds, targets, *argv) -> torch.float: 15 | return self.compute(preds, targets, *argv) 16 | 17 | def __str__(self) -> str: 18 | return self.__class__.__name__ 19 | 20 | 21 | class AlignmentMetric(Metric): 22 | """Computes the metric for saying one document is aligned with another.""" 23 | 24 | @staticmethod 25 | def _compute_entropy(preds: List[Tuple[torch.FloatTensor, torch.FloatTensor]]) -> torch.FloatTensor: 26 | """Computes the entropy term (epislon * H(P)) of each (cost, alignment) tuple in preds.""" 27 | return torch.stack([compute_entropy(alignment) for cost, alignment in preds], dim=0) 28 | 29 | @staticmethod 30 | def _compute_cost(preds: List[Tuple[torch.FloatTensor, torch.FloatTensor]]) -> torch.FloatTensor: 31 | """Computes the alignment cost of each (cost, alignment) tuple in preds.""" 32 | return torch.stack([compute_alignment_cost(C=cost, P=alignment) for cost, alignment in preds], dim=0) 33 | 34 | @staticmethod 35 | def _compute_similarities(preds: List[Tuple[torch.FloatTensor, torch.FloatTensor]]) -> torch.FloatTensor: 36 | """Computes the alignment similarities (i.e. -cost) of each (cost, alignment) tuple in preds.""" 37 | return -AlignmentMetric._compute_cost(preds) 38 | 39 | @abstractmethod 40 | def compute(self, 41 | preds: List[Tuple[torch.FloatTensor, torch.FloatTensor]], 42 | targets: Union[List[torch.LongTensor], List[Dict[str, torch.LongTensor]]], 43 | step: Optional[int]) -> torch.float: 44 | pass 45 | 46 | 47 | class AlignmentAverageMetric(AlignmentMetric): 48 | """Computes a metric and averages it across documents.""" 49 | 50 | def __init__(self, similar: Optional[bool] = None): 51 | self.similar = similar # Whether to only include similar or only dissimilar examples 52 | 53 | @abstractmethod 54 | def _compute_one(self, 55 | cost: torch.FloatTensor, 56 | alignment: torch.FloatTensor, 57 | target: int) -> torch.float: 58 | """ 59 | Computes the metric and count of aligning two documents. 60 | 61 | :param cost: The cost of aligning sentence i with sentence j (matrix is n x m). 62 | :param alignment: The probability of aligning sentence i with sentence j (matrix is n x m). 63 | :param target: Whether the documents are similar or not. 64 | :return: The value. 65 | """ 66 | pass 67 | 68 | def _compute_count(self, 69 | cost: torch.FloatTensor, 70 | alignment: torch.FloatTensor, 71 | target: int) -> int: 72 | """ 73 | Computes the count of items associated with the documents for the purpose of averaging. 74 | 75 | :param cost: The cost of aligning sentence i with sentence j (matrix is n x m). 76 | :param alignment: The probability of aligning sentence i with sentence j (matrix is n x m). 77 | :param target: Whether the documents are similar or not. 78 | :return: The count (typically either # of sentences or 1). 79 | """ 80 | if self.similar is None: 81 | return 1 82 | 83 | return target == self.similar 84 | 85 | def compute(self, 86 | preds: List[Tuple[torch.FloatTensor, torch.FloatTensor]], 87 | targets: List[Dict[str, torch.LongTensor]]) -> torch.float: 88 | """ 89 | Computes metric across a list of instances of two sets of objects. 90 | 91 | :param preds: A list of (cost, alignment) tuples (each is n x m). 92 | :param targets: A list of LongTensors indicating the correct alignment. 93 | :return: The metric of the alignments. 94 | """ 95 | # Initialize 96 | metric, count = 0, 0 97 | 98 | # Extract targets 99 | targets = [t.item() for target in targets for t in target['targets']] 100 | 101 | # Check lengths 102 | assert len(preds) == len(targets) 103 | 104 | # Loop over alignments and add metric and count 105 | for (cost, alignment), target in zip(preds, targets): 106 | new_count = self._compute_count(cost, alignment, target) 107 | 108 | if new_count == 0: 109 | continue 110 | 111 | count += new_count 112 | metric += self._compute_one(cost, alignment, target) 113 | 114 | # Average metric 115 | metric = metric / count if count != 0 else 0 116 | 117 | return metric 118 | 119 | def __str__(self) -> str: 120 | super_str = super(AlignmentAverageMetric, self).__str__() 121 | 122 | if self.similar is None: 123 | return super_str 124 | 125 | return ('Similar' if self.similar else 'Dissimilar') + super_str 126 | -------------------------------------------------------------------------------- /similarity/metric/dev/__init__.py: -------------------------------------------------------------------------------- 1 | from similarity.metric.dev.similarity import AUC, MAP, MRR, Precision 2 | 3 | 4 | __all__ = ["AUC", "MAP", "MRR", "Precision"] 5 | -------------------------------------------------------------------------------- /similarity/metric/dev/auc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sklearn.metrics 3 | import torch 4 | 5 | from similarity.metric.abstract import Metric 6 | 7 | 8 | class AUC(Metric): 9 | def __init__(self, max_fpr: float = 1.0): 10 | """Initialize the AUC metric. 11 | 12 | Parameters 13 | ---------- 14 | max_fpr : float, optional 15 | Maximum false positive rate to compute the area under 16 | 17 | """ 18 | self.max_fpr = max_fpr 19 | 20 | def compute(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 21 | """Compute AUC at the given max false positive rate. 22 | 23 | Parameters 24 | ---------- 25 | pred : torch.Tensor 26 | The model predictions 27 | target : torch.Tensor 28 | The binary targets 29 | 30 | Returns 31 | ------- 32 | torch.Tensor 33 | The computed AUC 34 | 35 | """ 36 | scores = np.array(pred) 37 | targets = np.array(target) 38 | 39 | # Case when number of elements added are 0 40 | if ( 41 | not scores.size 42 | or not targets.size 43 | or (~np.isfinite(scores)).sum() 44 | or (~np.isfinite(targets)).sum() 45 | ): 46 | return torch.tensor(0.5) 47 | 48 | fpr, tpr, _ = sklearn.metrics.roc_curve(targets, scores, sample_weight=None) 49 | 50 | # Compute the area under the curve using trapezoidal rule 51 | max_index = np.searchsorted(fpr, [self.max_fpr], side="right").item() 52 | 53 | # Ensure we integrate up to max_fpr 54 | fpr, tpr = fpr.tolist(), tpr.tolist() 55 | fpr, tpr = fpr[:max_index], tpr[:max_index] 56 | fpr.append(self.max_fpr) 57 | tpr.append(max(tpr)) 58 | 59 | area = np.trapz(tpr, fpr) 60 | 61 | return torch.tensor(area / self.max_fpr).float() 62 | -------------------------------------------------------------------------------- /similarity/metric/dev/similarity.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Dict, List, Tuple 3 | 4 | import torch 5 | 6 | from similarity.metric.abstract import AlignmentMetric 7 | from similarity.metric.dev.auc import AUC as RawAUC 8 | 9 | 10 | class AUC(RawAUC, AlignmentMetric): 11 | """Computes AUC of aligning documents.""" 12 | 13 | def compute(self, 14 | preds: List[Tuple[torch.FloatTensor, torch.FloatTensor]], 15 | targets: List[Dict[str, torch.LongTensor]]) -> torch.float: 16 | similarities = self._compute_similarities(preds) 17 | 18 | all_true, all_pred = [], [] 19 | for target in targets: 20 | # Get true and pred for example 21 | true, pred = target['targets'], similarities[target['scope']] 22 | 23 | all_true += true.numpy().tolist() 24 | all_pred += pred.numpy().tolist() 25 | 26 | auc = super(AUC, self).compute(all_pred, all_true) 27 | 28 | return auc 29 | 30 | def __str__(self) -> str: 31 | if self.max_fpr == 1.0: 32 | return self.__class__.__name__ 33 | return f'{self.__class__.__name__}_{self.max_fpr}' 34 | 35 | 36 | class SimilarityMetric(AlignmentMetric): 37 | """Computes the mean of document similarity metrics.""" 38 | 39 | @abstractmethod 40 | def _compute_one(self, 41 | true: torch.FloatTensor, 42 | pred: torch.FloatTensor) -> torch.float: 43 | """ 44 | Computes the metric for one example. 45 | 46 | :param true: The true binary values (sorted in order of pred score). 47 | :param pred: The predicted scores (sorted in order of pred score). 48 | :return: The metric computed on true and pred. 49 | """ 50 | pass 51 | 52 | def compute(self, 53 | preds: List[Tuple[torch.FloatTensor, torch.FloatTensor]], 54 | targets: List[Dict[str, torch.LongTensor]]) -> torch.float: 55 | """ 56 | Computes the mean metric. 57 | 58 | :param preds: A list of (cost, alignment) tuples. 59 | :param targets: A list of dictionaries mapping to the indices of the targets and scope. 60 | :return: The mean metric. 61 | """ 62 | similarities = self._compute_similarities(preds) 63 | 64 | metrics = [] 65 | for target in targets: 66 | # Get true and pred for example 67 | true, pred = target['targets'], similarities[target['scope']] 68 | 69 | # Sort based on pred 70 | argsort = torch.argsort(pred, descending=True) 71 | true, pred = true[argsort], pred[argsort] 72 | 73 | # Convert true to float 74 | true = true.float() 75 | 76 | # Compute metric 77 | metric = self._compute_one(true, pred) 78 | metrics.append(metric) 79 | 80 | mean_metric = torch.mean(torch.FloatTensor(metrics)) 81 | 82 | return mean_metric 83 | 84 | 85 | class MAP(SimilarityMetric): 86 | """Computes mean average precision.""" 87 | 88 | def _compute_one(self, 89 | true: torch.FloatTensor, 90 | pred: torch.FloatTensor) -> torch.float: 91 | cumsum = torch.cumsum(true, dim=0) 92 | rank = torch.arange(len(true), dtype=torch.float) + 1 93 | precisions = true * cumsum / rank 94 | average_precision = torch.sum(precisions) / torch.sum(true) 95 | 96 | return average_precision 97 | 98 | 99 | class MRR(SimilarityMetric): 100 | """Computes mean reciprocal rank.""" 101 | 102 | def _compute_one(self, 103 | true: torch.FloatTensor, 104 | pred: torch.FloatTensor) -> torch.float: 105 | # The rank is the index of the first nonzero element + 1 106 | rank = torch.nonzero(true)[0, 0] + 1 107 | reciprocal_rank = 1 / rank.float() 108 | 109 | return reciprocal_rank 110 | 111 | 112 | class Precision(SimilarityMetric): 113 | """Computes precision at n.""" 114 | 115 | def __init__(self, n: int): 116 | super(Precision, self).__init__() 117 | self.n = n 118 | 119 | def _compute_one(self, 120 | true: torch.FloatTensor, 121 | pred: torch.FloatTensor) -> torch.float: 122 | return torch.mean(true[:self.n]) 123 | 124 | def __str__(self) -> str: 125 | return f'{self.__class__.__name__}_at_{self.n}' 126 | -------------------------------------------------------------------------------- /similarity/metric/load.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | from similarity.metric import Metric 4 | from similarity.metric.dev import * 5 | from similarity.metric.loss import * 6 | from similarity.metric.train import * 7 | from utils.parsing import Arguments 8 | 9 | 10 | def load_loss_and_metrics( 11 | args: Arguments, 12 | ) -> Tuple[Metric, Metric, List[Metric], List[Metric]]: 13 | """ 14 | Defines the loss and metric functions that will be used during AskUbuntu training. 15 | 16 | :param args: Arguments. 17 | :return: A tuple consisting of: 18 | 1) Training loss function 19 | 2) Dev metric function 20 | 3) A list of additional training metrics 21 | 4) A list of additional validation metrics 22 | """ 23 | # Loss 24 | loss_fn = HingeLoss( 25 | margin=args.margin, pooling=args.hinge_pooling, alpha=args.hinge_alpha 26 | ) 27 | 28 | # Metrics 29 | metric_fn = AUC() 30 | extra_training_metrics = [] 31 | 32 | if args.alignment != "average": 33 | extra_training_metrics += [ 34 | CostRange(), 35 | CostMin(), 36 | CostMax(), 37 | CostMean(), 38 | CostMedian(), 39 | AlignmentCount(), 40 | AlignmentCount(normalize="min"), 41 | AlignmentCount(normalize="full"), 42 | AlignmentSum(), 43 | AlignmentRowMarginalError(), 44 | AlignmentColumnMarginalError(), 45 | AlignmentCost(similar=True), 46 | AlignmentCost(similar=False), 47 | AlignmentEntropy(similar=True), 48 | AlignmentEntropy(similar=False), 49 | ] 50 | 51 | if args.cost_fn in ["dot_product", "scaled_dot_product", "cosine_similarity"]: 52 | extra_training_metrics += [ 53 | CostSign(positive=True, similar=True), 54 | CostSign(positive=True, similar=False), 55 | CostSign(positive=False, similar=True), 56 | CostSign(positive=False, similar=False), 57 | ] 58 | 59 | extra_validation_metrics = [ 60 | AUC(max_fpr=0.1), 61 | AUC(max_fpr=0.05), 62 | MAP(), 63 | MRR(), 64 | Precision(n=1), 65 | Precision(n=5), 66 | ] 67 | 68 | if args.alignment != "average": 69 | extra_validation_metrics += [ 70 | AlignmentCount(), 71 | AlignmentCount(normalize="min"), 72 | AlignmentCount(normalize="full"), 73 | ] 74 | 75 | return loss_fn, metric_fn, extra_training_metrics, extra_validation_metrics 76 | -------------------------------------------------------------------------------- /similarity/metric/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from similarity.metric.loss.hinge import HingeLoss 2 | 3 | __all__ = ["HingeLoss"] 4 | -------------------------------------------------------------------------------- /similarity/metric/loss/hinge.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple 2 | 3 | import torch 4 | 5 | from similarity.metric.abstract import AlignmentMetric 6 | from utils.utils import prod 7 | 8 | 9 | class HingeLoss(AlignmentMetric): 10 | """Computes the hinge loss between aligned and un-aligned document pairs (for AskUbuntu). 11 | 12 | For each document, the loss is sum_ij |negative_similarity_i - positive_similarity_j + margin| 13 | i.e. sum over all positive/negative pairs 14 | """ 15 | 16 | def __init__(self, margin: float, pooling: str = "max", alpha: float = 0.5): 17 | super(HingeLoss, self).__init__() 18 | self.margin = margin 19 | self.pooling = pooling 20 | self.alpha = alpha 21 | 22 | def compute( 23 | self, 24 | preds: List[Tuple[torch.FloatTensor, torch.FloatTensor]], 25 | targets: List[Dict[str, torch.LongTensor]], 26 | step: int = 4, 27 | ) -> torch.float: 28 | similarities = self._compute_similarities(preds) 29 | 30 | loss = count = 0 31 | for target in targets: 32 | positive_similarities, negative_similarities = ( 33 | similarities[target["positives"]], 34 | similarities[target["negatives"]], 35 | ) 36 | diff_similarities = negative_similarities.unsqueeze( 37 | dim=1 38 | ) - positive_similarities.unsqueeze( 39 | dim=0 40 | ) # num_negatives x num_positives 41 | 42 | if self.pooling == "max": 43 | diff_similarities = diff_similarities.max(dim=0)[ 44 | 0 45 | ] # num_positives (max across negatives) 46 | elif self.pooling == "smoothmax": 47 | alpha = self.alpha / (step + 1) if step < 5 else 0 48 | diff_similarities = (1 - alpha) * diff_similarities.max(dim=0)[ 49 | 0 50 | ] + alpha * diff_similarities.mean( 51 | dim=0 52 | ) # num_positives (max across negatives) 53 | elif self.pooling == "average": 54 | diff_similarities = diff_similarities.mean( 55 | dim=0 56 | ) # num_positives (mean across negatives) 57 | elif self.pooling == "none": 58 | pass 59 | else: 60 | raise ValueError(f'Pooling type "{self.pooling}" not supported') 61 | 62 | loss += torch.sum(torch.clamp(diff_similarities + self.margin, min=0)) 63 | count += int(prod(diff_similarities.shape)) 64 | 65 | loss = loss / count 66 | 67 | return loss 68 | -------------------------------------------------------------------------------- /similarity/metric/train/__init__.py: -------------------------------------------------------------------------------- 1 | from similarity.metric.train.alignment import ( 2 | AlignmentSum, 3 | AlignmentCount, 4 | AlignmentRowMarginalError, 5 | AlignmentColumnMarginalError, 6 | AlignmentCost, 7 | AlignmentEntropy, 8 | AlignmentEpsilonEntropy, 9 | ) 10 | from similarity.metric.train.cost import ( 11 | CostRange, 12 | CostMin, 13 | CostMax, 14 | CostMean, 15 | CostMedian, 16 | CostSign, 17 | ) 18 | 19 | __all__ = [ 20 | "AlignmentSum", 21 | "AlignmentCount", 22 | "AlignmentRowMarginalError", 23 | "AlignmentColumnMarginalError", 24 | "AlignmentCost", 25 | "AlignmentEntropy", 26 | "AlignmentEpsilonEntropy", 27 | "CostRange", 28 | "CostMin", 29 | "CostMax", 30 | "CostMean", 31 | "CostMedian", 32 | "CostSign", 33 | ] 34 | -------------------------------------------------------------------------------- /similarity/metric/train/alignment.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from sinkhorn import compute_alignment_cost, compute_entropy 4 | import torch 5 | 6 | from similarity.metric.abstract import AlignmentAverageMetric 7 | 8 | 9 | class AlignmentSum(AlignmentAverageMetric): 10 | def _compute_one( 11 | self, cost: torch.FloatTensor, alignment: torch.FloatTensor, target: int 12 | ) -> torch.float: 13 | return alignment.sum() 14 | 15 | 16 | class AlignmentCount(AlignmentAverageMetric): 17 | """Returns the number of non-zero entries (possibly normalized).""" 18 | 19 | def __init__( 20 | self, 21 | normalize: str = "none", # 'none', 'min' to norm by min(n, m), 'full' to norm by (nm) 22 | threshold_scaling: float = 1.0, # how much to scale the 1 / (n * m) threshold 23 | similar: Optional[bool] = None, 24 | ): 25 | assert normalize in ["none", "min", "full"] 26 | 27 | super(AlignmentCount, self).__init__(similar=similar) 28 | self.normalize = normalize 29 | 30 | def _compute_one( 31 | self, cost: torch.FloatTensor, alignment: torch.FloatTensor, target: int 32 | ) -> torch.float: 33 | n, m = alignment.shape[-2:] 34 | count = torch.sum(alignment > 1e-10).float() 35 | 36 | if self.normalize == "full": 37 | return count / (n * m) 38 | elif self.normalize == "min": 39 | return count / min(n, m) 40 | 41 | return count 42 | 43 | def __str__(self) -> str: 44 | string = super(AlignmentCount, self).__str__() 45 | 46 | if self.normalize != "none": 47 | string += "_normalized_by_" + ( 48 | "min_nm" if self.normalize == "min" else "nm" 49 | ) 50 | 51 | return string 52 | 53 | 54 | class AlignmentMarginalError(AlignmentAverageMetric): 55 | """Returns the average absolute error of either the row or column marginal, assuming uniform marginal.""" 56 | 57 | def __init__(self, side: int, similar: Optional[bool] = None): 58 | super(AlignmentMarginalError, self).__init__(similar=similar) 59 | assert side in {0, 1} 60 | self.side = side 61 | 62 | def _compute_one( 63 | self, cost: torch.FloatTensor, alignment: torch.FloatTensor, target: int 64 | ) -> torch.float: 65 | device = alignment.device 66 | marginal_dim = -2 if self.side == 0 else -1 67 | sum_dim = -1 if self.side == 0 else -2 68 | marginal = torch.ones( 69 | alignment.size(marginal_dim), device=device 70 | ) / alignment.size(marginal_dim) 71 | marginal_hat = alignment.sum(sum_dim) 72 | error = torch.abs(marginal - marginal_hat).mean() 73 | 74 | return error 75 | 76 | 77 | class AlignmentRowMarginalError(AlignmentMarginalError): 78 | def __init__(self, similar: Optional[bool] = None): 79 | super(AlignmentRowMarginalError, self).__init__(side=0, similar=similar) 80 | 81 | 82 | class AlignmentColumnMarginalError(AlignmentMarginalError): 83 | def __init__(self, similar: Optional[bool] = None): 84 | super(AlignmentColumnMarginalError, self).__init__(side=1, similar=similar) 85 | 86 | 87 | class AlignmentCost(AlignmentAverageMetric): 88 | """Computes the cost of aligning two objects.""" 89 | 90 | def _compute_one( 91 | self, cost: torch.FloatTensor, alignment: torch.FloatTensor, target: int 92 | ) -> torch.float: 93 | return compute_alignment_cost(C=cost, P=alignment) 94 | 95 | 96 | class AlignmentEntropy(AlignmentAverageMetric): 97 | """Computes the entropy of aligning two objects.""" 98 | 99 | def _compute_one( 100 | self, cost: torch.FloatTensor, alignment: torch.FloatTensor, target: int 101 | ) -> torch.float: 102 | return compute_entropy(P=alignment) 103 | 104 | 105 | class AlignmentEpsilonEntropy(AlignmentAverageMetric): 106 | """Computes epsilon times the entropy of aligning two objects.""" 107 | 108 | def __init__(self, epsilon: float, similar: Optional[bool] = None): 109 | super(AlignmentEpsilonEntropy, self).__init__(similar=similar) 110 | self.epsilon = epsilon 111 | 112 | def _compute_one( 113 | self, cost: torch.FloatTensor, alignment: torch.FloatTensor, target: int 114 | ) -> torch.float: 115 | return self.epsilon * compute_entropy(P=alignment) 116 | -------------------------------------------------------------------------------- /similarity/metric/train/cost.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from similarity.metric.abstract import AlignmentAverageMetric 4 | 5 | 6 | class CostRange(AlignmentAverageMetric): 7 | def _compute_one(self, 8 | cost: torch.FloatTensor, 9 | alignment: torch.FloatTensor, 10 | target: int) -> torch.float: 11 | return cost.max() - cost.min() 12 | 13 | 14 | class CostMin(AlignmentAverageMetric): 15 | def _compute_one(self, 16 | cost: torch.FloatTensor, 17 | alignment: torch.FloatTensor, 18 | target: int) -> torch.float: 19 | return cost.min() 20 | 21 | 22 | class CostMax(AlignmentAverageMetric): 23 | def _compute_one(self, 24 | cost: torch.FloatTensor, 25 | alignment: torch.FloatTensor, 26 | target: int) -> torch.float: 27 | return cost.max() 28 | 29 | 30 | class CostMean(AlignmentAverageMetric): 31 | def _compute_one(self, 32 | cost: torch.FloatTensor, 33 | alignment: torch.FloatTensor, 34 | target: int) -> torch.float: 35 | return cost.mean() 36 | 37 | 38 | class CostMedian(AlignmentAverageMetric): 39 | def _compute_one(self, 40 | cost: torch.FloatTensor, 41 | alignment: torch.FloatTensor, 42 | target: int) -> torch.float: 43 | return cost.median() 44 | 45 | 46 | class CostSign(AlignmentAverageMetric): 47 | def __init__(self, positive: bool, similar: bool): 48 | super(CostSign, self).__init__(similar=similar) 49 | self.positive = positive # Whether to count positive or negative costs 50 | 51 | def _compute_one(self, 52 | cost: torch.FloatTensor, 53 | alignment: torch.FloatTensor, 54 | target: int) -> torch.float: 55 | if self.positive: 56 | return torch.sum(cost > 0).float() 57 | return torch.sum(cost < 0).float() 58 | 59 | def __str__(self) -> str: 60 | return f'Num{"Positive" if self.positive else "Negative"}CostWhen{"Similar" if self.similar else "Dissimilar"}' 61 | -------------------------------------------------------------------------------- /similarity/models/__init__.py: -------------------------------------------------------------------------------- 1 | from similarity.models.alignment import AlignmentModel 2 | from similarity.models.attention import SparsemaxFunction 3 | 4 | __all__ = ["AlignmentModel", "SparsemaxFunction"] 5 | -------------------------------------------------------------------------------- /similarity/models/alignment.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import math 3 | from typing import Any, Iterator, List, Optional, Tuple 4 | 5 | from sinkhorn import batch_sinkhorn, construct_cost_and_marginals 6 | from sru import SRU 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from utils.parsing import Arguments 12 | from utils.utils import compute_cost, prod, unpad_tensors 13 | from similarity.models.attention import load_attention_layer 14 | from similarity.models.encoder import Embedder 15 | 16 | 17 | class AlignmentModel(nn.Module): 18 | def __init__( 19 | self, 20 | args: Arguments, 21 | text_field, 22 | domain: Optional[str] = None, 23 | device: torch.device = torch.device("cpu"), 24 | ): 25 | """Constructs an AlignmentModel.""" 26 | super(AlignmentModel, self).__init__() 27 | 28 | # Save values 29 | self.args = args 30 | self.device = device 31 | self.embedder = Embedder(args=args, text_field=text_field, device=device) 32 | 33 | if self.args.alignment == "attention": 34 | self.atten = load_attention_layer(self.args, self.embedder.output_size) 35 | 36 | self.output_size = self.embedder.output_size 37 | # Move to device 38 | self.to(self.device) 39 | 40 | def forward( 41 | self, 42 | data: torch.LongTensor, # batch_size x seq_len 43 | scope: List[Tuple[torch.LongTensor, torch.LongTensor]], 44 | targets: List[torch.LongTensor], 45 | threshold: float = 0, 46 | encoded: torch.Tensor = None, 47 | ) -> Tuple[List[Tuple[torch.FloatTensor, torch.FloatTensor]], List[Any]]: 48 | """ 49 | Aligns document pairs. 50 | 51 | :param data: Sentences represented as LongTensors of word indices. 52 | :param scope: A list of tuples of row_indices and column_indices indexing into data 53 | to extract the appropriate sentences for each document pair. 54 | :param targets: A list of targets for each document pair. 55 | :return: A tuple consisting of a list of (cost, alignment) tuples and a list of targets. 56 | """ 57 | if encoded is None: 58 | encoded, encoded_seq = self.embedder(data) 59 | # if self.word_to_word: 60 | # encoded = encodede_seq 61 | 62 | # Alignment 63 | costs, alignments = [], [] 64 | n_list, m_list = [], [] 65 | for row_indices, column_indices in scope: 66 | # Select sentence vectors using indices 67 | row_vecs, column_vecs = ( 68 | torch.index_select(encoded, 0, row_indices), 69 | torch.index_select(encoded, 0, column_indices), 70 | ) # (n/m)x 2*hidden_size 71 | 72 | # Get sizes 73 | n, m = len(row_vecs), len(column_vecs) 74 | n_list.append(n) 75 | m_list.append(m) 76 | 77 | # Average sentence embeddings 78 | if self.args.alignment == "average": 79 | row_vecs = row_vecs.mean(dim=0, keepdim=True) 80 | column_vecs = column_vecs.mean(dim=0, keepdim=True) 81 | 82 | # Compute cost 83 | cost = compute_cost(cost_fn=self.args.cost_fn, x1=row_vecs, x2=column_vecs) 84 | 85 | # Alignment-specific computation 86 | if self.args.alignment == "attention": 87 | cost, alignment = self.atten( 88 | row_vecs, column_vecs, cost / self.args.attention_temp, threshold 89 | ) 90 | alignments.append(alignment) 91 | 92 | # Add cost 93 | costs.append(cost) 94 | 95 | # Hack alignment matrix for models that don't do alignment 96 | if self.args.alignment not in ["attention", "sinkhorn"]: 97 | alignments.append( 98 | torch.ones_like(cost) / prod(cost.shape) 99 | ) # use alignment to compute average cost 100 | 101 | # Alignment via sinkhorn 102 | if self.args.alignment == "sinkhorn": 103 | # Add dummy node and get marginals 104 | costs, a_list, b_list = zip( 105 | *[ 106 | construct_cost_and_marginals( 107 | C=cost, 108 | one_to_k=self.args.one_to_k, 109 | split_dummy=self.args.split_dummy, 110 | max_num_aligned=self.args.max_num_aligned, 111 | optional_alignment=self.args.optional_alignment, 112 | ) 113 | for cost in costs 114 | ] 115 | ) 116 | costs, a_list, b_list = list(costs), list(a_list), list(b_list) 117 | 118 | # Prepare sinkhorn function 119 | batch_sinkhorn_func = partial( 120 | batch_sinkhorn, 121 | a_list=a_list, 122 | b_list=b_list, 123 | epsilon=self.args.epsilon, 124 | unbalanced_lambda=self.args.unbalanced_lambda, 125 | ) 126 | 127 | # Run sinkhorn 128 | alignments = batch_sinkhorn_func(C_list=costs) 129 | 130 | # Remove dummy nodes 131 | shapes = list(zip(n_list, m_list)) 132 | alignments = unpad_tensors(alignments, shapes) 133 | 134 | costs = unpad_tensors(costs, shapes) 135 | 136 | # Re-normalize alignment probabilities to one 137 | if ( 138 | self.args.alignment in ["attention", "sinkhorn"] 139 | and not self.args.optional_alignment 140 | ): 141 | alignments = [ 142 | alignment 143 | / alignment.sum(dim=-1, keepdim=True).sum(dim=-2, keepdim=True) 144 | for alignment in alignments 145 | ] 146 | 147 | # Combine costs and alignments into preds 148 | preds = list(zip(costs, alignments)) 149 | 150 | return preds, targets 151 | 152 | def forward_with_sentences( 153 | self, 154 | data: torch.LongTensor, 155 | scope: List[Tuple[torch.LongTensor, torch.LongTensor]], 156 | targets: List[torch.LongTensor], 157 | threshold: float = 0, 158 | ) -> Tuple[ 159 | List[Tuple[torch.LongTensor, torch.LongTensor]], 160 | List[Tuple[torch.FloatTensor, torch.FloatTensor]], 161 | List[Any], 162 | ]: 163 | """Makes predictions and returns input sentences, predictions, and targets.""" 164 | sentences = [ 165 | ( 166 | torch.index_select(data, 0, simple_indices), 167 | torch.index_select(data, 0, normal_indices), 168 | ) 169 | for simple_indices, normal_indices in scope 170 | ] 171 | preds, targets = self.forward(data, scope, targets, threshold) 172 | 173 | return sentences, preds, targets 174 | 175 | def num_parameters(self, trainable: bool = False) -> int: 176 | """Gets the number of parameters in the model. 177 | Returns 178 | ---------- 179 | int 180 | number of model params 181 | """ 182 | if trainable: 183 | model_params = list(self.trainable_params) 184 | else: 185 | model_params = list(self.parameters()) 186 | 187 | return sum(len(x.view(-1)) for x in model_params) 188 | 189 | @property 190 | def trainable_params(self) -> Iterator[nn.Parameter]: 191 | """Get all the parameters with `requires_grad=True`. 192 | Returns 193 | ------- 194 | Iterator[nn.Parameter] 195 | Iterator over the parameters 196 | """ 197 | return filter(lambda p: p.requires_grad, self.parameters()) 198 | 199 | @property 200 | def gradient_norm(self) -> float: 201 | """Compute the average gradient norm. 202 | 203 | Returns 204 | ------- 205 | float 206 | The current average gradient norm 207 | 208 | """ 209 | # Only compute over parameters that are being trained 210 | parameters = filter( 211 | lambda p: p.requires_grad and p.grad is not None, self.parameters() 212 | ) 213 | norm = math.sqrt(sum(param.grad.norm(p=2).item() ** 2 for param in parameters)) 214 | 215 | return norm 216 | 217 | @property 218 | def parameter_norm(self) -> float: 219 | """Compute the average parameter norm. 220 | 221 | Returns 222 | ------- 223 | float 224 | The current average parameter norm 225 | 226 | """ 227 | # Only compute over parameters that are being trained 228 | parameters = filter(lambda p: p.requires_grad, self.parameters()) 229 | norm = math.sqrt(sum(param.norm(p=2).item() ** 2 for param in parameters)) 230 | 231 | return norm 232 | -------------------------------------------------------------------------------- /similarity/models/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Tuple 4 | 5 | from utils.parsing import Arguments 6 | from utils.utils import compute_cost, prod 7 | 8 | 9 | def load_attention_layer(args: Arguments, bidirectional: bool) -> nn.Module: 10 | if args.attention_type == 0 or args.attention_type == 1: 11 | return attention0(args, bidirectional) 12 | if args.attention_type == 2 or args.attention_type == 3: 13 | return attention2(args, bidirectional) 14 | if args.attention_type == 4: 15 | return attention4(args, bidirectional) 16 | if args.attention_type == 5: 17 | return attention5(args, bidirectional) 18 | 19 | 20 | def build_ffn(input_size: int, output_size: int, args: Arguments) -> nn.Module: 21 | """Builds a 2-layer feed-forward network.""" 22 | return nn.Sequential( 23 | # nn.Linear(input_size, self.args.hidden_size), 24 | # nn.Dropout(self.args.dropout), 25 | # nn.ReLU(), 26 | # nn.Linear(self.args.hidden_size, output_size) 27 | nn.Linear(input_size, output_size), 28 | nn.Dropout(args.dropout), 29 | nn.LeakyReLU(0.2), 30 | ) 31 | 32 | 33 | class attention5(nn.Module): 34 | def __init__(self, args: Arguments, input_size: int): 35 | super(attention5, self).__init__() 36 | self.args = args 37 | self.G = build_ffn( 38 | input_size=input_size, output_size=self.args.hidden_size, args=self.args 39 | ) 40 | self.sparsemax = SparsemaxFunction.apply 41 | device = ( 42 | torch.device(args.gpu) if torch.cuda.is_available() else torch.device("cpu") 43 | ) 44 | self.device = device 45 | 46 | def forward( 47 | self, 48 | row_vecs: torch.FloatTensor, 49 | column_vecs: torch.FloatTensor, 50 | cost: torch.FloatTensor, 51 | threshold: float = 0, 52 | ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: 53 | # https://arxiv.org/abs/1606.01933 54 | # Attend 55 | a, b = row_vecs, column_vecs # n x hidden_size, m x hidden_size 56 | 57 | if self.args.using_sparsemax: 58 | row_alignment = self.sparsemax(-cost, 1) 59 | column_alignment = self.sparsemax(-cost, 0) 60 | else: 61 | row_alignment = (-cost).softmax(dim=1) 62 | column_alignment = (-cost).softmax(dim=0) 63 | 64 | # if threshold: 65 | # mask = (torch.rand(row_alignment.size(0), row_alignment.size(1)) >= threshold ).float().cuda() 66 | # # threshold_alignments = [alignment * ( torch.rand(alignment.size(-2), alignment.size(-1)) >=0.5 ).float() for alignment in alignments] 67 | 68 | # row_alignment = row_alignment * mask #(row_alignment >= threshold/size).float() 69 | # column_alignment = column_alignment * mask #(column_alignment >= threshold/size).float() 70 | 71 | if threshold: 72 | # if self.args.absolute_threshold: 73 | # alignment_masks_r = (row_alignments >= threshold).float().to(self.device) 74 | # alignment_masks_c = (column_alignments >= threshold).float().to(self.device) 75 | # else: 76 | # mask = (torch.rand(row_alignments.size(0), row_alignments.size(1), row_alignments.size(2)) >= threshold ).float().to(self.device) 77 | # threshold_alignments = [alignment * ( torch.rand(alignment.size(-2), alignment.size(-1)) >=0.5 ).float() for alignment in alignments] 78 | # print('relative') 79 | alignment_masks_r = ( 80 | (row_alignment >= threshold / prod(row_alignment.shape[-2:])) 81 | .float() 82 | .to(self.device) 83 | ) 84 | alignment_masks_c = ( 85 | (column_alignment >= threshold / prod(column_alignment.shape[-2:])) 86 | .float() 87 | .to(self.device) 88 | ) 89 | 90 | row_alignment = ( 91 | row_alignment * alignment_masks_r * alignment_masks_c 92 | ) # (row_alignment >= threshold/size).float() 93 | column_alignment = ( 94 | column_alignment * alignment_masks_c * alignment_masks_r 95 | ) # (column_alignment >= threshold/size).float() 96 | 97 | beta = torch.sum( 98 | row_alignment.unsqueeze(dim=2) * b.unsqueeze(dim=0), dim=1 99 | ) # n x hidden_size 100 | alpha = torch.sum( 101 | column_alignment.unsqueeze(dim=2) * a.unsqueeze(dim=1), dim=0 102 | ) # m x hidden_size 103 | 104 | # Compare 105 | if self.args.force_attention_linear: 106 | v_1 = beta.mean(dim=0, keepdim=True) # n x hidden_size 107 | v_2 = alpha.mean(dim=0, keepdim=True) # m x hidden_size 108 | else: 109 | v_1 = self.G(beta).mean(dim=0, keepdim=True) # n x hidden_size 110 | v_2 = self.G(alpha).mean(dim=0, keepdim=True) # m x hidden_size 111 | 112 | y = compute_cost(cost_fn=self.args.cost_fn, x1=v_1, x2=v_2) 113 | 114 | cost_matrix = y * torch.ones_like(cost) 115 | alignment = torch.stack( 116 | (row_alignment.detach(), column_alignment.detach()), dim=0 117 | ) 118 | return cost_matrix, alignment 119 | 120 | 121 | """Sparsemax activation function. 122 | Pytorch implementation of Sparsemax function from: 123 | -- "From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification" 124 | -- André F. T. Martins, Ramón Fernandez Astudillo (http://arxiv.org/abs/1602.02068) 125 | """ 126 | """ 127 | An implementation of sparsemax (Martins & Astudillo, 2016). See 128 | :cite:`DBLP:journals/corr/MartinsA16` for detailed description. 129 | By Ben Peters and Vlad Niculae 130 | """ 131 | # From: https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/modules/sparse_activations.py 132 | 133 | import torch 134 | from torch.autograd import Function 135 | import torch.nn as nn 136 | 137 | 138 | def _make_ix_like(input, dim=0): 139 | d = input.size(dim) 140 | rho = torch.arange(1, d + 1, device=input.device, dtype=input.dtype) 141 | view = [1] * input.dim() 142 | view[0] = -1 143 | return rho.view(view).transpose(0, dim) 144 | 145 | 146 | def _threshold_and_support(input, dim=0): 147 | """Sparsemax building block: compute the threshold 148 | Args: 149 | input: any dimension 150 | dim: dimension along which to apply the sparsemax 151 | Returns: 152 | the threshold value 153 | """ 154 | 155 | input_srt, _ = torch.sort(input, descending=True, dim=dim) 156 | input_cumsum = input_srt.cumsum(dim) - 1 157 | rhos = _make_ix_like(input, dim) 158 | support = rhos * input_srt > input_cumsum 159 | 160 | support_size = support.sum(dim=dim).unsqueeze(dim) 161 | tau = input_cumsum.gather(dim, support_size - 1) 162 | tau /= support_size.to(input.dtype) 163 | return tau, support_size 164 | 165 | 166 | class SparsemaxFunction(Function): 167 | @staticmethod 168 | def forward(ctx, input, dim=0): 169 | """sparsemax: normalizing sparse transform (a la softmax) 170 | Parameters: 171 | input (Tensor): any shape 172 | dim: dimension along which to apply sparsemax 173 | Returns: 174 | output (Tensor): same shape as input 175 | """ 176 | ctx.dim = dim 177 | max_val, _ = input.max(dim=dim, keepdim=True) 178 | input -= max_val # same numerical stability trick as for softmax 179 | tau, supp_size = _threshold_and_support(input, dim=dim) 180 | output = torch.clamp(input - tau, min=0) 181 | ctx.save_for_backward(supp_size, output) 182 | return output 183 | 184 | @staticmethod 185 | def backward(ctx, grad_output): 186 | supp_size, output = ctx.saved_tensors 187 | dim = ctx.dim 188 | grad_input = grad_output.clone() 189 | grad_input[output == 0] = 0 190 | 191 | v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze() 192 | v_hat = v_hat.unsqueeze(dim) 193 | grad_input = torch.where(output != 0, grad_input - v_hat, grad_input) 194 | return grad_input, None 195 | -------------------------------------------------------------------------------- /similarity/models/encoder.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import math 3 | from typing import Any, Iterator, List, Optional, Tuple 4 | 5 | from sinkhorn import batch_sinkhorn, construct_cost_and_marginals 6 | from sru import SRU 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from utils.parsing import Arguments 12 | from utils.utils import compute_cost, prod, unpad_tensors 13 | from similarity.models.attention import load_attention_layer 14 | 15 | 16 | class Embedder(nn.Module): 17 | def __init__( 18 | self, 19 | args: Arguments, 20 | text_field, 21 | bidirectional: bool = True, 22 | layer_norm: bool = False, 23 | highway_bias: float = 0.0, 24 | rescale: bool = True, 25 | device: torch.device = torch.device("cpu"), 26 | ): 27 | """Constructs an model to compute embeddings.""" 28 | super(Embedder, self).__init__() 29 | 30 | # Save values 31 | self.args = args 32 | self.device = device 33 | pad_index = text_field.pad_index() 34 | self.pad_index = pad_index 35 | 36 | if self.args.bert: 37 | from transformers import AutoModel 38 | 39 | self.encoder = AutoModel.from_pretrained(args.bert_type) 40 | print("finish loading bert encoder") 41 | self.output_size = self.encoder.config.hidden_size 42 | self.bidirectional = False 43 | self.bert_bs = args.bert_batch_size 44 | else: 45 | num_embeddings = len(text_field.vocabulary) 46 | print(f'Loading embeddings from "{args.embedding_path}"') 47 | embedding_matrix = text_field.load_embeddings(args.embedding_path) 48 | 49 | self.num_embeddings = num_embeddings 50 | self.embedding_size = embedding_matrix.size(1) 51 | self.bidirectional = bidirectional 52 | self.layer_norm = layer_norm 53 | self.highway_bias = highway_bias 54 | self.rescale = rescale 55 | self.word_to_word = args.word_to_word 56 | self.output_size = self.args.hidden_size * (1 + self.bidirectional) 57 | 58 | # Create models/parameters 59 | self.embedding = nn.Embedding( 60 | num_embeddings=self.num_embeddings, 61 | embedding_dim=self.embedding_size, 62 | padding_idx=self.pad_index, 63 | ) 64 | self.embedding.weight.data = embedding_matrix 65 | self.embedding.weight.requires_grad = False 66 | 67 | self.encoder = SRU( 68 | input_size=self.embedding_size, 69 | hidden_size=self.args.hidden_size, 70 | num_layers=self.args.num_layers, 71 | dropout=self.args.dropout, 72 | bidirectional=self.bidirectional, 73 | layer_norm=self.layer_norm, 74 | rescale=self.rescale, 75 | highway_bias=self.highway_bias, 76 | ) 77 | 78 | # Move to device 79 | self.to(self.device) 80 | 81 | def rnn_encode( 82 | self, data: torch.LongTensor, # batch_size x seq_len 83 | ) -> Tuple[List[Tuple[torch.FloatTensor, torch.FloatTensor]], List[Any]]: 84 | """ 85 | Aligns document pairs. 86 | 87 | :param data: Sentences represented as LongTensors of word indices. 88 | :param scope: A list of tuples of row_indices and column_indices indexing into data 89 | to extract the appropriate sentences for each document pair. 90 | :param data: A list of data for each document pair. 91 | :return: A tuple consisting of a list of (cost, alignment) tuples and a list of data. 92 | """ 93 | # Transpose from batch first to sequence first 94 | data = data.transpose(0, 1) # seq_len x batch_size 95 | 96 | # Create mask 97 | mask = (data != self.pad_index).float() # seq_len x batch_size 98 | 99 | # Embed 100 | embedded = self.embedding(data) # batch_size x seq_len x embedding_size 101 | 102 | # RNN encoder 103 | h_seq, _ = self.encoder( 104 | embedded, mask_pad=(1 - mask) 105 | ) # seq_len x batch_size x 2*hidden_size 106 | # output_states, c_states = sru(x) # forward pass 107 | # output_states is (length, batch size, number of directions * hidden size) 108 | # c_states is (layers, batch size, number of directions * hidden size) 109 | 110 | masked_h_seq = h_seq * mask.unsqueeze( 111 | dim=2 112 | ) # seq_len x batch_size x 2*hidden_size 113 | 114 | # Average pooling 115 | masked_h = masked_h_seq.sum(dim=0) / mask.sum(dim=0).unsqueeze( 116 | dim=1 117 | ) # batch_size x 2*hidden_size 118 | 119 | masked_h_seq = masked_h_seq.transpose(0, 1) 120 | # return masked_h, masked_h_seq 121 | return masked_h, None 122 | 123 | def bert_encode( 124 | self, 125 | data: torch.LongTensor, # batch_size x seq_len 126 | token_type_ids: Optional[torch.Tensor] = None, 127 | attention_mask: Optional[torch.Tensor] = None, 128 | position_ids: Optional[torch.Tensor] = None, 129 | head_mask: Optional[torch.Tensor] = None, 130 | ) -> torch.Tensor: 131 | """ 132 | Uses an RNN and self-attention to encode a batch of sequences of word embeddings. 133 | :param batch: A FloatTensor of shape `(sequence_length, batch_size, embedding_size)` containing embedded text. 134 | :param lengths: A LongTensor of shape `(batch_size)` containing the lengths of the sequences, used for masking. 135 | :return: A FloatTensor of shape `(batch_size, output_size)` containing the encoding for each sequence 136 | in the batch. 137 | """ 138 | # print(data.shape) 139 | # Create mask for padding 140 | # max_len = lengths.max().item() 141 | # attention_mask = torch.zeros(len(data), max_len, dtype=torch.float) 142 | # for i in range(len(data)): 143 | # attention_mask[i, :lengths[i]] = 1 144 | 145 | if attention_mask is None and self.pad_index is not None: 146 | attention_mask = (data != self.pad_index).float() 147 | 148 | attention_mask = attention_mask.to(self.device) 149 | outputs = self.encoder(data, attention_mask=attention_mask) 150 | if not "distil" in self.args.bert_type: 151 | masked_h_seq = outputs[0] 152 | masked_h = outputs[1] 153 | else: 154 | masked_h = outputs[0] 155 | return masked_h, None 156 | 157 | def forward( 158 | self, data: torch.LongTensor, # batch_size x seq_len 159 | ) -> Tuple[List[Tuple[torch.FloatTensor, torch.FloatTensor]], List[Any]]: 160 | if self.args.bert: 161 | if len(data) > self.bert_bs: 162 | encodings = [] 163 | batch_size = self.bert_bs 164 | for batch_idx in range(len(data) // batch_size + 1): 165 | start_idx = batch_idx * batch_size 166 | end_idx = (batch_idx + 1) * batch_size 167 | batch = data[start_idx:end_idx] 168 | # print(data.shape) 169 | # print(batch.shape) 170 | if len(batch) == 0: 171 | break 172 | encoded, _ = self.bert_encode(batch) 173 | # print(encoded.shape) 174 | encodings.extend(encoded) 175 | del encoded 176 | encodings = torch.stack(encodings) 177 | # print(encodings.shape) 178 | return encodings, None 179 | else: 180 | return self.rnn_encode(data) 181 | -------------------------------------------------------------------------------- /sinkhorn/__init__.py: -------------------------------------------------------------------------------- 1 | from sinkhorn.cost_and_marginals import construct_cost_and_marginals 2 | from sinkhorn.sinkhorn import compute_alignment_cost, compute_entropy, batch_sinkhorn, sinkhorn, \ 3 | sinkhorn_epsilon_scaling 4 | -------------------------------------------------------------------------------- /sinkhorn/cost_and_marginals.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | 5 | from sinkhorn.utils import compute_masks 6 | 7 | 8 | def construct_cost_and_marginals(C: torch.FloatTensor, 9 | one_to_k: Optional[int] = None, 10 | max_num_aligned: Optional[int] = None, 11 | optional_alignment: bool = False, 12 | split_dummy: bool = False, 13 | order_lambda: Optional[float] = None, 14 | ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: 15 | """ 16 | Given a cost matrix, adds dummy nodes (if necessary) and returns the new cost matrix and the row and column marginals. 17 | 18 | Dummy nodes always have cost 0. 19 | 20 | :param C: FloatTensor (n x m) with costs. 21 | :param one_to_k: The k for one-to-k alignment, where each node on the smaller side aligns to k nodes on the larger side. 22 | If k is -1, uses the maximum possible k, which is k = floor(max(n, m) / min(n, m)). 23 | :param max_num_aligned: Whether to constrain the total number of alignments to this number. 24 | :param optional_alignment: Whether to allow every node to align to a dummy node. Requires a cost that can be + or -. 25 | :param split_dummy: Whether to split multi-marginal dummy nodes into multiple nodes with marginal 1 each. 26 | :param order_lambda: Weight for an added diagonal order preserving cost (0 = no weight, 1 = all weight). 27 | :return: A tuple containing the cost matrix, the row marginals, and the column marginals. 28 | """ 29 | # Checks 30 | assert len(C.shape) == 2 31 | assert one_to_k is None or max_num_aligned is None 32 | if max_num_aligned is not None: 33 | assert not optional_alignment 34 | 35 | # Setup 36 | n, m = C.shape 37 | device = C.device 38 | 39 | # Ensure n >= m for all the following for convenience, then return to original order at the end 40 | swap = n < m 41 | if swap: 42 | n, m = m, n 43 | 44 | # Default marginals 45 | a = [1] * n 46 | b = [1] * m 47 | 48 | # If max_num_aligned is more than a one-to-one alignment, just do a one-to-one alignment 49 | if max_num_aligned is not None and max_num_aligned >= m: 50 | max_num_aligned = None 51 | one_to_k = 1 52 | 53 | # If optional alignment, set one-to-one because it'll happen anyway 54 | if optional_alignment and one_to_k is None: 55 | one_to_k = 1 56 | 57 | # Dummy nodes 58 | if max_num_aligned is not None: 59 | assert max_num_aligned > 0 60 | 61 | if split_dummy: 62 | a = b = [1] * (n + m - max_num_aligned) 63 | else: 64 | a = [1] * n + [m - max_num_aligned] 65 | b = [1] * m + [n - max_num_aligned] 66 | elif one_to_k is not None: 67 | assert one_to_k == -1 or one_to_k > 0 68 | 69 | # Add enough dummy nodes to absorb the non-divisibility of n and m 70 | max_k = n // m 71 | k = max_k if one_to_k == -1 else min(one_to_k, max_k) 72 | rem = n - k * m 73 | 74 | if optional_alignment: 75 | if split_dummy: 76 | a = [1] * n + [k] * m 77 | b = [k] * m + [1] * n 78 | else: 79 | a = [1] * n + [k * m] 80 | b = [k] * m + [n] 81 | elif rem != 0: 82 | if split_dummy: 83 | a = [1] * n 84 | b = [k] * m + [1] * rem 85 | else: 86 | a = [1] * n 87 | b = [k] * m + [rem] 88 | 89 | # Return to original ordering 90 | if swap: 91 | n, m = m, n 92 | a, b = b, a 93 | 94 | # Add zero padding for dummy node cost if there are dummy nodes 95 | num_rows, num_cols = len(a), len(b) 96 | if (num_rows, num_cols) != (n, m): 97 | padded_cost = torch.zeros(num_rows, num_cols, device=device) 98 | padded_cost[:n, :m] = C 99 | C = padded_cost 100 | 101 | # Tensorize marginals and normalize to one 102 | a = torch.FloatTensor(a).to(device) / sum(a) 103 | b = torch.FloatTensor(b).to(device) / sum(b) 104 | 105 | # Add diagonal order-preserving cost 106 | if order_lambda is not None: 107 | C = add_order_cost(C=C, a=a, b=b, order_lambda=order_lambda) 108 | 109 | return C, a, b 110 | 111 | 112 | def add_order_cost(C: torch.FloatTensor, 113 | a: torch.FloatTensor, 114 | b: torch.FloatTensor, 115 | order_lambda: float) -> torch.FloatTensor: 116 | """ 117 | Adds diagonal order preserving cost to a cost matrix. 118 | 119 | :param C: FloatTensor (n x m or num_batches x n x m) with costs. 120 | Note: The device and dtype of C are used for all other variables. 121 | :param a: Row marginals (num_batches x n). 122 | :param b: Column marginals (num_batches x m). 123 | :param order_lambda: Weight for diagonal order preserving cost (0 = no weight, 1 = all weight). 124 | :return: The cost matrix with the order preserving cost added in. 125 | """ 126 | # Checks 127 | assert len(C.shape) in [2, 3] 128 | batched = len(C.shape) == 3 129 | assert len(a.shape) == len(b.shape) == (2 if batched else 1) 130 | assert 0.0 <= order_lambda <= 1.0 131 | 132 | # Return C if not adding order cost 133 | if order_lambda == 0.0: 134 | return C 135 | 136 | # Setup 137 | dtype = C.dtype 138 | device = C.device 139 | 140 | # Compute order preserving cost 141 | I = torch.arange(C.size(-2), dtype=dtype, device=device).unsqueeze(dim=1) 142 | J = torch.arange(C.size(-1), dtype=dtype, device=device).unsqueeze(dim=0) 143 | 144 | # Compute masks 145 | mask, mask_n, mask_m = compute_masks(C=C, a=a, b=b) 146 | 147 | # Compute N and M 148 | N = mask_n.sum(dim=-1, keepdim=True) 149 | M = mask_m.sum(dim=-1, keepdim=True) 150 | 151 | if batched: 152 | I = I.unsqueeze(dim=0).repeat(C.size(0), 1, C.size(2)) 153 | J = J.unsqueeze(dim=0).repeat(C.size(0), C.size(1), 1) 154 | 155 | N = N.unsqueeze(dim=-1) 156 | M = M.unsqueeze(dim=-1) 157 | else: 158 | I = I.repeat(1, C.size(1)) 159 | J = J.repeat(C.size(0), 1) 160 | 161 | D = torch.abs(I / N - J / M) 162 | 163 | # Match average magnitudes so costs are on the same scale 164 | mask_sum = mask.sum(dim=-1).sum(dim=-1) 165 | C_magnitude = (mask * C.abs()).sum(dim=-1).sum(dim=-1) / mask_sum 166 | D_magnitude = (mask * D.abs()).sum(dim=-1).sum(dim=-1) / mask_sum 167 | D_magnitude[D_magnitude == 0] = 1 # prevent divide by 0 168 | D *= (C_magnitude / D_magnitude).unsqueeze(dim=-1).unsqueeze(dim=-1) 169 | 170 | # Add order preserving cost 171 | C = (1 - order_lambda) * C + order_lambda * D 172 | C *= mask 173 | 174 | return C 175 | -------------------------------------------------------------------------------- /sinkhorn/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import torch 4 | 5 | 6 | def bmv(m: torch.FloatTensor, v: torch.FloatTensor) -> torch.FloatTensor: 7 | """ 8 | Performs a batched matrix-vector product. 9 | 10 | :param m: A 3-dimensional FloatTensor (num_batches x n1 x n2). 11 | :param v: A 2-dimensional FloatTensor (num_batches x n2). 12 | :return: Batched matrix-vector product mv (num_batches x n1). 13 | """ 14 | assert len(m.shape) == 3 15 | assert len(v.shape) == 2 16 | return torch.bmm(m, v.unsqueeze(dim=2)).squeeze(dim=2) 17 | 18 | 19 | def compute_masks(C: torch.FloatTensor, 20 | a: torch.FloatTensor, 21 | b: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: 22 | """ 23 | Computes masks for C, a, and b based on the zero entries in a and b. 24 | 25 | Masks have 1s for content and 0s for padding. 26 | 27 | :param C: Cost matrix (n x m or num_batches x n x m). 28 | Note: The device and dtype of C are used for all other variables. 29 | :param a: Row marginals (n or num_batches x n). 30 | :param b: Column marginals (m or num_batches x m). 31 | :return: A tuple containing a mask for C, a mask for a, and a mask for b. 32 | """ 33 | mask_n = (a != 0) 34 | mask_m = (b != 0) 35 | mask = mask_n.unsqueeze(dim=-1) & mask_m.unsqueeze(dim=-2) 36 | mask, mask_n, mask_m = mask.to(C), mask_n.to(C), mask_m.to(C) 37 | 38 | return mask, mask_n, mask_m 39 | 40 | 41 | def pad_tensors(tensor_list: List[torch.FloatTensor], 42 | padding: float = 0.0) -> torch.FloatTensor: 43 | """ 44 | Pads and stacks a list of tensors, each with the same number of dimensions. 45 | 46 | :param tensor_list: A list of FloatTensors to pad. 47 | Note: The device and dtype of the first tensor are used for all other variables. 48 | :param padding: Padding value to use. 49 | :return: A FloatTensor containing the padded and stacked tensors in tensor_list. 50 | """ 51 | # Determine maximum size along each dimension 52 | shape_list = [tensor.shape for tensor in tensor_list] 53 | shape_max = torch.LongTensor(shape_list).max(dim=0)[0] 54 | 55 | # Create padding with shape (num_batches, *shape_max) 56 | tensor_batch = padding * torch.ones(len(tensor_list), *shape_max, 57 | dtype=tensor_list[0].dtype, device=tensor_list[0].device) 58 | 59 | # Put content of tensors into the batch tensor 60 | for i, (tensor, shape) in enumerate(zip(tensor_list, shape_list)): 61 | tensor_slice = [i, *[slice(size) for size in shape]] 62 | tensor_batch[tensor_slice] = tensor 63 | 64 | return tensor_batch 65 | 66 | 67 | def mask_log(x: torch.FloatTensor, mask: Optional[torch.Tensor] = None) -> torch.FloatTensor: 68 | """ 69 | Takes the logarithm such that the log of masked entries is zero. 70 | 71 | :param x: FloatTensor whose log will be computed. 72 | :param mask: Tensor with 1s for content and 0s for padding. 73 | Entries in x corresponding to 0s will have a log of 0. 74 | :return: log(x) such that entries where the mask is 0 have a log of 0. 75 | """ 76 | if mask is not None: 77 | # Set masked entries of x equal to 1 (in a differentiable way) so log(1) = 0 78 | mask = mask.float() 79 | x = x * mask + (1 - mask) 80 | 81 | return torch.log(x) 82 | 83 | 84 | def p_log_p(p: torch.FloatTensor) -> torch.FloatTensor: 85 | """ 86 | Computes p * log(p) so that 0 * log(0) = 0. 87 | 88 | :param p: A FloatTensor of probabilities between 0 and 1. 89 | :return: The elementwise computation p * log(p). 90 | """ 91 | return p * mask_log(p, mask=p != 0) 92 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import torch 5 | from torch.optim import Adam 6 | from torch.optim.lr_scheduler import ExponentialLR 7 | 8 | 9 | from utils.parsing import Arguments, parse_args 10 | from utils.utils import load_weights, makedirs, make_schedular, NoamLR 11 | 12 | 13 | def train(args: Arguments) -> None: 14 | """Trains an AlignmentModel to align sets of sentences.""" 15 | 16 | if args.task == "classify": 17 | from classify.compute import AlignmentTrainer 18 | from classify.data import load_data 19 | from classify.metric import load_loss_and_metrics 20 | 21 | if args.word_to_word: 22 | from classify.models.ot_atten import AlignmentModel 23 | else: 24 | from classify.models.ot_atten_sent import AlignmentModel 25 | 26 | elif args.task == "similarity": 27 | from similarity.compute import AlignmentTrainer 28 | from similarity.data import load_data 29 | from similarity.metric import load_loss_and_metrics 30 | from similarity.models import AlignmentModel 31 | 32 | # Determine device 33 | device = ( 34 | torch.device(args.gpu) if torch.cuda.is_available() else torch.device("cpu") 35 | ) 36 | # device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 37 | 38 | print("Loading data") 39 | text_field, train_sampler, dev_sampler, test_sampler = load_data(args, device) 40 | 41 | print("Building model") 42 | model = AlignmentModel( 43 | args=args, text_field=text_field, domain=args.dataset, device=device, 44 | ) 45 | 46 | saved_step = 0 47 | if args.checkpoint_path is not None: 48 | print(f"Loading checkpoint from: {args.checkpoint_path}") 49 | saved_step = 1 + int(args.checkpoint_path.split("_")[-1].replace(".pt", "")) 50 | print(f"trainig from step {saved_step}") 51 | load_weights(model, args.checkpoint_path) 52 | 53 | print(model) 54 | print(f"Number of parameters = {model.num_parameters(trainable=True):,}") 55 | 56 | print(f"Moving model to device: {device}") 57 | model.to(device) 58 | 59 | print("Defining loss and metrics") 60 | ( 61 | loss_fn, 62 | metric_fn, 63 | extra_training_metrics, 64 | extra_validation_metrics, 65 | ) = load_loss_and_metrics(args) 66 | 67 | print("Creating optimizer and scheduler") 68 | if args.bert: 69 | # Prepare optimizer and schedule (linear warmup and decay) 70 | from transformers import AdamW 71 | from transformers import get_linear_schedule_with_warmup 72 | 73 | no_decay = ["bias", "LayerNorm.weight"] 74 | optimizer_grouped_parameters = [ 75 | { 76 | "params": [ 77 | p 78 | for n, p in model.named_parameters() 79 | if not any(nd in n for nd in no_decay) 80 | ], 81 | "weight_decay": args.weight_decay, 82 | }, 83 | { 84 | "params": [ 85 | p 86 | for n, p in model.named_parameters() 87 | if any(nd in n for nd in no_decay) 88 | ], 89 | "weight_decay": 0.0, 90 | }, 91 | ] 92 | # optimizer_grouped_parameters = get_params(model) 93 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.lr, eps=1e-8) 94 | # num_batch_per_epoch = min(train_data.num_batches, args.max_batches_per_epoch) 95 | num_batch_per_epoch = len(train_sampler) 96 | t_total = int( 97 | num_batch_per_epoch // args.gradient_accumulation_steps * args.epochs 98 | ) 99 | # scheduler = WarmupLinearSchedule(optimizer, warmup_steps=int(t_total*0.06), t_total=t_total) 100 | scheduler = get_linear_schedule_with_warmup( 101 | optimizer, 102 | num_warmup_steps=int(t_total * args.warmup_ratio), 103 | num_training_steps=t_total, 104 | ) 105 | else: 106 | optimizer = Adam( 107 | model.trainable_params, lr=args.lr, weight_decay=args.weight_decay 108 | ) 109 | scheduler = make_schedular( 110 | args, optimizer, model.output_size, last_epoch=saved_step - 1 111 | ) 112 | 113 | print("Building Trainer") 114 | trainer = AlignmentTrainer( 115 | args=args, 116 | train_sampler=train_sampler, 117 | dev_sampler=dev_sampler, 118 | test_sampler=test_sampler, 119 | model=model, 120 | loss_fn=loss_fn, 121 | metric_fn=metric_fn, 122 | optimizer=optimizer, 123 | scheduler=scheduler, 124 | epochs=args.epochs, 125 | extra_training_metrics=extra_training_metrics, 126 | extra_validation_metrics=extra_validation_metrics, 127 | log_dir=args.log_dir, 128 | log_frequency=args.log_frequency, 129 | gradient_accumulation_steps=args.gradient_accumulation_steps, 130 | sparsity_thresholds=args.sparsity_thresholds, 131 | saved_step=saved_step, 132 | ) 133 | 134 | if args.epochs > 0: 135 | print("Training") 136 | while not trainer.step(): 137 | pass 138 | 139 | if args.preds_dir is not None or args.viz_dir is not None: 140 | print("Predicting") 141 | sentences, preds, targets = trainer.predict(num_predict=args.num_predict) 142 | 143 | # Extract targets 144 | targets = [target["targets"] for target in targets] 145 | targets = [t.item() for target in targets for t in target] 146 | 147 | # Convert indices back to tokens 148 | sentences = [ 149 | ( 150 | [text_field.deprocess(sentence) for sentence in doc_1], 151 | [text_field.deprocess(sentence) for sentence in doc_2], 152 | ) 153 | for doc_1, doc_2 in sentences 154 | ] 155 | 156 | # Save predictions 157 | if args.preds_dir is not None: 158 | makedirs(args.preds_dir) 159 | preds_path = os.path.join(args.preds_dir, "preds.pkl") 160 | with open(preds_path, "wb") as f: 161 | sentences, preds, targets = sentences, preds, targets 162 | pickle.dump((sentences, preds, targets), f) 163 | 164 | elif args.epochs == 0: 165 | print("Evaluating") 166 | trainer.eval_step() 167 | 168 | 169 | if __name__ == "__main__": 170 | import sys 171 | 172 | from utils.utils import Logger 173 | 174 | # Parse args 175 | args = parse_args() 176 | 177 | # Set up logging to console and file 178 | sys.stdout = Logger( 179 | pipe=sys.stdout, log_path=os.path.join(args.log_dir, "stdout.txt") 180 | ) 181 | sys.stderr = Logger( 182 | pipe=sys.stderr, log_path=os.path.join(args.log_dir, "stderr.txt") 183 | ) 184 | 185 | # Print and save args 186 | print(args) 187 | args.save(os.path.join(args.log_dir, "args.json")) 188 | 189 | # Train 190 | train(args) 191 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asappresearch/rationale-alignment/8d2bf06ba4c121863833094d5d4896bf34a9a73e/utils/__init__.py -------------------------------------------------------------------------------- /utils/berttokenizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import * 3 | from typing import List 4 | from torch.nn.utils.rnn import pad_sequence 5 | 6 | 7 | class BTTokenizer: 8 | """ 9 | Preprocessor that splits text on whitespace into tokens. 10 | Example: 11 | >>> preprocessor = SplitPreprocessor() 12 | >>> preprocessor.process('Hi how may I help you?') 13 | ['Hi', 'how', 'may', 'I', 'help', you?'] 14 | """ 15 | 16 | def __init__(self, args): 17 | self.tokenizer = AutoTokenizer.from_pretrained(args.bert_type) 18 | print("finish loading bert tokenizer") 19 | self.add_special_tokens = True 20 | self.max_len = args.max_sentence_length 21 | # self.pad_index = self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token) 22 | self.pad_token = self.tokenizer.pad_token 23 | 24 | def pad_index(self) -> int: 25 | """Get the padding index. 26 | 27 | Returns 28 | ------- 29 | int 30 | The padding index in the vocabulary 31 | 32 | """ 33 | # pad_token = tokenizer._convert_token_to_id(tokenizer.pad_token) 34 | pad_token = self.tokenizer.pad_token 35 | return self.tokenizer.convert_tokens_to_ids(pad_token) 36 | 37 | def process(self, text: str) -> List[str]: 38 | """Split text on whitespace into tokens.""" 39 | tokens = self.tokenizer.encode( 40 | text, add_special_tokens=self.add_special_tokens, max_length=self.max_len 41 | ) 42 | if len(tokens) > self.max_len: 43 | print(len(tokens)) 44 | processed = torch.LongTensor(tokens) 45 | return processed 46 | 47 | def deprocess(self, idx) -> List[str]: 48 | """Split text on whitespace into tokens.""" 49 | text = self.tokenizer.decode(idx.numpy(), skip_special_tokens=True) 50 | # print(text) 51 | text = ( 52 | text.replace("[CLS]", "").replace("[SEP]", "").replace(self.pad_token, "") 53 | ) 54 | text = text.strip() 55 | # print(text) 56 | return text 57 | 58 | -------------------------------------------------------------------------------- /utils/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.1" 2 | --------------------------------------------------------------------------------