├── cache └── .gitignore ├── configs ├── .gitignore ├── hyperparams.json └── data_path.json ├── data └── .gitignore ├── model ├── .gitignore ├── basicmodules.py ├── model.py └── submodules.py ├── saved └── .gitignore ├── helper.py ├── README.md ├── run.py └── data_helper.py /cache/.gitignore: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /configs/.gitignore: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /model/.gitignore: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /saved/.gitignore: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /configs/hyperparams.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 42, 3 | "do_train_dev": true, 4 | "do_test": true, 5 | "do_infer": false, 6 | "vocab_limit": null, 7 | "vocab_thresh": null, 8 | "weight_scheme": "tws", 9 | "tws_thresh": 10.0, 10 | "train_bs": 64, 11 | "dev_bs": 512, 12 | "test_bs": 512, 13 | "infer_bs": 512, 14 | "max_segments": 128, 15 | "max_segment_size": null, 16 | "num_mpath_samples": 8, 17 | "hidden_size": 200, 18 | "opt_lr": 1e-4, 19 | "opt_wt_decay": 1e-4, 20 | "sch_factor": 0.2, 21 | "sch_patience": 3, 22 | "num_epochs": 2, 23 | "pthresh": 0.51, 24 | "thetas": [1,2,3], 25 | "lambdas": [0.25,0.75], 26 | "dropout": 0.1 27 | } -------------------------------------------------------------------------------- /configs/data_path.json: -------------------------------------------------------------------------------- 1 | { 2 | "sec_src": "data/secs.jsonl", 3 | "train_src": "data/train.jsonl", 4 | "dev_src": "data/dev.jsonl", 5 | "test_src": "data/test.jsonl", 6 | "infer_src": null, 7 | "sec_cache": "cache/secs.pkl", 8 | "train_cache": "cache/train.pkl", 9 | "dev_cache": "cache/dev.pkl", 10 | "test_cache": "cache/test.pkl", 11 | "infer_cache": "cache/infer.pkl", 12 | "s2v_path": "data/ils2v.bin", 13 | "type_map": "data/type_map.json", 14 | "label_tree": "data/label_tree.json", 15 | "citation_network": "data/citation_network.json", 16 | "schemas": "data/schemas.json", 17 | "model_load": null, 18 | "metrics_load": null, 19 | "model_dump": "saved/best_model.pt", 20 | "dev_metrics_dump": "saved/dev_metrics.json", 21 | "test_metrics_dump": "saved/test_metrics.json", 22 | "infer_trg": null 23 | } -------------------------------------------------------------------------------- /model/basicmodules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class LstmNet(torch.nn.Module): 4 | def __init__(self, hidden_size): 5 | super().__init__() 6 | self.lstm = torch.nn.LSTM(hidden_size, hidden_size // 2, batch_first=True, bidirectional=True) 7 | 8 | def forward(self, inputs, mask=None): # [B, S, H], [B, S] 9 | mask = mask if mask is not None else torch.ones(inputs.size(0), inputs.size(1), device=inputs.device) 10 | lengths = mask.sum(dim=-1) # [B,] 11 | 12 | # need to pack inputs before passing to RNN and unpack obtained outputs 13 | pck_inputs = torch.nn.utils.rnn.pack_padded_sequence(inputs, lengths.cpu(), batch_first=True, enforce_sorted=False) 14 | pck_hidden_all = self.lstm(pck_inputs)[0] 15 | hidden_all = torch.nn.utils.rnn.pad_packed_sequence(pck_hidden_all, batch_first=True)[0] # [B, S, H] 16 | 17 | return hidden_all 18 | 19 | class AttnNet(torch.nn.Module): 20 | def __init__(self, hidden_size, drop=0.1): 21 | super().__init__() 22 | 23 | self.hidden_size = hidden_size 24 | self.attn_fc = torch.nn.Linear(hidden_size, hidden_size) 25 | self.context = torch.nn.Parameter(torch.rand(hidden_size)) 26 | 27 | self.dropout = torch.nn.Dropout(drop) 28 | 29 | def forward(self, inputs, mask=None, dyn_context=None): # [B, S, H], [B, S], [B, H] 30 | 31 | mask = mask if mask is not None else torch.ones(inputs.size(0), inputs.size(1), device=inputs.device) 32 | # use static (learned) context vector if dynamic context is unavailable 33 | context = dyn_context if dyn_context is not None else self.context.expand(inputs.size(0), self.hidden_size) # [B, H] 34 | 35 | act_inputs = torch.tanh(self.dropout(self.attn_fc(inputs))) 36 | 37 | scores = torch.bmm(act_inputs, context.unsqueeze(2)).squeeze(2) # [B, S] 38 | msk_scores = scores.masked_fill((1 - mask).bool(), -1e-32) 39 | msk_scores = torch.nn.functional.softmax(msk_scores, dim=1) 40 | 41 | hidden = torch.sum(inputs * msk_scores.unsqueeze(2), dim=1) # [B, H] 42 | return hidden 43 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from model.submodules import HierAttnNet, MetapathAggrNet, MatchNet 3 | 4 | class LeSICiN(torch.nn.Module): 5 | def __init__(self, hidden_size, num_labels, node_vocab_size, edge_vocab_size, vocab_size=None, label_weights=None, pthresh=0.65, lambdas=(0.5, 0.5), thetas=(3, 2, 3), drop=0.1): 6 | super().__init__() 7 | 8 | self.text_encoder = HierAttnNet(hidden_size, vocab_size=vocab_size) 9 | self.graph_encoder = MetapathAggrNet(node_vocab_size, edge_vocab_size, hidden_size) 10 | self.match_network = MatchNet(hidden_size, num_labels) 11 | 12 | self.match_context_transform = torch.nn.Linear(hidden_size, hidden_size) 13 | self.intra_context_transform = torch.nn.Linear(hidden_size, 2 * hidden_size) # We need double the hidden size for Struct Encoder dynamic context 14 | self.inter_context_transform = torch.nn.Linear(hidden_size, 2 * hidden_size) 15 | 16 | self.criterion = torch.nn.BCEWithLogitsLoss(pos_weight=label_weights) 17 | 18 | self.pred_threshold = pthresh 19 | self.lambdas = lambdas # weights for scores 20 | self.thetas = thetas # weights for losses 21 | self.dropout = torch.nn.Dropout(drop) 22 | 23 | def calculate_losses(self, logits_list, labels): 24 | loss = 0 25 | for i, logits in enumerate(logits_list): 26 | if logits is not None: 27 | loss += self.thetas[i] * self.criterion(logits, labels) 28 | return loss 29 | 30 | def forward(self, fact_batch, sec_batch, pthresh=None): # We have D documents in fact_batch and C sections in sec_batch 31 | if pthresh is not None: 32 | self.pred_threshold = pthresh 33 | 34 | # Encode fact text using HAN 35 | if not fact_batch.sent_vectorized: 36 | fact_attr_hidden = self.text_encoder(tokens=fact_batch.tokens, mask=fact_batch.mask) # [D, H] 37 | else: 38 | fact_attr_hidden = self.text_encoder(doc_inputs=fact_batch.doc_inputs, mask=fact_batch.mask) # [D, H] 39 | 40 | # Encode sec text using HAN 41 | if not sec_batch.sent_vectorized: 42 | sec_attr_hidden = self.text_encoder(tokens=sec_batch.tokens, mask=sec_batch.mask) # [C, H] 43 | else: 44 | sec_attr_hidden = self.text_encoder(doc_inputs=sec_batch.doc_inputs, mask=sec_batch.mask) # [C, H] 45 | 46 | # context vector for matching with fact attributes 47 | attr_match_context = self.dropout(self.match_context_transform(fact_attr_hidden)) # [D, H] 48 | 49 | # Attribute scores 50 | attr_logits, attr_scores = self.match_network(fact_attr_hidden, sec_attr_hidden, context=attr_match_context) 51 | 52 | # sec-side context vectors for Struct Encoder 53 | sec_intra_context = self.dropout(self.intra_context_transform(sec_attr_hidden)).repeat(sec_batch.num_mpath_samples, 1) # [M*C, H] 54 | sec_inter_context = self.dropout(self.inter_context_transform(sec_attr_hidden)) # [C, H] 55 | 56 | # Encode sec graph using MAGNN 57 | sec_struct_hidden = self.graph_encoder(sec_batch.node_tokens, sec_batch.edge_tokens, sec_batch.schemas, intra_context=sec_intra_context, inter_context=sec_inter_context) # [C, H] 58 | 59 | # Alignment scores 60 | align_logits, align_scores = self.match_network(fact_attr_hidden, sec_struct_hidden, context=attr_match_context) 61 | 62 | if fact_batch.sample_metapaths: 63 | # fact-side context vectors for Struct Encoder 64 | fact_intra_context = self.dropout(self.intra_context_transform(fact_attr_hidden)).repeat(fact_batch.num_mpath_samples, 1) # [M*D, H] 65 | fact_inter_context = self.dropout(self.inter_context_transform(fact_attr_hidden)) # [D, H] 66 | 67 | # Encode sec graph using MAGNN 68 | fact_struct_hidden = self.graph_encoder(fact_batch.node_tokens, fact_batch.edge_tokens, fact_batch.schemas, intra_context=fact_intra_context, inter_context=fact_inter_context) # [D, H] 69 | 70 | # context vector for matching with fact structure 71 | struct_match_context = self.dropout(self.match_context_transform(fact_struct_hidden)) # [D, H] 72 | 73 | # Structural scores 74 | struct_logits, struct_scores = self.match_network(fact_struct_hidden, sec_struct_hidden, context=struct_match_context) 75 | 76 | else: 77 | struct_logits = None 78 | 79 | # Combine scores and losses 80 | scores = (self.lambdas[0] * attr_scores + self.lambdas[-1] * align_scores) 81 | predictions = (scores > self.pred_threshold).float() 82 | 83 | if fact_batch.annotated: 84 | loss = self.calculate_losses([attr_logits, struct_logits, align_logits], fact_batch.labels) 85 | else: 86 | loss = None 87 | 88 | return loss, predictions 89 | -------------------------------------------------------------------------------- /helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_sparse 3 | from tqdm import tqdm 4 | from collections import defaultdict 5 | 6 | from data_helper import MiniBatch 7 | 8 | # Create word vocab (if not sent vectorized) and label vocab from the specific order in the Section dataset 9 | def generate_vocabs(train_data, label_data, limit=30000, thresh=1): 10 | if not train_data.sent_vectorized: 11 | freqs = defaultdict(int) 12 | for instance in tqdm(train_data.dataset + label_data.dataset, desc="Creating vocabulary"): 13 | for sent in instance['text']: 14 | for word in sent: 15 | freqs[word] += 1 16 | vocab_set = set(w for w, f in freqs.items() if f >= thresh) 17 | vocab = {k: i for i, k in enumerate(vocab_set)} 18 | else: 19 | vocab = None 20 | 21 | label_vocab = {} 22 | for instance in tqdm(label_data.dataset, desc="Creating label vocabulary"): 23 | label_vocab[instance['id']] = len(label_vocab) 24 | 25 | return vocab, label_vocab 26 | 27 | # Create the entire graph by combining the label tree and fact-sec citation network 28 | def generate_graph(label_vocab, type_map, label_tree_edges, cit_net_edges, label_name='section'): 29 | node_vocab = defaultdict(dict) # each key is a node_type, and each value is a dict storing the vocab of all nodes under given node type 30 | node_vocab[label_name] = label_vocab # manually set this since we want the label vocab to be consistent with node vocab for labels 31 | 32 | edge_vocab = {} 33 | edge_indices = defaultdict(list) # each key is a tuple (src node type, relationship name, trg node type), and each value is a list storing the edges from src node type to trg node type 34 | 35 | for (node_a, edge_type, node_b) in label_tree_edges + cit_net_edges: 36 | # first get the node type 37 | node_a_type, node_b_type = type_map[node_a], type_map[node_b] 38 | 39 | # create new vocab entries for edges and nodes 40 | if edge_type not in edge_vocab: 41 | edge_vocab[edge_type] = len(edge_vocab) 42 | 43 | if node_a not in node_vocab[node_a_type]: 44 | node_vocab[node_a_type][node_a] = len(node_vocab[node_a_type]) 45 | if node_b not in node_vocab[node_b_type]: 46 | node_vocab[node_b_type][node_b] = len(node_vocab[node_b_type]) 47 | 48 | # get node indices 49 | node_a_token = node_vocab[node_a_type][node_a] 50 | node_b_token = node_vocab[node_b_type][node_b] 51 | 52 | edge_indices[(node_a_type, edge_type, node_b_type)].append([node_a_token, node_b_token]) 53 | 54 | num_nodes = {ntype: len(nodes) for ntype, nodes in node_vocab.items()} 55 | 56 | # same as edge_indices except that the edges under each key are now stored as sparse matrices 57 | adjacency = {} 58 | for keys, edges in edge_indices.items(): 59 | row, col = torch.tensor(edges).t() 60 | sizes = (num_nodes[keys[0]], num_nodes[keys[-1]]) 61 | adj = torch_sparse.SparseTensor(row=row, col=col, sparse_sizes=sizes) 62 | adjacency[tuple(keys)] = adj 63 | 64 | return node_vocab, edge_vocab, edge_indices, adjacency 65 | 66 | # create label weights for BCE Loss since we have unbalanced class distribution 67 | def generate_label_weights(train_data, label_vocab, dev='cuda:0', scheme="tws", thresh=10.): 68 | pos = torch.zeros(len(label_vocab), device=dev) 69 | for instance in tqdm(train_data, desc="Generating label weights"): 70 | for l in instance['labels']: 71 | pos[label_vocab[l]] += 1 72 | weights = torch.clamp(pos.max() / pos, max=thresh) if scheme == 'tws' else len(train_data) / pos 73 | return weights 74 | 75 | # Unified code to deal with a single train / dev / test / inference pass over the dataset 76 | def train_dev_pass(model, optimizer, fact_loader, sec_batch, metrics=None, pred_threshold=None, train=False, infer=False, label_vocab=False): 77 | model.train() if train else model.eval() 78 | 79 | if infer: 80 | outputs = [] 81 | inv_label_vocab = {v: k for k, v in label_vocab.items()} 82 | 83 | for i, fact_batch in enumerate(tqdm(fact_loader, desc="Flowing data through model")): 84 | torch.cuda.empty_cache() 85 | 86 | loss, predictions = model(fact_batch.cuda(), sec_batch.cuda(), pthresh=pred_threshold) 87 | 88 | if train: 89 | loss.backward() 90 | optimizer.step() 91 | optimizer.zero_grad() 92 | 93 | if not infer: 94 | batch_loss = loss.item() 95 | metrics(predictions, fact_batch.labels, loss=batch_loss) 96 | 97 | else: 98 | for i, instance_preds in enumerate(predictions): 99 | # gather true predictions 100 | pred_list_indices = torch.nonzero(instance_preds, as_tuple=False).squeeze(1) 101 | pred_list = [inv_label_vocab[idx] for idx in pred_list_indices] 102 | outputs.append({'id': fact_batch.example_ids[i], 'predictions': pred_list}) 103 | 104 | return metrics.calculate_metrics() if not infer else outputs 105 | 106 | class MultiLabelMetrics(torch.nn.Module): 107 | def __init__(self, num_classes, dev='cuda', loss=True): 108 | super().__init__() 109 | 110 | self.match = torch.zeros(num_classes, device=dev) # count no. of true positives for each label 111 | self.predictions = torch.zeros(num_classes, device=dev) # count no. of true positives + false positives for each label 112 | self.labels = torch.zeros(num_classes, device=dev) # count no. of true positives + false negatives for each label 113 | self.run_jacc = 0 # running sum of jaccard scores 114 | self.counter = 0 # count no. of batches 115 | if loss: 116 | self.run_loss = 0 # running sum of losses 117 | 118 | # to be called with a batch of predictions and true labels 119 | def forward(self, predictions, labels, loss=None): 120 | match = predictions * labels # true positives for this batch 121 | 122 | # increment counts 123 | self.match += match.sum(dim=0) 124 | self.predictions += predictions.sum(dim=0) 125 | self.labels += labels.sum(dim=0) 126 | self.run_jacc += torch.sum(torch.logical_and(predictions, labels).sum(dim=1) / torch.logical_or(predictions, labels).sum(dim=1)).item() 127 | self.counter += 1 128 | 129 | if loss is not None: 130 | self.run_loss += loss 131 | 132 | # reset counters 133 | def refresh(self): 134 | self.match.fill_(0) 135 | self.predictions.fill_(0) 136 | self.labels.fill_(0) 137 | self.run_jacc = 0 138 | self.counter = 0 139 | if 'run_loss' in self.__dict__: 140 | self.run_loss = 0 141 | return self 142 | 143 | # calculate the metrics and return self 144 | def calculate_metrics(self, refresh=True): 145 | prec = self.match / self.predictions # P = TP / (TP + FP) 146 | rec = self.match / self.labels # R = TP (TP + FN) 147 | 148 | prec[prec.isnan()] = 0 149 | rec[rec.isnan()] = 0 150 | 151 | f1 = 2 * prec * rec / (prec + rec) # F1 = 2 * P * R / (P + R) 152 | f1[f1.isnan()] = 0 153 | 154 | # macro --> average across each label 155 | self.macro_prec = prec.mean().item() 156 | self.macro_rec = rec.mean().item() 157 | self.macro_f1 = f1.mean().item() 158 | 159 | match_total = self.match.sum().item() 160 | preds_total = self.predictions.sum().item() 161 | labels_total = self.labels.sum().item() 162 | 163 | # micro --> take total counts 164 | self.micro_prec = match_total / preds_total if preds_total > 0 else 0 165 | self.micro_rec = match_total / labels_total 166 | self.micro_f1 = 0 if self.micro_prec + self.micro_rec == 0 else 2 * self.micro_prec * self.micro_rec / (self.micro_prec + self.micro_rec) 167 | 168 | self.jacc = self.run_jacc / self.counter 169 | if 'run_loss' in self.__dict__: 170 | self.loss = self.run_loss / self.counter 171 | 172 | if refresh: 173 | self.refresh() 174 | 175 | return self 176 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LeSICiN 2 | Dataset and codes for the paper "LeSICiN: A Heterogeneous Graph-based Approach for Automatic Legal Statute Identification from Indian Legal Documents", accepted and to be published at AAAI 2022. 3 | 4 | ## About 5 | The task of **Legal Statute Identification (LSI)** aims to identify the legal statutes that are relevant to a given description of facts or evidence of a legal case. 6 | Existing methods only utilize the textual content of facts and legal articles to guide such a task. However, the citation network among case documents and legal statutes is a rich source of additional information, which is not considered by existing models. 7 | In this work, we take the first step towards utilising both the text and the legal citation network for the LSI task. 8 | We curate a large novel dataset for this task, including facts of cases from several major Indian Courts of Law, and statutes from the Indian Penal Code (IPC). 9 | Modeling the statutes and training documents as a heterogeneous graph, our proposed model **LeSICiN** can learn rich textual and graphical features, and can also tune itself to correlate these features. 10 | Thereafter, the model can be used to inductively predict links between test documents (new nodes whose graphical features are not available to the model) and statutes (existing nodes). 11 | Extensive experiments on the dataset show that our model comfortably outperforms several state-of-the-art baselines, by exploiting the graphical structure along with textual features. 12 | 13 | ## Citation 14 | If you use this dataset or the codes, please refer to the following paper: 15 | ``` 16 | @inproceedings{paul2022lesicin, 17 | author = {Paul, Shounak and Goyal, Pawan and Ghosh, Saptarshi}, 18 | title = {{LeSICiN: A Heterogeneous Graph-based Approach for Automatic Legal Statute Identification from Indian Legal Documents}}, 19 | booktitle = {{Proceedings of the 36th AAAI Conference on Artificial Intelligence (AAAI)}}, 20 | year = {2022} 21 | } 22 | ``` 23 | 24 | ## Repo Organization 25 | ``` 26 | - model 27 | - basicmodules.py - Contains basic building blocks -- LSTM and Attention Networks 28 | - submodules.py - Contains Text and Graph Encoders and Matching Network 29 | - model.py - Contains main module LeSICiN 30 | - data_helper.py - Helper codes for constructing Dataset and batching logic 31 | - helper.py - Helper codes for creating vocabularies, label weights, training loop, metrics, etc. 32 | - run.py - Script for running training and evaluation, and/or testing, or inference 33 | ``` 34 | ## Data 35 | Datasets are in the form of .jsonl files, with one instance (dict) per line, having the following keys 36 | ``` 37 | id [string]: id/name of the fact instance / section instance 38 | text [list[string]]: text in the form of list of sentences (each sentence is a string) 39 | labels [list[string] / null]: gold-standard labels (not needed for inference) 40 | ``` 41 | Apart from this, we need a type_map, which will be a dict mapping node ids to their type (Act / Chapter / Topic / Section / Fact). We also need two files, storing the edges of the label tree and the Fact-Section citation network in the following format 42 | ``` 43 | (src node id, relationship name, trg node id) 44 | ``` 45 | Finally, we also need the metapath schemas for each node type. Each individual schema is a list of the edges that make up the metapath. An edge is described by the tuple 46 | ``` 47 | (src node type, relationship name, trg node type) 48 | ``` 49 | You can find all these files at: https://doi.org/10.5281/zenodo.6053791 50 | 51 | ## Configs 52 | To make it easy to configure experiments on the go, we make use of two config files stored in configs/ folder. 53 | 54 | *data_path.json* - Specifies the full file paths for loading data, models, etc. 55 | ``` 56 | sec_src [string]: path to sec source data file 57 | train_src [string/null]: path to train source data file (null if not running train-dev) 58 | dev_src [string/null]: path to dev source data file (null if not running train-dev) 59 | test_src [string/null]: path to test source data file (null if not running test) 60 | infer_src [string/null]: path to infer source data file (null if not running infer) 61 | 62 | sec_cache [string]: path to sec cached data file 63 | train_cache [string/null]: path to train cached data file (null if not running train-dev) 64 | dev_cache [string/null]: path to dev cached data file (null if not running train-dev) 65 | test_cache [string/null]: path to test cached data file (null if not running test) 66 | infer_cache [string/null]: path to infer cached data file (null if not running infer) 67 | 68 | s2v_path [string/null]: path to pretrained sent2vec file (null if you do not want to sent vectorize) 69 | 70 | type_map [string]: path to file that maps id of each node to its type 71 | label_tree [string]: path to file that stores edges of the label tree 72 | citation_network [string]: path to file that stores edges of the Fact-Section citation net 73 | schemas [string]: path to file that stores the schemas 74 | 75 | model_load [string/null]: checkpointed model to load from (null to train from scratch) 76 | metrics_load [string/null]: saved validation metrics to act as benchmark (null to train from scratch) 77 | 78 | model_dump [string]: path to file where trained model will be saved 79 | dev_metrics_dump [string]: path to file where best validation metrics will be saved 80 | test_metrics_dump [string]: path to file where test metrics will be saved 81 | infer_trg [string]: path to file where inference predictions will be saved 82 | ``` 83 | 84 | *hyperparams.json* - Controls the model and experiment hyperparameters and few other settings, like the seed. 85 | ``` 86 | seed [int]: universal seed for random, numpy and torch 87 | do_train_dev [bool]: true if running train-dev 88 | do_test [bool]: true if running test 89 | do_infer [bool]: true if running infer 90 | vocab_limit [int/null]: maximum vocabulary size [null if using sent vectorization] 91 | vocab_thresh [int/null]: minimum frequency for a word to be considered in vocabulary [null if using sent vectorization] 92 | weight_scheme {"tws", "vws"}: choose between Threshold-based weighting scheme and Vanilla Weighting Scheme as discussed in the paper 93 | tws_thresh [float/null]: TWS threshold as discussed in the paper (null if using VWS) 94 | 95 | train_bs [int]: batch size (no. of facts) for training 96 | dev_bs [int]: batch size (no. of facts) for validation 97 | test_bs [int]: batch size (no. of facts) for testing 98 | infer_bs [int]: batch size (no. of facts) for inference 99 | max_segments [int]: maximum no. of sentences per document (fact or section) 100 | max_segment_size [int/null]: maximum no. of words per sentence (null if using sent vectorization) 101 | num_mpath_samples [int]: no. of metapath instances to sample per metapath schema 102 | 103 | hidden_size [int]: hidden dimension for all intermediate layers (if using sent vectorization, make sure this is equal to the dimension of the sent2vec embeddings) 104 | 105 | opt_lr [float]: learning rate for optimizer 106 | opt_wt_decay [float]: optimizer weight decay 107 | sch_factor [float]: factor for the ReduceLROnPlateau scheduler 108 | sch_patience [int]: patience for the ReduceLROnPlateau scheduler 109 | num_epoch [int]: no. of training epochs 110 | 111 | pthresh [float]: prediction threshold to be used by model 112 | thetas [tuple[float]]: thetas in the order (attr, struct, align) 113 | lambdas [tuple[float]]: lambdas in the order (attr, align) 114 | dropout [float]: dropout factor for model layers 115 | ``` 116 | 117 | ## Running the Script 118 | All kinds of operations (train/dev/test/infer) can be performed by the "run.py" code, by appropriately configuring its settings. See the section above to understand the different settings. You can run the script using: 119 | ``` 120 | python run.py 121 | ``` 122 | ## Outputs 123 | In case of train / dev / test, a metrics object is saved in the path specified in dev / test metrics dump key in data path config, which contains the following scores: 124 | ``` 125 | - macro 126 | - precision 127 | - recall 128 | - f1 129 | - micro 130 | - precision 131 | - recall 132 | - f1 133 | - jaccard 134 | ``` 135 | If training is performed, the model state corresponding to the best dev loss is also saved in the path specified in model dump key in data path config. 136 | During inference, instead of metrics, a jsonl file is saved in the path specified in infer trg key in data path config. Each line in the jsonl file is a dict with the following keys: 137 | ``` 138 | id [string]: id of the fact instance 139 | predictions [list[string]]: model predictions for this instance 140 | ``` 141 | -------------------------------------------------------------------------------- /model/submodules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from model.basicmodules import LstmNet, AttnNet 3 | 4 | class HierAttnNet(torch.nn.Module): 5 | def __init__(self, hidden_size, vocab_size=None, drop=0.1): 6 | super().__init__() 7 | 8 | if vocab_size is not None: 9 | self.word_embedding = torch.nn.Embedding(vocab_size, hidden_size) 10 | self.sent_lstm = LstmNet(hidden_size) 11 | self.sent_attn = AttnNet(hidden_size, drop=drop) 12 | 13 | self.doc_lstm = LstmNet(hidden_size) 14 | self.doc_attn = AttnNet(hidden_size, drop=drop) 15 | 16 | def forward(self, tokens=None, doc_inputs=None, mask=None, sent_dyn_context=None, doc_dyn_context=None): # [B, S, W], [B, S, H], [B, S, W] / [B, S], [B, S, H], [B, H] 17 | if tokens is not None: 18 | sent_inputs = self.word_embedding(tokens) 19 | 20 | # flatten to 3-D 21 | sent_inputs = sent_inputs.view(-1, sent_inputs.size(2), sent_inputs.size(3)) 22 | sent_mask = mask.view(-1, mask.size(2)) 23 | 24 | if sent_dyn_context is not None: 25 | sent_dyn_context = sent_dyn_context.view(-1, sent_dyn_context.size(2)) 26 | 27 | sent_hidden_all = self.sent_lstm(sent_inputs, sent_mask) 28 | sent_hidden = self.sent_attn(sent_hidden_all, sent_mask, dyn_context=sent_dyn_context) 29 | 30 | doc_inputs = sent_hidden.view(tokens.size(0), tokens.size(1), -1) 31 | doc_mask = (mask.sum(dim=2) > 0).float() 32 | else: 33 | doc_mask = mask 34 | 35 | doc_hidden_all = self.doc_lstm(doc_inputs, doc_mask) 36 | doc_hidden = self.doc_attn(doc_hidden_all, doc_mask, dyn_context=doc_dyn_context) 37 | return doc_hidden 38 | 39 | class MetapathAggrNet(torch.nn.Module): 40 | def __init__(self, node_vocab_size, edge_vocab_size, hidden_size, drop=0.1, gdel=14.): 41 | super().__init__() 42 | self.emb_range = gdel / hidden_size 43 | 44 | self.node_embedding = torch.nn.ModuleDict({ntype: torch.nn.Embedding(num_nodes, hidden_size) for ntype, num_nodes in node_vocab_size.items()}) 45 | for ntype, ntype_weights in self.node_embedding.items(): 46 | ntype_weights.weight.data.uniform_(- self.emb_range, self.emb_range) 47 | 48 | self.scale_fc = torch.nn.ModuleDict({ntype: torch.nn.Linear(hidden_size, hidden_size) for ntype in node_vocab_size}) 49 | 50 | self.edge_embedding = torch.nn.Embedding(edge_vocab_size, hidden_size // 2) 51 | self.edge_embedding.weight.data.uniform_(- self.emb_range, self.emb_range) 52 | 53 | self.intra_attention = AttnNet(2 * hidden_size, drop=drop) 54 | 55 | self.inter_fc = torch.nn.Linear(2 * hidden_size, 2 * hidden_size) 56 | self.inter_context = torch.nn.Parameter(torch.rand(2 * hidden_size)) 57 | 58 | self.output_fc = torch.nn.Linear(2 * hidden_size, hidden_size) 59 | 60 | self.dropout = torch.nn.Dropout(drop) 61 | 62 | # Embed each node index using the node embedding matrix and then scale to generate same sized embeddings for each node type 63 | def embed_and_scale(self, tokens, edge_tokens, schema): # [B, L+1], [B, L] 64 | inputs, edge_inputs = [], [] 65 | 66 | node_type = schema[0][0] 67 | node_input = self.dropout(self.node_embedding[node_type](tokens[:, 0])) # [B, H] 68 | inputs.append(self.dropout(self.scale_fc[node_type](node_input))) 69 | 70 | for i in range(edge_tokens.size(1)): 71 | node_type = schema[i][2] 72 | node_input = self.dropout(self.node_embedding[node_type](tokens[:, i+1])) # [B, H] 73 | inputs.append(self.dropout(self.scale_fc[node_type](node_input))) 74 | 75 | edge_inputs.append(self.dropout(self.edge_embedding(edge_tokens[:, i]))) 76 | inputs = torch.stack(inputs, dim=1) # [B, L+1, H] 77 | edge_inputs = torch.stack(edge_inputs, dim=1) # [B, L, H] 78 | return inputs, edge_inputs 79 | 80 | # We are following the official implementation of the RotatE algorithm --- https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding 81 | def rotational_encoding(self, inputs, edge_inputs): # [B, L+1, H], [B, L, H/2] 82 | PI = 3.14159265358979323846 83 | hidden = inputs.clone() 84 | for i in reversed(range(edge_inputs.size(1))): 85 | hid_real, hid_imag = torch.chunk(hidden.clone()[:, i+1:, :], 2, dim=2) # [B, L-i, H/2], [B, L-i, H/2] 86 | inp_real, inp_imag = torch.chunk(inputs[:, i, :], 2, dim=1) # [B, H/2], [B, H/2] 87 | 88 | edge_complex = edge_inputs[:, i, :] / (self.emb_range / PI) 89 | edge_real, edge_imag = torch.cos(edge_inputs[:, i, :]), torch.sin(edge_inputs[:, i, :]) # [B, H/2], [B, H/2] 90 | 91 | out_real = inp_real.unsqueeze(1) + edge_real.unsqueeze(1) * hid_real - edge_imag.unsqueeze(1) * hid_imag # [B, L-i, H/2] 92 | out_imag = inp_imag.unsqueeze(1) + edge_imag.unsqueeze(1) * hid_real + edge_real.unsqueeze(1) * hid_imag # [B, L-i, H/2] 93 | 94 | hidden[:, i+1:, :] = torch.cat([out_real, out_imag], dim=2) 95 | path_lens = 1 + torch.arange(hidden.size(1), device=hidden.device) # [L+1] 96 | return hidden / path_lens.unsqueeze(0).unsqueeze(2) 97 | 98 | def forward(self, tokens, edge_tokens, schemas, intra_context=None, inter_context=None): 99 | hidden = [] 100 | 101 | # serially perform intra-metapath aggregation across the different schemas 102 | for i in range(len(tokens)): 103 | # flatten out the multiple samples of the same schema 104 | mpath_tokens = tokens[i].view(-1, tokens[i].size(2)) # [M*D, L+1] 105 | mpath_edge_tokens = edge_tokens[i].view(-1, edge_tokens[i].size(2)) # [M*D, L] 106 | 107 | mpath_inputs, mpath_edge_inputs = self.embed_and_scale(mpath_tokens, mpath_edge_tokens, schemas[i]) 108 | 109 | mpath_hidden_all = self.rotational_encoding(mpath_inputs, mpath_edge_inputs) # [M*D, L+1, H] 110 | 111 | # the first element in the sequence is the target node, the rest are transformed embeddings for other nodes in the metapath 112 | mpath_hidden_all = torch.cat([mpath_hidden_all[:, 0, :].unsqueeze(1).repeat(1, mpath_hidden_all.size(1) - 1, 1), mpath_hidden_all[:, 1:, :]], dim=2) # [M*D, L, 2H] 113 | mpath_hidden = torch.relu(self.intra_attention(mpath_hidden_all, dyn_context=intra_context)) # [M*D, 2H] 114 | 115 | # aggregate transformed embeddings from multiple samples of the same schema 116 | mpath_hidden = torch.sum(mpath_hidden.view(tokens[i].size(0), tokens[i].size(1), -1), dim=0) # [D, 2H] 117 | hidden.append(mpath_hidden) 118 | hidden = torch.stack(hidden, dim=1) # [D, N, 2H] 119 | 120 | # perform inter-metapath aggregation across transformed embeddings for each schema 121 | hidden_act = torch.mean(torch.tanh(self.dropout(self.inter_fc(hidden))), dim=0).expand_as(hidden) # [D, N, 2H] 122 | context = self.inter_context.unsqueeze(0).repeat(hidden_act.size(0), 1).unsqueeze(2) if inter_context is None else inter_context.unsqueeze(2) 123 | scores = torch.bmm(hidden_act, context) # [D, N, 1] 124 | 125 | outputs = torch.sum(hidden * scores, dim=1) # [D, 2H] 126 | outputs = self.dropout(self.output_fc(outputs)) # [D, H] 127 | 128 | return outputs 129 | 130 | class MatchNet(torch.nn.Module): 131 | def __init__(self, hidden_size, num_labels, drop=0.1): 132 | super().__init__() 133 | 134 | self.match_lstm = LstmNet(hidden_size) 135 | self.match_attn = AttnNet(hidden_size, drop=drop) 136 | self.match_fc = torch.nn.Linear(2 * hidden_size, num_labels) 137 | 138 | self.dropout = torch.nn.Dropout(drop) 139 | 140 | def forward(self, fact_inputs, sec_inputs, context=None): # [D, H], [C, H] 141 | sec_inputs = sec_inputs.expand(fact_inputs.size(0), sec_inputs.size(0), sec_inputs.size(1)) # [D, C, H] 142 | 143 | sec_hidden_all = self.match_lstm(sec_inputs) # [D, C, H] 144 | sec_hidden = self.match_attn(sec_hidden_all, dyn_context=context) # [D, H] 145 | 146 | logits = self.dropout(self.match_fc(torch.cat([fact_inputs, sec_hidden], dim=1))) # [D, C] 147 | scores = torch.sigmoid(logits).detach() # [D, C] 148 | return logits, scores 149 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | print("\nPreparing PyTorch environment") 2 | print("==================================================") 3 | 4 | import torch 5 | import random 6 | from tqdm import tqdm 7 | from functools import partial 8 | import sent2vec 9 | 10 | from model.model import * 11 | from data_helper import * 12 | from helper import * 13 | 14 | 15 | torch.autograd.set_detect_anomaly(True) 16 | torch.backends.cudnn.deterministic = True 17 | torch.backends.cudnn.benchmark = False 18 | 19 | with open("configs/data_path.json") as fr: 20 | dc = json.load(fr) 21 | with open("configs/hyperparams.json") as fr: 22 | hc = json.load(fr) 23 | 24 | SEED = hc['seed'] 25 | 26 | random.seed(SEED) 27 | np.random.seed(SEED) 28 | torch.manual_seed(SEED) 29 | torch.cuda.manual_seed(SEED) 30 | 31 | if dc['s2v_path'] is not None: 32 | sent2vec_model = sent2vec.Sent2vecModel() 33 | sent2vec_model.load_model(dc['s2v_path']) 34 | 35 | print("\nPreparing Datasets") 36 | print("==================================================") 37 | sec_dataset = LSIDataset(jsonl_file=dc['sec_src']) 38 | sec_dataset.preprocess() 39 | sec_dataset.sent_vectorize(sent2vec_model) 40 | sec_dataset.save_data(dc['sec_cache']) 41 | 42 | # sec_dataset = LSIDataset.load_data(dc['dev_cache']) 43 | 44 | if hc['do_train_dev']: 45 | train_dataset = LSIDataset(jsonl_file=dc['train_src']) 46 | train_dataset.preprocess() 47 | train_dataset.sent_vectorize(sent2vec_model) 48 | train_dataset.save_data(dc['train_cache']) 49 | 50 | # train_dataset = LSIDataset.load_data(dc['train_cache']) 51 | 52 | dev_dataset = LSIDataset(jsonl_file=dc['dev_src']) 53 | dev_dataset.preprocess() 54 | dev_dataset.sent_vectorize(sent2vec_model) 55 | dev_dataset.save_data(dc['dev_cache']) 56 | 57 | # dev_dataset = LSIDataset.load_data(dc['dev_cache']) 58 | 59 | if hc['do_test']: 60 | test_dataset = LSIDataset(jsonl_file=dc['test_src']) 61 | test_dataset.preprocess() 62 | test_dataset.sent_vectorize(sent2vec_model) 63 | test_dataset.save_data(dc['test_cache']) 64 | 65 | # test_dataset = LSIDataset.load_data(dc['test_cache']) 66 | 67 | if hc['do_infer']: 68 | infer_dataset = LSIDataset(jsonl_file=dc['infer_src']) 69 | infer_dataset.preprocess() 70 | infer_dataset.sent_vectorize(sent2vec_model) 71 | infer_dataset.save_data(dc['infer_cache']) 72 | 73 | # infer_dataset = LSIDataset.load_data(dc['infer_cache']) 74 | 75 | print("\nGathering other data") 76 | print("==================================================") 77 | vocab, label_vocab = generate_vocabs(train_dataset, sec_dataset, limit=hc['vocab_limit'], thresh=hc['vocab_thresh']) 78 | with open(dc['type_map']) as fr: 79 | type_map = json.load(fr) 80 | with open(dc['label_tree']) as fr: 81 | label_tree = json.load(fr) 82 | with open(dc['citation_network']) as fr: 83 | citation_net = json.load(fr) 84 | with open(dc['schemas']) as fr: 85 | schemas = json.load(fr) 86 | for sch in schemas.values(): 87 | for path in sch: 88 | for i, edge in enumerate(path): 89 | path[i] = tuple(path[i]) 90 | 91 | node_vocab, edge_vocab, edge_indices, adjacency = generate_graph(label_vocab, type_map, label_tree, citation_net) 92 | sec_weights = generate_label_weights(train_dataset, label_vocab) 93 | 94 | L = len(label_vocab) 95 | N = {k: len(v) for k,v in node_vocab.items()} 96 | E = len(edge_vocab) 97 | 98 | sec_loader = torch.utils.data.DataLoader( 99 | sec_dataset, 100 | batch_size=len(label_vocab), 101 | collate_fn=partial( 102 | collate_func, 103 | schemas=schemas['section'], 104 | type_map=type_map, 105 | node_vocab=node_vocab, 106 | edge_vocab=edge_vocab, 107 | adjacency=adjacency, 108 | max_segments=hc['max_segments'], 109 | max_segment_size=hc['max_segment_size'], 110 | num_mpath_samples=hc['num_mpath_samples'] 111 | ), 112 | pin_memory=True, 113 | num_workers=4 114 | ) 115 | 116 | if hc['do_train_dev']: 117 | train_loader = torch.utils.data.DataLoader( 118 | train_dataset, 119 | batch_size=hc['train_bs'], 120 | collate_fn=partial( 121 | collate_func, 122 | label_vocab=label_vocab, 123 | schemas=schemas['fact'], 124 | type_map=type_map, 125 | node_vocab=node_vocab, 126 | edge_vocab=edge_vocab, 127 | adjacency=adjacency, 128 | max_segments=hc['max_segments'], 129 | max_segment_size=hc['max_segment_size'], 130 | num_mpath_samples=hc['num_mpath_samples'] 131 | ), 132 | pin_memory=True, 133 | num_workers=4 134 | ) 135 | 136 | dev_loader = torch.utils.data.DataLoader( 137 | dev_dataset, 138 | batch_size=hc['dev_bs'], 139 | collate_fn=partial( 140 | collate_func, 141 | label_vocab=label_vocab, 142 | max_segments=hc['max_segments'], 143 | max_segment_size=hc['max_segment_size'] 144 | ), 145 | pin_memory=True, 146 | num_workers=4 147 | ) 148 | 149 | if hc['do_test']: 150 | test_loader = torch.utils.data.DataLoader( 151 | test_dataset, 152 | batch_size=hc['test_bs'], 153 | collate_fn=partial( 154 | collate_func, 155 | label_vocab=label_vocab, 156 | max_segments=hc['max_segments'], 157 | max_segment_size=hc['max_segment_size'] 158 | ), 159 | pin_memory=True, 160 | num_workers=4 161 | ) 162 | 163 | if hc['do_infer']: 164 | infer_loader = torch.utils.data.DataLoader( 165 | infer_dataset, 166 | batch_size=hc['infer_bs'], 167 | collate_fn=partial( 168 | collate_func, 169 | label_vocab=label_vocab, 170 | max_segments=hc['max_segments'], 171 | max_segment_size=hc['max_segment_size'] 172 | ), 173 | pin_memory=True, 174 | num_workers=4 175 | ) 176 | 177 | for sec_batch in sec_loader: 178 | break 179 | 180 | print("\nPreparing Model") 181 | print("==================================================") 182 | lsc_model = LeSICiN( 183 | hc['hidden_size'], 184 | L, 185 | N, 186 | E, 187 | label_weights=sec_weights, 188 | lambdas=hc['lambdas'], 189 | thetas=hc['thetas'], 190 | pthresh=hc['pthresh'], 191 | drop=hc['dropout'] 192 | ).cuda() 193 | 194 | if dc['model_load'] is not None: 195 | lsc_model.load_state_dict(torch.load(dc['model_load'], map_location='cuda')) 196 | 197 | 198 | if hc['do_train_dev']: 199 | if dc['metrics_load'] is not None: 200 | with open(dc['metrics_dump'], 'rb') as fr: 201 | best_metrics = pkl.load(fr) 202 | best_loss = best_metrics.loss 203 | else: 204 | best_loss = float('inf') 205 | 206 | best_model = lsc_model.state_dict() 207 | 208 | 209 | optimizer = torch.optim.AdamW(lsc_model.parameters(), lr=hc['opt_lr'], weight_decay=hc['opt_wt_decay']) 210 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=hc['sch_patience'], factor=hc['sch_factor']) 211 | train_mlmetrics = MultiLabelMetrics(L) 212 | dev_mlmetrics = MultiLabelMetrics(L) 213 | 214 | print("\nRunning Train/Dev") 215 | print("==================================================") 216 | for epoch in range(hc['num_epochs']): 217 | train_mlmetrics = train_dev_pass(lsc_model, optimizer, train_loader, sec_batch, metrics=train_mlmetrics, train=True, pred_threshold=hc['pthresh']) 218 | dev_mlmetrics = train_dev_pass(lsc_model, optimizer, dev_loader, sec_batch, metrics=dev_mlmetrics, pred_threshold=hc['pthresh']) 219 | 220 | train_loss, dev_loss = train_mlmetrics.loss, dev_mlmetrics.loss 221 | 222 | if dev_loss < best_loss: 223 | best_loss = dev_loss 224 | best_metrics = dev_mlmetrics 225 | best_model = lsc_model.state_dict() 226 | 227 | scheduler.step(dev_loss) 228 | 229 | print("%5d || %.4f | %.4f || %.4f | %.4f %.4f %.4f" % (epoch, train_loss, train_mlmetrics.macro_f1, dev_loss, dev_mlmetrics.macro_prec, dev_mlmetrics.macro_rec, dev_mlmetrics.macro_f1)) 230 | 231 | print("\nCollecting outputs") 232 | print("==================================================") 233 | torch.save(best_model, dc['model_dump']) 234 | with open(dc['dev_metrics_dump'], 'wb') as fw: 235 | pkl.dump(best_metrics, fw) 236 | 237 | if hc['do_test']: 238 | lsc_model.load_state_dict(best_model) 239 | 240 | print("VALIDATION Results || %.4f | %.4f %.4f %.4f" % (best_loss, best_metrics.macro_prec, best_metrics.macro_rec, best_metrics.macro_f1)) 241 | 242 | if hc['do_test']: 243 | test_mlmetrics = MultiLabelMetrics(L) 244 | print("\nRunning Test") 245 | print("==================================================") 246 | test_mlmetrics = train_dev_pass(lsc_model, optimizer, test_loader, sec_batch, metrics=test_mlmetrics, pred_threshold=hc['pthresh']) 247 | with open(dc['test_metrics_dump'], 'wb') as fw: 248 | pkl.dump(test_mlmetrics, fw) 249 | print("TEST Results || %.4f | %.4f %.4f %.4f" % (test_mlmetrics.loss, test_mlmetrics.macro_prec, test_mlmetrics.macro_rec, test_mlmetrics.macro_f1)) 250 | 251 | if hc['do_infer']: 252 | print("\nRunning Test") 253 | print("==================================================") 254 | infer_outputs = train_dev_pass(lsc_model, optimizer, infer_loader, sec_batch, infer=True, pred_threshold=hc['pthresh'], label_vocab=label_vocab) 255 | with open(dc['infer_trg'], 'w') as fw: 256 | fw.write('\n'.join([json.dumps(doc) for doc in infer_outputs])) 257 | 258 | 259 | -------------------------------------------------------------------------------- /data_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import string 4 | import copy 5 | import multiprocessing as mp 6 | from tqdm import tqdm 7 | import json 8 | import pickle as pkl 9 | 10 | class LSIDataset(torch.utils.data.Dataset): 11 | def __init__(self, jsonl_file=None, data_list=None): 12 | super().__init__() 13 | 14 | self.annotated = False 15 | self.sent_vectorized = False 16 | 17 | if data_list is not None: 18 | self.dataset = copy.deepcopy(data_list) 19 | for instance in tqdm(self.dataset, desc="Loading data from list"): 20 | instance['text'] = instance['text'] 21 | if 'labels' in instance: 22 | self.annotated = True 23 | instance['labels'] = np.array(instance['labels']) 24 | 25 | elif jsonl_file is not None: 26 | self.dataset = [] 27 | with open(jsonl_file) as fr: 28 | for line in tqdm(fr, desc="Loading data from file"): 29 | doc = json.loads(line) 30 | text = np.array([sent for sent in doc['text']]) 31 | newdoc = {'id': doc['id'], 'text': text} 32 | if 'labels' in doc: 33 | self.annotated = True 34 | labels = np.array(doc['labels']) 35 | newdoc['labels'] = labels 36 | self.dataset.append(newdoc) 37 | 38 | def __len__(self): 39 | return len(self.dataset) 40 | 41 | def __getitem__(self, index): 42 | return self.dataset[index] 43 | 44 | def save_data(self, data_file): 45 | with open(data_file, 'wb') as fw: 46 | pkl.dump(self, fw) 47 | 48 | def load_data(data_file): 49 | with open(data_file, 'rb') as fr: 50 | return pkl.load(fr) 51 | 52 | # remove puncutations and empty sentences 53 | def preprocess(self): 54 | for i, instance in enumerate(tqdm(self.dataset, desc="Preprocessing")): 55 | text = [] 56 | for j, sent in enumerate(instance['text']): 57 | ppsent = sent.strip().lower().translate(str.maketrans('', '', string.punctuation)) 58 | if len(ppsent.split()) > 1: 59 | text.append(ppsent) 60 | instance['text'] = np.array(text) 61 | 62 | # break each sentence string into word tokens 63 | def tokenize(self): 64 | for i, instance in enumerate(tqdm(self.dataset, desc="Tokenizing")): 65 | text = [] 66 | for j, sent in enumerate(instance['text']): 67 | toksent = np.array(sent.strip().split()) 68 | text.append(toksent) 69 | instance['text'] = np.array(text, dtype=object) 70 | 71 | # generate a vector for each sentence using Sent2Vec 72 | def sent_vectorize(self, sent2vec_model): 73 | for i, instance in enumerate(tqdm(self.dataset, desc="Embedding sentences")): 74 | esents = sent2vec_model.embed_sentences(instance['text']) 75 | instance['text'] = np.delete(esents, np.where(esents.sum(axis=1) == 0)[0], axis=0) 76 | self.sent_vectorized = True 77 | 78 | 79 | # unified code for generating mini batches of data for both facts and sections during train / dev / test / inference 80 | class MiniBatch: 81 | def __init__(self, examples, vocab=None, label_vocab=None, schemas=None, type_map=None, node_vocab=None, edge_vocab=None, adjacency=None, hidden_size=200, max_segments=4, max_segment_size=8, num_mpath_samples=2): 82 | # provide vocab if not sent vectorized, None otherwise 83 | self.sent_vectorized = True if vocab is None else False 84 | # provide label_vocab if annotated, None otherwise 85 | self.annotated = True if label_vocab is not None else False 86 | # provide graph data if struct encoder is to be used on these examples, None otherwise 87 | self.sample_metapaths = True if schemas is not None else False 88 | 89 | self.max_segments = max_segments 90 | 91 | if not self.sent_vectorized: 92 | self.vocab = vocab 93 | self.max_segment_size = max_segment_size 94 | else: 95 | self.sent_hidden_size = hidden_size 96 | 97 | if self.annotated: 98 | self.label_vocab = label_vocab 99 | 100 | if self.sample_metapaths: 101 | self.schemas = schemas 102 | self.type_map = type_map 103 | self.node_vocab = node_vocab 104 | self.edge_vocab = edge_vocab 105 | self.adjacency = adjacency 106 | self.num_mpath_samples = num_mpath_samples 107 | 108 | max_len = max([len(d['text']) for d in examples]) 109 | max_segments = min(self.max_segments, max_len) 110 | 111 | # expected shape of text tensors 112 | if not self.sent_vectorized: 113 | max_segment_size = min(self.max_segment_size, max([len(s) for d in examples for s in d['text']])) 114 | self.tokens = torch.zeros(len(examples), max_segments, max_segment_size, dtype=torch.long) # [D, S, W] 115 | else: 116 | self.doc_inputs = torch.zeros(len(examples), max_segments, self.sent_hidden_size) # [D, S, H] 117 | 118 | self.example_ids = [] 119 | 120 | if self.annotated: 121 | # expected shape of true label indicator tensors 122 | self.labels = torch.zeros(len(examples), len(self.label_vocab)) # [D, C] 123 | 124 | for i, instance in enumerate(examples): 125 | if not self.sent_vectorized: 126 | for j, sent in enumerate(instance['text']): 127 | # fill up the j-th sentence of i-th example with word tokens 128 | self.tokens[i, j, :len(sent)] = torch.from_numpy(np.array([self.vocab[w] for w in sent])) 129 | else: 130 | # fill up the i-th example with sentence embeddings 131 | self.doc_inputs[i, :len(instance['text']), :] = torch.from_numpy(instance['text'])[:max_segments] 132 | 133 | self.example_ids.append(instance['id']) 134 | 135 | if self.annotated: 136 | label_list = torch.from_numpy(np.array([self.label_vocab[l] for l in instance['labels']])) 137 | self.labels[i].scatter_(0, label_list, 1.) 138 | 139 | if not self.sent_vectorized: 140 | self.mask = (self.tokens != 0).float() # [D, S, W] 141 | else: 142 | self.mask = (self.doc_inputs != 0).any(dim=2).float() # [D, S] 143 | 144 | if self.sample_metapaths: 145 | trg_node_tokens = torch.tensor([self.node_vocab[self.type_map[x]][x] for x in self.example_ids]) 146 | self.node_tokens, self.edge_tokens = self.generate_metapaths(trg_node_tokens, self.schemas, self.adjacency, self.edge_vocab, num_samples=self.num_mpath_samples) # N * [M, D, L+1], N * [M, D, L] 147 | 148 | # sample metapaths using adjacency matrices 149 | def generate_metapaths(self, indices, schemas, adjacency, edge_vocab, num_samples=2): # [D,] 150 | indices = indices.repeat(num_samples) # [M*D,] 151 | 152 | tokens, edge_tokens = [], [] 153 | 154 | # repeat over all schemas 155 | for i in range(len(schemas)): 156 | ins_tokens, ins_edge_tokens = [indices], [] 157 | for keys in schemas[i]: 158 | neighbours = adjacency[keys].sample(num_neighbors=1, subset=ins_tokens[-1]).squeeze(1) # [M*D,] 159 | relations = torch.full(neighbours.shape, edge_vocab[keys[1]], dtype=torch.long) # [M*D,] 160 | 161 | ins_tokens.append(neighbours) 162 | ins_edge_tokens.append(relations) 163 | 164 | ins_tokens = torch.stack(ins_tokens, dim=1) 165 | ins_tokens = ins_tokens.view(num_samples, -1, ins_tokens.size(1)) # [M, D, L+1] 166 | 167 | ins_edge_tokens = torch.stack(ins_edge_tokens, dim=1) 168 | ins_edge_tokens = ins_edge_tokens.view(num_samples, -1, ins_edge_tokens.size(1)) # [M, D, L] 169 | 170 | tokens.append(ins_tokens) 171 | edge_tokens.append(ins_edge_tokens) 172 | 173 | return tokens, edge_tokens 174 | 175 | # automatic memory pinning for faster cpu to cuda transfer 176 | def pin_memory(self): 177 | if not self.sent_vectorized: 178 | self.tokens.pin_memory() 179 | else: 180 | self.doc_inputs.pin_memory() 181 | self.mask.pin_memory() 182 | if self.annotated: 183 | self.labels.pin_memory() 184 | if self.sample_metapaths: 185 | for i in range(len(self.node_tokens)): 186 | self.node_tokens[i].pin_memory() 187 | self.edge_tokens[i].pin_memory() 188 | return self 189 | 190 | # transfer pinned cpu tensors to cuda 191 | def cuda(self, dev='cuda'): 192 | if not self.sent_vectorized: 193 | self.tokens = self.tokens.cuda(dev, non_blocking=True) 194 | else: 195 | self.doc_inputs = self.doc_inputs.cuda(dev, non_blocking=True) 196 | self.mask = self.mask.cuda(dev, non_blocking=True) 197 | if self.annotated: 198 | self.labels = self.labels.cuda(dev, non_blocking=True) 199 | if self.sample_metapaths: 200 | for i in range(len(self.node_tokens)): 201 | self.node_tokens[i] = self.node_tokens[i].cuda(dev, non_blocking=True) 202 | self.edge_tokens[i] = self.edge_tokens[i].cuda(dev, non_blocking=True) 203 | return self 204 | 205 | def collate_func(examples, **kwargs): 206 | return MiniBatch(examples, **kwargs) 207 | --------------------------------------------------------------------------------