├── LICENSE ├── README.md ├── config.yaml ├── grapher_fig.png ├── model.py ├── modules ├── base.py ├── data_processor.py ├── evaluator.py ├── filtering.py ├── layers.py ├── loss_functions.py ├── rel_rep.py ├── scorer.py ├── span_rep.py ├── token_rep.py ├── token_splitter.py └── utils.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Urchade Zaratiana 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # *GraphER*: A Structure-aware Text-to-Graph Model for Entity and Relation Extraction 2 | 3 | ![Alt text](grapher_fig.png) 4 | 5 | ### TO DO: 6 | - [ ] Implement negative sampling (both for entities and relations) 7 | - [ ] Add evaluation computation 8 | - [ ] Ablation study (scoring, ...) 9 | 10 | For now, you can play with the beta version in colab: [](https://colab.research.google.com/drive/1IinAMCtUotntrtoP9zNutriZJtJ4Hymd?usp=sharing) 11 | 12 | ```bibtex 13 | @misc{zaratiana2024grapher, 14 | title={GraphER: A Structure-aware Text-to-Graph Model for Entity and Relation Extraction}, 15 | author={Urchade Zaratiana and Nadi Tomeh and Niama El Khbir and Pierre Holat and Thierry Charnois}, 16 | year={2024}, 17 | eprint={2404.12491}, 18 | archivePrefix={arXiv}, 19 | primaryClass={cs.CL} 20 | } 21 | ``` 22 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | # Model Configuration 2 | model_name: microsoft/deberta-v3-base # Hugging Face model 3 | name: "grapher" 4 | max_width: 12 5 | hidden_size: 768 6 | dropout: 0.1 7 | fine_tune: true 8 | subtoken_pooling: first 9 | span_mode: markerV0 10 | num_heads: 4 11 | num_transformer_layers: 2 12 | ffn_mul: 4 13 | scorer: "dot" 14 | 15 | # Training Parameters 16 | num_steps: 30000 17 | train_batch_size: 8 18 | eval_every: 5000 19 | warmup_ratio: 0.1 20 | scheduler_type: "cosine" 21 | 22 | # Learning Rate and weight decay Configuration 23 | lr_encoder: 1e-5 24 | lr_others: 5e-5 25 | 26 | # Directory Paths 27 | root_dir: grapher_logs 28 | train_data: "data/rel_news_b.json" 29 | val_data_dir: "data/NER_datasets" 30 | 31 | # Pretrained Model Path 32 | # Use "none" if no pretrained model is being used 33 | prev_path: "none" 34 | 35 | # Advanced Training Settings 36 | size_sup: -1 37 | max_types: 25 38 | max_len: 384 39 | freeze_token_rep: false 40 | save_total_limit: 20 41 | max_top_k: 54 42 | add_top_k: 10 43 | shuffle_types: true 44 | 45 | random_drop: true 46 | max_neg_type_ratio: 2 47 | max_ent_types: 20 48 | max_rel_types: 20 -------------------------------------------------------------------------------- /grapher_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/urchade/GraphER/23cc4edc4fb04cebd40b99ec02e55700ca8252dd/grapher_fig.png -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | from modules.base import GrapherBase 6 | from modules.data_processor import TokenPromptProcessorTR 7 | from modules.filtering import FilteringLayer 8 | from modules.layers import MLP, LstmSeq2SeqEncoder, TransLayer, GraphEmbedder 9 | from modules.loss_functions import compute_matching_loss 10 | from modules.rel_rep import RelationRep 11 | from modules.scorer import ScorerLayer 12 | from modules.span_rep import SpanRepLayer 13 | from modules.token_rep import TokenRepLayer 14 | from modules.utils import get_ground_truth_relations, get_candidates 15 | 16 | 17 | class GraphER(GrapherBase): 18 | def __init__(self, config): 19 | super().__init__(config) 20 | 21 | self.config = config 22 | 23 | # [ENT] token 24 | self.ent_token = "<>" 25 | self.rel_token = "<>" 26 | self.sep_token = "<>" 27 | 28 | # usually a pretrained bidirectional transformer, returns first subtoken representation 29 | self.token_rep_layer = TokenRepLayer(model_name=config.model_name, fine_tune=config.fine_tune, 30 | subtoken_pooling=config.subtoken_pooling, hidden_size=config.hidden_size, 31 | add_tokens=[self.ent_token, self.rel_token, self.sep_token]) 32 | 33 | # token prompt processor 34 | self.token_prompt_processor = TokenPromptProcessorTR(self.ent_token, self.rel_token, self.sep_token) 35 | 36 | # hierarchical representation of tokens (Zaratiana et al, 2022) 37 | # https://arxiv.org/pdf/2203.14710.pdf 38 | self.rnn = LstmSeq2SeqEncoder( 39 | input_size=config.hidden_size, 40 | hidden_size=config.hidden_size // 2, 41 | num_layers=1, 42 | bidirectional=True 43 | ) 44 | 45 | # span representation 46 | self.span_rep_layer = SpanRepLayer( 47 | span_mode=config.span_mode, 48 | hidden_size=config.hidden_size, 49 | max_width=config.max_width, 50 | dropout=config.dropout 51 | ) 52 | 53 | # prompt representation (FFN) 54 | self.ent_rep_layer = nn.Linear(config.hidden_size, config.hidden_size) 55 | self.rel_rep_layer = nn.Linear(config.hidden_size, config.hidden_size) 56 | 57 | # filtering layer for spans and relations 58 | self._span_filtering = FilteringLayer(config.hidden_size) 59 | self._rel_filtering = FilteringLayer(config.hidden_size) 60 | 61 | # relation representation 62 | self.relation_rep = RelationRep(config.hidden_size, config.dropout, config.ffn_mul) 63 | 64 | # graph embedder 65 | self.graph_embedder = GraphEmbedder(config.hidden_size) 66 | 67 | # transformer layer 68 | self.trans_layer = TransLayer( 69 | config.hidden_size, 70 | num_heads=config.num_heads, 71 | num_layers=config.num_transformer_layers 72 | ) 73 | 74 | # keep_mlp 75 | self.keep_mlp = MLP([config.hidden_size, config.hidden_size * config.ffn_mul, 1], dropout=0.1) 76 | 77 | # scoring layers 78 | self.scorer_ent = ScorerLayer( 79 | scoring_type=config.scorer, 80 | hidden_size=config.hidden_size, 81 | dropout=config.dropout 82 | ) 83 | 84 | self.scorer_rel = ScorerLayer( 85 | scoring_type=config.scorer, 86 | hidden_size=config.hidden_size, 87 | dropout=config.dropout 88 | ) 89 | 90 | def get_optimizer(self, lr_encoder, lr_others, freeze_token_rep=False): 91 | """ 92 | Parameters: 93 | - lr_encoder: Learning rate for the encoder layer. 94 | - lr_others: Learning rate for all other layers. 95 | - freeze_token_rep: whether the token representation layer should be frozen. 96 | """ 97 | param_groups = [ 98 | # encoder 99 | {"params": self.rnn.parameters(), "lr": lr_others}, 100 | # projection layers 101 | {"params": self.span_rep_layer.parameters(), "lr": lr_others}, 102 | # prompt representation 103 | {"params": self.ent_rep_layer.parameters(), "lr": lr_others}, 104 | {"params": self.rel_rep_layer.parameters(), "lr": lr_others}, 105 | # filtering layers 106 | {"params": self._span_filtering.parameters(), "lr": lr_others}, 107 | {"params": self._rel_filtering.parameters(), "lr": lr_others}, 108 | # relation representation 109 | {"params": self.relation_rep.parameters(), "lr": lr_others}, 110 | # graph embedder 111 | {"params": self.graph_embedder.parameters(), "lr": lr_others}, 112 | # transformer layer 113 | {"params": self.trans_layer.parameters(), "lr": lr_others}, 114 | # keep_mlp 115 | {"params": self.keep_mlp.parameters(), "lr": lr_others}, 116 | # scoring layer 117 | {"params": self.scorer_ent.parameters(), "lr": lr_others}, 118 | {"params": self.scorer_rel.parameters(), "lr": lr_others} 119 | ] 120 | 121 | if not freeze_token_rep: 122 | # If token_rep_layer should not be frozen, add it to the optimizer with its learning rate 123 | param_groups.append({"params": self.token_rep_layer.parameters(), "lr": lr_encoder}) 124 | else: 125 | # If token_rep_layer should be frozen, explicitly set requires_grad to False for its parameters 126 | for param in self.token_rep_layer.parameters(): 127 | param.requires_grad = False 128 | 129 | optimizer = torch.optim.AdamW(param_groups) 130 | 131 | return optimizer 132 | 133 | def compute_score_train(self, x): 134 | span_idx = x['span_idx'] * x['span_mask'].unsqueeze(-1) 135 | 136 | # Process input 137 | word_rep, mask, entity_type_rep, entity_type_mask, rel_type_rep, relation_type_mask = self.token_prompt_processor.process( 138 | x, self.token_rep_layer, "train" 139 | ) 140 | 141 | # Compute representations 142 | word_rep = self.rnn(word_rep, mask) 143 | span_rep = self.span_rep_layer(word_rep, span_idx) 144 | entity_type_rep = self.ent_rep_layer(entity_type_rep) 145 | rel_type_rep = self.rel_rep_layer(rel_type_rep) 146 | 147 | # Compute number of entity and relation types 148 | num_ent, num_rel = entity_type_rep.shape[1], rel_type_rep.shape[1] 149 | 150 | return span_rep, num_ent, num_rel, entity_type_rep, entity_type_mask, rel_type_rep, relation_type_mask, ( 151 | word_rep, mask) 152 | 153 | @torch.no_grad() 154 | def compute_score_eval(self, x, device): 155 | span_idx = (x['span_idx'] * x['span_mask'].unsqueeze(-1)).to(device) 156 | 157 | # Process input 158 | word_rep, mask, entity_type_rep, relation_type_rep = self.token_prompt_processor.process( 159 | x, self.token_rep_layer, "eval" 160 | ) 161 | 162 | # Compute representations 163 | word_rep = self.rnn(word_rep, mask) 164 | span_rep = self.span_rep_layer(word_rep, span_idx) 165 | entity_type_rep = self.ent_rep_layer(entity_type_rep) 166 | relation_type_rep = self.rel_rep_layer(relation_type_rep) 167 | 168 | # Compute number of entity and relation types 169 | num_ent, num_rel = entity_type_rep.shape[1], relation_type_rep.shape[1] 170 | 171 | return span_rep, num_ent, num_rel, entity_type_rep, relation_type_rep, (word_rep, mask) 172 | 173 | def forward(self, x, prediction_mode=False): 174 | 175 | # clone span_label 176 | span_label = x['span_label'].clone() 177 | 178 | # compute span representation 179 | if prediction_mode: 180 | # Get the device of the model parameters 181 | device = next(self.parameters()).device 182 | 183 | # Compute scores for evaluation 184 | span_rep, num_ent, num_rel, entity_type_rep, rel_type_rep, (word_rep, word_mask) = self.compute_score_eval( 185 | x, device) 186 | 187 | # Create masks for relation and entity types, setting all values to 1 188 | relation_type_mask = torch.ones(size=(rel_type_rep.shape[0], num_rel), device=device) 189 | entity_type_mask = torch.ones(size=(entity_type_rep.shape[0], num_ent), device=device) 190 | else: 191 | # Compute scores for training 192 | span_rep, num_ent, num_rel, entity_type_rep, entity_type_mask, rel_type_rep, relation_type_mask, ( 193 | word_rep, mask) = self.compute_score_train(x) 194 | 195 | # Reshape span_rep from (B, L, K, D) to (B, L * K, D) 196 | B, L, K, D = span_rep.shape 197 | span_rep = span_rep.view(B, L * K, D) 198 | 199 | # Compute filtering scores for spans 200 | filter_score_span, filter_loss_span = self._span_filtering(span_rep, x['span_label']) 201 | 202 | # Determine the maximum number of candidates 203 | # If L is greater than the configured maximum, use the configured maximum plus an additional top K 204 | # Otherwise, use L plus an additional top K 205 | max_top_k = min(L, self.config.max_top_k) + self.config.add_top_k 206 | 207 | # Sort the filter scores for spans in descending order 208 | sorted_idx = torch.sort(filter_score_span, dim=-1, descending=True)[1] 209 | 210 | # Define the elements to get candidates for 211 | elements = [span_rep, span_label, x['span_mask'], x['span_idx']] 212 | 213 | # Use a list comprehension to get the candidates for each element 214 | candidate_span_rep, candidate_span_label, candidate_span_mask, candidate_spans_idx = [ 215 | get_candidates(sorted_idx, element, topk=max_top_k)[0] for element in elements 216 | ] 217 | 218 | # Calculate the lengths for the top K entities 219 | top_k_lengths = x["seq_length"].clone() + self.config.add_top_k 220 | 221 | # Create a condition mask where the range of top K is greater than or equal to the top K lengths 222 | condition_mask = torch.arange(max_top_k, device=span_rep.device).unsqueeze(0) >= top_k_lengths.unsqueeze(-1) 223 | 224 | # Apply the condition mask to the candidate span mask and label, setting the masked values to 0 and -1 225 | # respectively 226 | candidate_span_mask.masked_fill_(condition_mask, 0) 227 | candidate_span_label.masked_fill_(condition_mask, -1) 228 | 229 | # Get ground truth relations and represent them 230 | relation_classes = get_ground_truth_relations(x, candidate_spans_idx, candidate_span_label) 231 | rel_rep = self.relation_rep(candidate_span_rep).view(B, max_top_k * max_top_k, -1) # Reshape in the same line 232 | 233 | # Compute filtering scores for relations and sort them in descending order 234 | filter_score_rel, filter_loss_rel = self._rel_filtering(rel_rep, relation_classes) 235 | sorted_idx_pair = torch.sort(filter_score_rel, dim=-1, descending=True)[1] 236 | 237 | # Embed candidate span representations 238 | candidate_span_rep, cat_pair_rep = self.graph_embedder(candidate_span_rep) 239 | 240 | # Define the elements to get candidates for 241 | elements = [cat_pair_rep.view(B, max_top_k * max_top_k, -1), relation_classes.view(B, max_top_k * max_top_k)] 242 | 243 | # Use a list comprehension to get the candidates for each element 244 | candidate_pair_rep, candidate_pair_label = [get_candidates(sorted_idx_pair, element, topk=max_top_k)[0] for 245 | element in elements] 246 | 247 | # Get the top K relation indices 248 | topK_rel_idx = sorted_idx_pair[:, :max_top_k] 249 | 250 | # Mask the candidate pair labels using the condition mask and refine the relation representation 251 | candidate_pair_label.masked_fill_(condition_mask, -1) 252 | candidate_pair_mask = candidate_pair_label > -1 253 | 254 | # Concatenate span and relation representations 255 | concat_span_pair = torch.cat((candidate_span_rep, candidate_pair_rep), dim=1) 256 | mask_span_pair = torch.cat((candidate_span_mask, candidate_pair_mask), dim=1) 257 | 258 | # Apply transformer layer and keep_mlp 259 | out_trans = self.trans_layer(concat_span_pair, mask_span_pair) 260 | keep_score = self.keep_mlp(out_trans).squeeze(-1) # Shape: (B, max_top_k + max_top_k, 1) 261 | 262 | # Apply sigmoid function and squeeze the last dimension 263 | # keep_score = torch.sigmoid(keep_score).squeeze(-1) # Shape: (B, max_top_k + max_top_k) 264 | 265 | # Split keep_score into keep_ent and keep_rel 266 | keep_ent, keep_rel = keep_score.split([max_top_k, max_top_k], dim=1) 267 | 268 | """not use output from transformer layer for now""" 269 | # Split out_trans 270 | # candidate_span_rep, candidate_pair_rep = out_trans.split([max_top_k, max_top_k], dim=1) 271 | 272 | # Compute scores for entities and relations 273 | scores_ent = self.scorer_ent(candidate_span_rep, entity_type_rep) # Shape: [B, N, C] 274 | scores_rel = self.scorer_rel(candidate_pair_rep, rel_type_rep) # Shape: [B, N, C] 275 | 276 | if prediction_mode: 277 | return { 278 | "entity_logits": scores_ent, 279 | "relation_logits": scores_rel, 280 | "keep_ent": keep_ent, 281 | "keep_rel": keep_rel, 282 | "candidate_spans_idx": candidate_spans_idx, 283 | "candidate_pair_label": candidate_pair_label, 284 | "max_top_k": max_top_k, 285 | "topK_rel_idx": topK_rel_idx 286 | } 287 | # Compute losses for relation and entity classifiers 288 | relation_loss = compute_matching_loss(scores_rel, candidate_pair_label, relation_type_mask, num_rel) 289 | entity_loss = compute_matching_loss(scores_ent, candidate_span_label, entity_type_mask, num_ent) 290 | 291 | # Concatenate labels for binary classification and compute binary classification loss 292 | ent_rel_label = (torch.cat((candidate_span_label, candidate_pair_label), dim=1) > 0).float() 293 | filter_loss = F.binary_cross_entropy_with_logits(keep_score, ent_rel_label, reduction='none') 294 | 295 | # Compute structure loss and total loss 296 | structure_loss = (filter_loss * mask_span_pair.float()).sum() 297 | total_loss = sum([filter_loss_span, filter_loss_rel, relation_loss, entity_loss, structure_loss]) 298 | 299 | return total_loss 300 | -------------------------------------------------------------------------------- /modules/base.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from abc import abstractmethod 4 | from pathlib import Path 5 | from typing import Union, Optional, Dict 6 | 7 | import torch 8 | import torch.nn as nn 9 | import yaml 10 | from huggingface_hub import PyTorchModelHubMixin, hf_hub_download 11 | from torch.utils.data import DataLoader 12 | 13 | from .data_processor import GrapherData 14 | from .evaluator import Evaluator 15 | from .token_splitter import WhitespaceTokenSplitter, MecabKoTokenSplitter, SpaCyTokenSplitter 16 | from .utils import er_decoder, get_relation_with_span 17 | from types import SimpleNamespace 18 | 19 | 20 | class GrapherBase(nn.Module, PyTorchModelHubMixin): 21 | def __init__(self, config): 22 | super().__init__() 23 | 24 | self.config = config 25 | self.data_proc = GrapherData(config) 26 | 27 | if not hasattr(config, 'token_splitter'): 28 | self.token_splitter = WhitespaceTokenSplitter() 29 | elif self.config.token_splitter == "spacy": 30 | lang = getattr(config, 'token_splitter_lang', None) 31 | self.token_splitter = SpaCyTokenSplitter(lang=lang) 32 | elif self.config.token_splitter == "mecab-ko": 33 | self.token_splitter = MecabKoTokenSplitter() 34 | 35 | @abstractmethod 36 | def forward(self, x): 37 | pass 38 | 39 | def adjust_logits(self, logits, keep): 40 | """Adjust logits based on the keep tensor.""" 41 | keep = torch.sigmoid(keep) 42 | keep = (keep > 0.5).unsqueeze(-1).float() 43 | adjusted_logits = logits + (1 - keep) * -1e9 44 | return adjusted_logits 45 | 46 | def predict(self, x, threshold=0.5, output_confidence=False): 47 | """Predict entities and relations.""" 48 | out = self.forward(x, prediction_mode=True) 49 | 50 | # Adjust relation and entity logits 51 | out["entity_logits"] = self.adjust_logits(out["entity_logits"], out["keep_ent"]) 52 | out["relation_logits"] = self.adjust_logits(out["relation_logits"], out["keep_rel"]) 53 | 54 | # Get entities and relations 55 | entities, relations = er_decoder(x, out["entity_logits"], out["relation_logits"], out["topK_rel_idx"], 56 | out["max_top_k"], out["candidate_spans_idx"], threshold=threshold, 57 | output_confidence=output_confidence, token_splitter=self.token_splitter) 58 | return entities, relations 59 | 60 | def evaluate(self, test_data, threshold=0.5, batch_size=12, relation_types=None): 61 | self.eval() 62 | data_loader = self.create_dataloader(test_data, batch_size=batch_size, relation_types=relation_types, 63 | shuffle=False) 64 | device = next(self.parameters()).device 65 | all_preds = [] 66 | all_trues = [] 67 | for x in data_loader: 68 | for k, v in x.items(): 69 | if isinstance(v, torch.Tensor): 70 | x[k] = v.to(device) 71 | batch_predictions = self.predict(x, threshold) 72 | all_preds.extend(batch_predictions) 73 | all_trues.extend(get_relation_with_span(x)) 74 | evaluator = Evaluator(all_trues, all_preds) 75 | out, f1 = evaluator.evaluate() 76 | return out, f1 77 | 78 | def create_dataloader(self, data, entity_types=None, **kwargs) -> DataLoader: 79 | return self.data_proc.create_dataloader(data, entity_types, **kwargs) 80 | 81 | def save_pretrained( 82 | self, 83 | save_directory: Union[str, Path], 84 | *, 85 | config: Optional[Union[dict, "DataclassInstance"]] = None, 86 | repo_id: Optional[str] = None, 87 | push_to_hub: bool = False, 88 | **push_to_hub_kwargs, 89 | ) -> Optional[str]: 90 | """ 91 | Save weights in local directory. 92 | 93 | Args: 94 | save_directory (`str` or `Path`): 95 | Path to directory in which the model weights and configuration will be saved. 96 | config (`dict` or `DataclassInstance`, *optional*): 97 | Model configuration specified as a key/value dictionary or a dataclass instance. 98 | push_to_hub (`bool`, *optional*, defaults to `False`): 99 | Whether or not to push your model to the Huggingface Hub after saving it. 100 | repo_id (`str`, *optional*): 101 | ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if 102 | not provided. 103 | kwargs: 104 | Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method. 105 | """ 106 | save_directory = Path(save_directory) 107 | save_directory.mkdir(parents=True, exist_ok=True) 108 | 109 | # save model weights/files 110 | torch.save(self.state_dict(), save_directory / "pytorch_model.bin") 111 | 112 | # save config (if provided) 113 | if config is None: 114 | config = self.config 115 | if config is not None: 116 | if isinstance(config, argparse.Namespace) or isinstance(config, SimpleNamespace): 117 | config = vars(config) 118 | (save_directory / "config.json").write_text(json.dumps(config, indent=2)) 119 | 120 | # push to the Hub if required 121 | if push_to_hub: 122 | kwargs = push_to_hub_kwargs.copy() # soft-copy to avoid mutating input 123 | if config is not None: # kwarg for `push_to_hub` 124 | kwargs["config"] = config 125 | if repo_id is None: 126 | repo_id = save_directory.name # Defaults to `save_directory` name 127 | return self.push_to_hub(repo_id=repo_id, **kwargs) 128 | return None 129 | 130 | @classmethod 131 | def _from_pretrained( 132 | cls, 133 | *, 134 | model_id: str, 135 | revision: Optional[str], 136 | cache_dir: Optional[Union[str, Path]], 137 | force_download: bool, 138 | proxies: Optional[Dict], 139 | resume_download: bool, 140 | local_files_only: bool, 141 | token: Union[str, bool, None], 142 | map_location: str = "cpu", 143 | strict: bool = False, 144 | **model_kwargs, 145 | ): 146 | 147 | # 2. Newer format: Use "pytorch_model.bin" and "gliner_config.json" 148 | model_file = Path(model_id) / "pytorch_model.bin" 149 | if not model_file.exists(): 150 | model_file = hf_hub_download( 151 | repo_id=model_id, 152 | filename="pytorch_model.bin", 153 | revision=revision, 154 | cache_dir=cache_dir, 155 | force_download=force_download, 156 | proxies=proxies, 157 | resume_download=resume_download, 158 | token=token, 159 | local_files_only=local_files_only, 160 | ) 161 | config_file = Path(model_id) / "config.json" 162 | if not config_file.exists(): 163 | config_file = hf_hub_download( 164 | repo_id=model_id, 165 | filename="config.json", 166 | revision=revision, 167 | cache_dir=cache_dir, 168 | force_download=force_download, 169 | proxies=proxies, 170 | resume_download=resume_download, 171 | token=token, 172 | local_files_only=local_files_only, 173 | ) 174 | config = load_config_as_namespace(config_file) 175 | model = cls(config) 176 | state_dict = torch.load(model_file, map_location=torch.device(map_location)) 177 | model.load_state_dict(state_dict, strict=strict, assign=True) 178 | model.to(map_location) 179 | return model 180 | 181 | def to(self, device): 182 | super().to(device) 183 | import flair 184 | flair.device = device 185 | return self 186 | 187 | 188 | def load_config_as_namespace(config_file): 189 | with open(config_file, "r") as f: 190 | config_dict = yaml.safe_load(f) 191 | return argparse.Namespace(**config_dict) 192 | -------------------------------------------------------------------------------- /modules/data_processor.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import List, Tuple, Dict 3 | 4 | import torch 5 | from torch.nn.utils.rnn import pad_sequence 6 | from torch.utils.data import DataLoader 7 | import random 8 | 9 | 10 | # Abstract base class for handling data processing 11 | class GrapherData(object): 12 | def __init__(self, config): 13 | self.config = config 14 | 15 | @staticmethod 16 | def get_dict(spans: List[Tuple[int, int, str]], classes_to_id: Dict[str, int]) -> Dict[Tuple[int, int], int]: 17 | """Get a dictionary of spans.""" 18 | dict_tag = defaultdict(int) 19 | for span in spans: 20 | if span[-1] not in classes_to_id: 21 | continue 22 | dict_tag[(span[0], span[1])] = classes_to_id[span[-1]] 23 | return dict_tag 24 | 25 | def preprocess_spans(self, tokens: List[str], ner: List[Tuple[int, int, str]], rel: List[Tuple[int, int, str]], 26 | classes_to_id: Dict[str, int]) -> Dict: 27 | """Preprocess spans for a given text.""" 28 | # Set the maximum length for tokens 29 | max_token_length = self.config.max_len 30 | 31 | # If the number of tokens exceeds the maximum length, truncate the tokens 32 | if len(tokens) > max_token_length: 33 | token_length = max_token_length 34 | tokens = tokens[:max_token_length] 35 | else: 36 | token_length = len(tokens) 37 | 38 | # Initialize a list to store span indices 39 | span_indices = [] 40 | for i in range(token_length): 41 | span_indices.extend([(i, i + j) for j in range(self.config.max_width)]) 42 | 43 | # Get the dictionary of labels 44 | label_dict = self.get_dict(ner, classes_to_id) if ner else defaultdict(int) 45 | 46 | # Initialize the span labels with the corresponding label from the dictionary 47 | span_labels = torch.LongTensor([label_dict[i] for i in span_indices]) 48 | span_indices = torch.LongTensor(span_indices) 49 | 50 | # Create a mask for valid spans 51 | valid_span_mask = span_indices[:, 1] > token_length - 1 52 | 53 | # Mask invalid positions in the span labels 54 | span_labels = span_labels.masked_fill(valid_span_mask, -1) 55 | 56 | # Return a dictionary with the preprocessed spans 57 | return { 58 | 'tokens': tokens, 59 | 'span_idx': span_indices, 60 | 'span_label': span_labels, 61 | 'seq_length': token_length, 62 | 'entities': ner, 63 | 'relations': rel, 64 | } 65 | 66 | def create_mapping(self, types: List[str]) -> Tuple[Dict[str, int], Dict[int, str]]: 67 | """Create a mapping from type to id and id to type.""" 68 | if not types: 69 | types = ["None"] 70 | type_to_id = {type: id for id, type in enumerate(types, start=1)} 71 | id_to_type = {id: type for type, id in type_to_id.items()} 72 | return type_to_id, id_to_type 73 | 74 | def batch_generate_class_mappings(self, batch_list: List[Dict]) -> Tuple[ 75 | List[Dict[str, int]], List[Dict[int, str]], List[Dict[str, int]], List[Dict[int, str]]]: 76 | """Generate class mappings for a batch of data.""" 77 | all_ent_to_id, all_id_to_ent, all_rel_to_id, all_id_to_rel = [], [], [], [] 78 | 79 | for b in batch_list: 80 | ent_types = list(set([el[-1] for el in b['entities']])) 81 | rel_types = list(set([el[-1] for el in b['relations']])) 82 | 83 | if self.config.shuffle_types: 84 | random.shuffle(ent_types) 85 | random.shuffle(rel_types) 86 | 87 | if len(ent_types) == 0: 88 | ent_types = ["none"] 89 | if len(rel_types) == 0: 90 | rel_types = ["none"] 91 | 92 | ent_to_id, id_to_ent = self.create_mapping(ent_types) 93 | rel_to_id, id_to_rel = self.create_mapping(rel_types) 94 | 95 | all_ent_to_id.append(ent_to_id) 96 | all_id_to_ent.append(id_to_ent) 97 | all_rel_to_id.append(rel_to_id) 98 | all_id_to_rel.append(id_to_rel) 99 | 100 | return all_ent_to_id, all_id_to_ent, all_rel_to_id, all_id_to_rel 101 | 102 | def collate_fn(self, batch_list: List[Dict], entity_types: List[str] = None, 103 | relation_types: List[str] = None) -> Dict: 104 | """Collate a batch of data.""" 105 | 106 | if entity_types is None or relation_types is None: 107 | ent_to_id, id_to_ent, rel_to_id, id_to_rel = self.batch_generate_class_mappings(batch_list) 108 | else: 109 | ent_to_id, id_to_ent = self.create_mapping(entity_types) 110 | rel_to_id, id_to_rel = self.create_mapping(relation_types) 111 | 112 | batch = [self.preprocess_spans(b["tokenized_text"], b["entities"], b["relations"], 113 | ent_to_id if not isinstance(ent_to_id, list) else ent_to_id[i]) for i, b in 114 | enumerate(batch_list)] 115 | 116 | return self.create_batch_dict(batch, ent_to_id, id_to_ent, rel_to_id, id_to_rel) 117 | 118 | def create_batch_dict(self, batch: List[Dict], ent_to_id: List[Dict[str, int]], id_to_ent: List[Dict[int, str]], 119 | rel_to_id: List[Dict[str, int]], id_to_rel: List[Dict[int, str]]) -> Dict: 120 | """Create a dictionary for a batch of data.""" 121 | 122 | # Extract necessary information from the batch 123 | tokens = [el["tokens"] for el in batch] 124 | span_idx = pad_sequence([b["span_idx"] for b in batch], batch_first=True, padding_value=0) 125 | span_label = pad_sequence([el["span_label"] for el in batch], batch_first=True, padding_value=-1) 126 | seq_length = torch.LongTensor([el["seq_length"] for el in batch]) 127 | entities = [el["entities"] for el in batch] 128 | relations = [el["relations"] for el in batch] 129 | 130 | # Create a mask for valid spans 131 | span_mask = span_label != -1 132 | 133 | # Return a dictionary with the preprocessed spans 134 | return { 135 | 'seq_length': seq_length, 136 | 'span_idx': span_idx, 137 | 'tokens': tokens, 138 | 'span_mask': span_mask, 139 | 'span_label': span_label, 140 | 'entities': entities, 141 | 'relations': relations, 142 | 'ent_to_id': ent_to_id, 143 | 'id_to_ent': id_to_ent, 144 | 'rel_to_id': rel_to_id, 145 | 'id_to_rel': id_to_rel 146 | } 147 | 148 | def create_dataloader(self, data, entity_types=None, relation_types=None, **kwargs) -> DataLoader: 149 | return DataLoader(data, collate_fn=lambda x: self.collate_fn(x, entity_types, relation_types), **kwargs) 150 | 151 | 152 | class TokenPromptProcessorTR: 153 | def __init__(self, entity_token, relation_token, sep_token): 154 | self.entity_token = entity_token 155 | self.sep_token = sep_token 156 | self.relation_token = relation_token 157 | 158 | def process(self, x, token_rep_layer, mode): 159 | if mode == "train": 160 | return self._process_train(x, token_rep_layer) 161 | elif mode == "eval": 162 | return self._process_eval(x, token_rep_layer) 163 | else: 164 | raise ValueError("Invalid mode specified. Choose 'train' or 'eval'.") 165 | 166 | def _process_train(self, x, token_rep_layer): 167 | 168 | device = next(token_rep_layer.parameters()).device 169 | 170 | new_length = x["seq_length"].clone() 171 | new_tokens = [] 172 | all_len_prompt = [] 173 | num_classes_all = [] 174 | num_relations_all = [] 175 | 176 | for i in range(len(x["tokens"])): 177 | all_types_i = list(x["ent_to_id"][i].keys()) 178 | all_relations_i = list(x["rel_to_id"][i].keys()) 179 | entity_prompt = [] 180 | relation_prompt = [] 181 | num_classes_all.append(len(all_types_i)) 182 | num_relations_all.append(len(all_relations_i)) 183 | 184 | for entity_type in all_types_i: 185 | entity_prompt.append(self.entity_token) 186 | entity_prompt.append(entity_type) 187 | entity_prompt.append(self.sep_token) 188 | 189 | for relation_type in all_relations_i: 190 | relation_prompt.append(self.relation_token) 191 | relation_prompt.append(relation_type) 192 | relation_prompt.append(self.sep_token) 193 | 194 | combined_prompt = entity_prompt + relation_prompt 195 | tokens_p = combined_prompt + x["tokens"][i] 196 | new_length[i] += len(combined_prompt) 197 | new_tokens.append(tokens_p) 198 | all_len_prompt.append(len(combined_prompt)) 199 | 200 | max_num_classes = max(num_classes_all) 201 | entity_type_pos = torch.arange(max_num_classes).unsqueeze(0).expand(len(num_classes_all), -1).to(device) 202 | entity_type_mask = entity_type_pos < torch.tensor(num_classes_all).unsqueeze(-1).to(device) 203 | 204 | max_num_relations = max(num_relations_all) 205 | relation_type_pos = torch.arange(max_num_relations).unsqueeze(0).expand(len(num_relations_all), -1).to(device) 206 | relation_type_mask = relation_type_pos < torch.tensor(num_relations_all).unsqueeze(-1).to(device) 207 | 208 | bert_output = token_rep_layer(new_tokens, new_length) 209 | word_rep_w_prompt = bert_output["embeddings"] 210 | mask_w_prompt = bert_output["mask"] 211 | 212 | word_rep = [] 213 | mask = [] 214 | entity_type_rep = [] 215 | relation_type_rep = [] 216 | 217 | for i in range(len(x["tokens"])): 218 | prompt_entity_length = all_len_prompt[i] 219 | entity_len = 2 * len(list(x["ent_to_id"][i].keys())) + 1 220 | relation_len = 2 * len(list(x["rel_to_id"][i].keys())) + 1 221 | 222 | word_rep.append(word_rep_w_prompt[i, prompt_entity_length:new_length[i]]) 223 | mask.append(mask_w_prompt[i, prompt_entity_length:new_length[i]]) 224 | 225 | entity_rep = word_rep_w_prompt[i, :entity_len - 1] 226 | entity_rep = entity_rep[0::2] 227 | entity_type_rep.append(entity_rep) 228 | 229 | relation_rep = word_rep_w_prompt[i, entity_len:entity_len + relation_len - 1] 230 | relation_rep = relation_rep[0::2] 231 | relation_type_rep.append(relation_rep) 232 | 233 | word_rep = pad_sequence(word_rep, batch_first=True) 234 | mask = pad_sequence(mask, batch_first=True) 235 | entity_type_rep = pad_sequence(entity_type_rep, batch_first=True) 236 | relation_type_rep = pad_sequence(relation_type_rep, batch_first=True) 237 | 238 | return word_rep, mask, entity_type_rep, entity_type_mask, relation_type_rep, relation_type_mask 239 | 240 | def _process_eval(self, x, token_rep_layer): 241 | all_types = list(x["ent_to_id"].keys()) 242 | all_relations = list(x["rel_to_id"].keys()) 243 | entity_prompt = [] 244 | relation_prompt = [] 245 | 246 | for entity_type in all_types: 247 | entity_prompt.append(self.entity_token) 248 | entity_prompt.append(entity_type) 249 | entity_prompt.append(self.sep_token) 250 | 251 | for relation_type in all_relations: 252 | relation_prompt.append(self.relation_token) 253 | relation_prompt.append(relation_type) 254 | relation_prompt.append(self.sep_token) 255 | 256 | combined_prompt = entity_prompt + relation_prompt 257 | prompt_entity_length = len(combined_prompt) 258 | tokens_p = [combined_prompt + tokens for tokens in x["tokens"]] 259 | seq_length_p = x["seq_length"] + prompt_entity_length 260 | 261 | # Converting tokens_p to a format suitable for token_rep_layer 262 | out = token_rep_layer(tokens_p, seq_length_p) 263 | 264 | word_rep_w_prompt = out["embeddings"] 265 | mask_w_prompt = out["mask"] 266 | 267 | word_rep = word_rep_w_prompt[:, prompt_entity_length:, :] 268 | mask = mask_w_prompt[:, prompt_entity_length:] 269 | 270 | entity_type_rep = word_rep_w_prompt[:, :len(entity_prompt) - 1, :] 271 | entity_type_rep = entity_type_rep[:, 0::2, :] 272 | 273 | relation_type_rep = word_rep_w_prompt[:, len(entity_prompt):prompt_entity_length - 1, :] 274 | relation_type_rep = relation_type_rep[:, 0::2, :] 275 | 276 | return word_rep, mask, entity_type_rep, relation_type_rep 277 | -------------------------------------------------------------------------------- /modules/evaluator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Evaluator: 5 | def __init__(self, all_true, all_outs): 6 | self.all_true = all_true 7 | self.all_outs = all_outs 8 | 9 | @torch.no_grad() 10 | def evaluate(self): 11 | precision, recall, f1 = calculate_metrics(self.all_true, self.all_outs, with_type=False) 12 | output_str = f"P: {precision:.2%}\tR: {recall:.2%}\tF1: {f1:.2%}\n" 13 | return output_str, f1 14 | 15 | 16 | def calculate_metrics(all_rel_true, all_rel_pred, with_type=True): 17 | flatrue = [] 18 | for i, v in enumerate(all_rel_true): 19 | for j in v: 20 | try: 21 | head, tail, tp = j 22 | except: 23 | (head, tail), tp = j 24 | if with_type: 25 | flatrue.append((head, tail, tp, i)) 26 | else: 27 | flatrue.append(((head[0], head[1]), (tail[0], tail[1]), tp, i)) 28 | 29 | flapred = [] 30 | for i, v in enumerate(all_rel_pred): 31 | for j in v: 32 | try: 33 | head, tail, tp = j 34 | except: 35 | (head, tail), tp = j 36 | if with_type: 37 | flapred.append((head, tail, tp, i)) 38 | else: 39 | flapred.append(((head[0], head[1]), (tail[0], tail[1]), tp, i)) 40 | 41 | TP = len(set(flatrue).intersection(set(flapred))) 42 | FP = len(flapred) - TP 43 | FN = len(flatrue) - TP 44 | 45 | if (TP + FP) == 0: 46 | prec = 0 47 | else: 48 | prec = TP / (TP + FP) 49 | 50 | if (TP + FN) == 0: 51 | rec = 0 52 | else: 53 | rec = TP / (TP + FN) 54 | 55 | # Note: It seems that you were using avg_pr and avg_re in the original code to calculate f1, 56 | # but they are not defined in the code snippet you provided. 57 | # Hence, I'm using 'prec' and 'rec' to calculate f1. 58 | if (prec + rec) == 0: 59 | f1 = 0 60 | else: 61 | f1 = 2 * (prec * rec) / (prec + rec) 62 | 63 | return prec, rec, f1 64 | -------------------------------------------------------------------------------- /modules/filtering.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from .loss_functions import down_weight_loss 3 | 4 | class FilteringLayer(nn.Module): 5 | def __init__(self, hidden_size): 6 | super().__init__() 7 | 8 | self.filter_layer = nn.Linear(hidden_size, 2) 9 | 10 | def forward(self, embeds, label): 11 | 12 | # Extract dimensions 13 | B, num_spans, D = embeds.shape 14 | 15 | # Compute score using a predefined filtering function 16 | score = self.filter_layer(embeds) # Shape: [B, num_spans, num_classes] 17 | 18 | # Modify label to binary (0 for negative class, 1 for positive) 19 | label_m = label.clone() 20 | label_m[label_m > 0] = 1 21 | 22 | # Initialize the loss 23 | filter_loss = 0 24 | if self.training: 25 | # Compute the loss if in training mode 26 | filter_loss = down_weight_loss(score.view(B * num_spans, -1), 27 | label_m.view(-1), 28 | sample_rate=0., 29 | is_logit=True) 30 | 31 | # Compute the filter score (difference between positive and negative class scores) 32 | filter_score = score[..., 1] - score[..., 0] # Shape: [B, num_spans] 33 | 34 | # Mask out filter scores for ignored labels 35 | filter_score = filter_score.masked_fill(label == -1, float('-inf')) 36 | 37 | if self.training: 38 | filter_score = filter_score.masked_fill(label_m > 0, float('inf')) 39 | 40 | return filter_score, filter_loss 41 | -------------------------------------------------------------------------------- /modules/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 4 | 5 | 6 | def MLP(units, dropout, activation=nn.ReLU): 7 | units = [int(u) for u in units] 8 | assert len(units) >= 2 9 | layers = [] 10 | for i in range(len(units) - 2): 11 | layers.append(nn.Linear(units[i], units[i + 1])) 12 | layers.append(activation()) 13 | layers.append(nn.Dropout(dropout)) 14 | layers.append(nn.Linear(units[-2], units[-1])) 15 | return nn.Sequential(*layers) 16 | 17 | 18 | def create_transformer_encoder(d_model, nhead, num_layers, ffn_mul=4, dropout=0.1): 19 | layer = nn.TransformerEncoderLayer( 20 | d_model=d_model, nhead=nhead, batch_first=True, norm_first=False, dim_feedforward=d_model * ffn_mul, 21 | dropout=dropout) 22 | encoder = nn.TransformerEncoder(layer, num_layers=num_layers) 23 | return encoder 24 | 25 | 26 | class TransLayer(nn.Module): 27 | def __init__(self, d_model, num_heads, num_layers, ffn_mul=4, dropout=0.1): 28 | super(TransLayer, self).__init__() 29 | 30 | if num_layers > 0: 31 | self.transformer_encoder = create_transformer_encoder(d_model, num_heads, num_layers, ffn_mul, dropout) 32 | 33 | def forward(self, x, mask): 34 | mask = mask == False 35 | if not hasattr(self, 'transformer_encoder'): 36 | return x 37 | else: 38 | return self.transformer_encoder(src=x, src_key_padding_mask=mask) 39 | 40 | 41 | class GraphEmbedder(nn.Module): 42 | def __init__(self, d_model): 43 | super().__init__() 44 | 45 | # Project node to half of its dimension 46 | self.project_node = nn.Linear(d_model, d_model // 2) 47 | 48 | # Initialize identifier with zeros 49 | self.identifier = nn.Parameter(torch.randn(2, d_model)) 50 | nn.init.zeros_(self.identifier) 51 | 52 | def forward(self, candidate_span_rep): 53 | max_top_k = candidate_span_rep.size()[1] 54 | 55 | # Project nodes 56 | nodes = self.project_node(candidate_span_rep) 57 | 58 | # Split nodes into heads and tails 59 | heads = nodes.unsqueeze(2).expand(-1, -1, max_top_k, -1) 60 | tails = nodes.unsqueeze(1).expand(-1, max_top_k, -1, -1) 61 | 62 | # Concatenate heads and tails to form edges 63 | edges = torch.cat([heads, tails], dim=-1) 64 | 65 | # Duplicate nodes along the last dimension 66 | nodes = torch.cat([nodes, nodes], dim=-1) 67 | 68 | # Add identifier to nodes and edges 69 | nodes += self.identifier[0] 70 | edges += self.identifier[1] 71 | 72 | return nodes, edges 73 | 74 | 75 | class LstmSeq2SeqEncoder(nn.Module): 76 | def __init__(self, input_size, hidden_size, num_layers=1, dropout=0., bidirectional=False): 77 | super(LstmSeq2SeqEncoder, self).__init__() 78 | 79 | self.lstm = nn.LSTM(input_size=input_size, 80 | hidden_size=hidden_size, 81 | num_layers=num_layers, 82 | dropout=dropout, 83 | bidirectional=bidirectional, 84 | batch_first=True) 85 | 86 | def forward(self, x, mask, hidden=None): 87 | # Packing the input sequence 88 | lengths = mask.sum(dim=1).cpu() 89 | packed_x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False) 90 | 91 | # Passing packed sequence through LSTM 92 | packed_output, hidden = self.lstm(packed_x, hidden) 93 | 94 | # Unpacking the output sequence 95 | output, _ = pad_packed_sequence(packed_output, batch_first=True) 96 | 97 | return output 98 | -------------------------------------------------------------------------------- /modules/loss_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def compute_matching_loss(logits, labels, mask, num_classes): 6 | B, _, _ = logits.size() 7 | 8 | logits_label = logits.view(-1, num_classes) 9 | labels = labels.view(-1) # (batch_size * num_spans) 10 | mask_label = labels != -1 # (batch_size * num_spans) 11 | labels.masked_fill_(~mask_label, 0) # Set the labels of padding tokens to 0 12 | 13 | # one-hot encoding 14 | labels_one_hot = torch.zeros(labels.size(0), num_classes + 1, dtype=torch.float32).to(logits.device) 15 | labels_one_hot.scatter_(1, labels.unsqueeze(1), 1) # Set the corresponding index to 1 16 | labels_one_hot = labels_one_hot[:, 1:] # Remove the first column 17 | 18 | # loss for classifier 19 | loss = F.binary_cross_entropy_with_logits(logits_label, labels_one_hot, reduction='none') 20 | # mask loss using mask (B, C) 21 | masked_loss = loss.view(B, -1, num_classes) * mask.unsqueeze(1) 22 | loss = masked_loss.view(-1, num_classes) 23 | # expand mask_label to loss 24 | mask_label = mask_label.unsqueeze(-1).expand_as(loss) 25 | # put lower loss for in labels_one_hot (2 for positive, 1 for negative) 26 | 27 | # apply mask 28 | loss = loss * mask_label.float() 29 | loss = loss.sum() 30 | 31 | return loss 32 | 33 | 34 | def down_weight_loss(logits, y, sample_rate=0.1, is_logit=True): 35 | 36 | if is_logit: 37 | loss_func = F.cross_entropy 38 | else: 39 | loss_func = F.nll_loss 40 | 41 | loss_entity = loss_func(logits, y.masked_fill(y == 0, -1), ignore_index=-1, reduction='sum') 42 | loss_non_entity = loss_func(logits, y.masked_fill(y > 0, -1), ignore_index=-1, reduction='sum') 43 | 44 | return loss_entity + loss_non_entity * (1 - sample_rate) 45 | -------------------------------------------------------------------------------- /modules/rel_rep.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .layers import MLP 4 | 5 | 6 | class RelationRep(nn.Module): 7 | def __init__(self, hidden_size, dropout, ffn_mul): 8 | super().__init__() 9 | 10 | self.head_mlp = nn.Linear(hidden_size, hidden_size // 2) 11 | self.tail_mlp = nn.Linear(hidden_size, hidden_size // 2) 12 | self.out_mlp = MLP([hidden_size, hidden_size * ffn_mul, hidden_size], dropout) 13 | 14 | def forward(self, span_reps): 15 | """ 16 | :param span_reps [B, topk, D] 17 | :return relation_reps [B, topk, topk, D] 18 | """ 19 | 20 | heads, tails = span_reps, span_reps 21 | 22 | # Apply MLPs to heads and tails 23 | heads = self.head_mlp(heads) 24 | tails = self.tail_mlp(tails) 25 | 26 | # Expand heads and tails to create relation representations 27 | heads = heads.unsqueeze(2).expand(-1, -1, heads.shape[1], -1) 28 | tails = tails.unsqueeze(1).expand(-1, tails.shape[1], -1, -1) 29 | 30 | # Concatenate heads and tails to create relation representations 31 | relation_reps = torch.cat([heads, tails], dim=-1) 32 | 33 | # Apply MLP to relation representations 34 | relation_reps = self.out_mlp(relation_reps) 35 | 36 | return relation_reps 37 | -------------------------------------------------------------------------------- /modules/scorer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | from .layers import MLP 6 | 7 | 8 | class ScorerLayer(nn.Module): 9 | def __init__(self, scoring_type="dot", hidden_size=768, dropout=0.1): 10 | super().__init__() 11 | 12 | self.scoring_type = scoring_type 13 | 14 | if scoring_type == "concat_proj": 15 | self.proj = MLP([hidden_size * 4, hidden_size * 4, 1], dropout) 16 | elif scoring_type == "dot_thresh": 17 | self.proj_thresh = MLP([hidden_size, hidden_size * 4, 2], dropout) 18 | self.proj_type = MLP([hidden_size, hidden_size * 4, hidden_size], dropout) 19 | 20 | def forward(self, candidate_pair_rep, rel_type_rep): 21 | # candidate_pair_rep: [B, N, D] 22 | # rel_type_rep: [B, T, D] 23 | if self.scoring_type == "dot": 24 | return torch.einsum("bnd,btd->bnt", candidate_pair_rep, rel_type_rep) 25 | 26 | elif self.scoring_type == "dot_thresh": 27 | # compute the scaling factor and threshold 28 | B, T, D = rel_type_rep.size() 29 | scaler = self.proj_thresh(rel_type_rep) # [B, T, 2] 30 | # alpha: scaling factor, beta: threshold 31 | alpha, beta = scaler[..., 0].view(B, 1, T), scaler[..., 1].view(B, 1, T) 32 | alpha = F.softplus(alpha) # reason: alpha should be positive 33 | # project the relation type representation 34 | rel_type_rep = self.proj_type(rel_type_rep) # [B, T, D] 35 | # compute the score (before sigmoid) 36 | score = torch.einsum("bnd,btd->bnt", candidate_pair_rep, rel_type_rep) # [B, N, T] 37 | return (score + beta) * alpha # [B, N, T] 38 | 39 | elif self.scoring_type == "dot_norm": 40 | score = torch.einsum("bnd,btd->bnt", candidate_pair_rep, rel_type_rep) # [B, N, T] 41 | bias_1 = self.dy_bias_type(rel_type_rep).transpose(1, 2) # [B, 1, T] 42 | bias_2 = self.dy_bias_rel(candidate_pair_rep) # [B, N, 1] 43 | return score + self.bias + bias_1 + bias_2 44 | 45 | elif self.scoring_type == "concat_proj": 46 | prod_features = candidate_pair_rep.unsqueeze(2) * rel_type_rep.unsqueeze(1) # [B, N, T, D] 47 | diff_features = candidate_pair_rep.unsqueeze(2) - rel_type_rep.unsqueeze(1) # [B, N, T, D] 48 | expanded_pair_rep = candidate_pair_rep.unsqueeze(2).repeat(1, 1, rel_type_rep.size(1), 1) 49 | expanded_rel_type_rep = rel_type_rep.unsqueeze(1).repeat(1, candidate_pair_rep.size(1), 1, 1) 50 | features = torch.cat([prod_features, diff_features, expanded_pair_rep, expanded_rel_type_rep], 51 | dim=-1) # [B, N, T, 2D] 52 | return self.proj(features).squeeze(-1) 53 | -------------------------------------------------------------------------------- /modules/span_rep.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | def create_projection_layer(hidden_size: int, dropout: float, out_dim: int = None) -> nn.Sequential: 7 | """ 8 | Creates a projection layer with specified configurations. 9 | """ 10 | if out_dim is None: 11 | out_dim = hidden_size 12 | 13 | return nn.Sequential( 14 | nn.Linear(hidden_size, int(out_dim * 3 / 2)), 15 | nn.ReLU(), 16 | nn.Dropout(dropout), 17 | nn.Linear(int(out_dim * 3 / 2), out_dim) 18 | ) 19 | 20 | 21 | class SpanConvBlock(nn.Module): 22 | def __init__(self, hidden_size, kernel_size, span_mode='conv_normal'): 23 | super().__init__() 24 | 25 | if span_mode == 'conv_conv': 26 | self.conv = nn.Conv1d(hidden_size, hidden_size, 27 | kernel_size=kernel_size) 28 | 29 | # initialize the weights 30 | nn.init.kaiming_uniform_(self.conv.weight, nonlinearity='relu') 31 | 32 | elif span_mode == 'conv_max': 33 | self.conv = nn.MaxPool1d(kernel_size=kernel_size, stride=1) 34 | elif span_mode == 'conv_mean' or span_mode == 'conv_sum': 35 | self.conv = nn.AvgPool1d(kernel_size=kernel_size, stride=1) 36 | 37 | self.span_mode = span_mode 38 | 39 | self.pad = kernel_size - 1 40 | 41 | def forward(self, x): 42 | 43 | x = torch.einsum('bld->bdl', x) 44 | 45 | if self.pad > 0: 46 | x = F.pad(x, (0, self.pad), "constant", 0) 47 | 48 | x = self.conv(x) 49 | 50 | if self.span_mode == "conv_sum": 51 | x = x * (self.pad + 1) 52 | 53 | return torch.einsum('bdl->bld', x) 54 | 55 | 56 | class SpanConv(nn.Module): 57 | def __init__(self, hidden_size, max_width, span_mode): 58 | super().__init__() 59 | 60 | kernels = [i + 2 for i in range(max_width - 1)] 61 | 62 | self.convs = nn.ModuleList() 63 | 64 | for kernel in kernels: 65 | self.convs.append(SpanConvBlock(hidden_size, kernel, span_mode)) 66 | 67 | self.project = nn.Sequential( 68 | nn.ReLU(), 69 | nn.Linear(hidden_size, hidden_size) 70 | ) 71 | 72 | def forward(self, x, *args): 73 | 74 | span_reps = [x] 75 | 76 | for conv in self.convs: 77 | h = conv(x) 78 | span_reps.append(h) 79 | 80 | span_reps = torch.stack(span_reps, dim=-2) 81 | 82 | return self.project(span_reps) 83 | 84 | 85 | class ConvShare(nn.Module): 86 | def __init__(self, hidden_size, max_width): 87 | super().__init__() 88 | 89 | self.max_width = max_width 90 | 91 | self.conv_weight = nn.Parameter( 92 | torch.randn(hidden_size, hidden_size, max_width)) 93 | 94 | nn.init.kaiming_uniform_(self.conv_weight, nonlinearity='relu') 95 | 96 | self.project = nn.Sequential( 97 | nn.ReLU(), 98 | nn.Linear(hidden_size, hidden_size) 99 | ) 100 | 101 | def forward(self, x, *args): 102 | span_reps = [] 103 | 104 | x = torch.einsum('bld->bdl', x) 105 | 106 | for i in range(self.max_width): 107 | pad = i 108 | x_i = F.pad(x, (0, pad), "constant", 0) 109 | conv_w = self.conv_weight[:, :, :i + 1] 110 | out_i = F.conv1d(x_i, conv_w) 111 | span_reps.append(out_i.transpose(-1, -2)) 112 | 113 | out = torch.stack(span_reps, dim=-2) 114 | 115 | return self.project(out) 116 | 117 | 118 | def extract_elements(sequence, indices): 119 | B, L, D = sequence.shape 120 | K = indices.shape[1] 121 | 122 | # Expand indices to [B, K, D] 123 | expanded_indices = indices.unsqueeze(2).expand(-1, -1, D) 124 | 125 | # Gather the elements 126 | extracted_elements = torch.gather(sequence, 1, expanded_indices) 127 | 128 | return extracted_elements 129 | 130 | 131 | class SpanMarker(nn.Module): 132 | 133 | def __init__(self, hidden_size, max_width, dropout=0.4): 134 | super().__init__() 135 | 136 | self.max_width = max_width 137 | 138 | self.project_start = nn.Sequential( 139 | nn.Linear(hidden_size, hidden_size * 2, bias=True), 140 | nn.ReLU(), 141 | nn.Dropout(dropout), 142 | nn.Linear(hidden_size * 2, hidden_size, bias=True), 143 | ) 144 | 145 | self.project_end = nn.Sequential( 146 | nn.Linear(hidden_size, hidden_size * 2, bias=True), 147 | nn.ReLU(), 148 | nn.Dropout(dropout), 149 | nn.Linear(hidden_size * 2, hidden_size, bias=True), 150 | ) 151 | 152 | self.out_project = nn.Linear(hidden_size * 2, hidden_size, bias=True) 153 | 154 | def forward(self, h, span_idx): 155 | # h of shape [B, L, D] 156 | # query_seg of shape [D, max_width] 157 | 158 | B, L, D = h.size() 159 | 160 | # project start and end 161 | start_rep = self.project_start(h) 162 | end_rep = self.project_end(h) 163 | 164 | start_span_rep = extract_elements(start_rep, span_idx[:, :, 0]) 165 | end_span_rep = extract_elements(end_rep, span_idx[:, :, 1]) 166 | 167 | # concat start and end 168 | cat = torch.cat([start_span_rep, end_span_rep], dim=-1).relu() 169 | 170 | # project 171 | cat = self.out_project(cat) 172 | 173 | # reshape 174 | return cat.view(B, L, self.max_width, D) 175 | 176 | 177 | class SpanMarkerV0(nn.Module): 178 | """ 179 | Marks and projects span endpoints using an MLP. 180 | """ 181 | 182 | def __init__(self, hidden_size: int, max_width: int, dropout: float = 0.4): 183 | super().__init__() 184 | self.max_width = max_width 185 | 186 | self.project_start = create_projection_layer(hidden_size, dropout) 187 | self.project_end = create_projection_layer(hidden_size, dropout) 188 | self.out_project = create_projection_layer(hidden_size * 2, dropout, hidden_size) 189 | 190 | def forward(self, h: torch.Tensor, span_idx: torch.Tensor) -> torch.Tensor: 191 | B, L, D = h.size() 192 | 193 | start_rep = self.project_start(h) 194 | end_rep = self.project_end(h) 195 | 196 | start_span_rep = extract_elements(start_rep, span_idx[:, :, 0]) 197 | end_span_rep = extract_elements(end_rep, span_idx[:, :, 1]) 198 | 199 | cat = torch.cat([start_span_rep, end_span_rep], dim=-1).relu() 200 | 201 | return self.out_project(cat).view(B, L, self.max_width, D) 202 | 203 | 204 | class SpanRepLayer(nn.Module): 205 | """ 206 | Various span representation approaches 207 | """ 208 | 209 | def __init__(self, hidden_size, max_width, span_mode, **kwargs): 210 | super().__init__() 211 | 212 | if span_mode == 'marker': 213 | self.span_rep_layer = SpanMarker(hidden_size, max_width, **kwargs) 214 | elif span_mode == 'markerV0': 215 | self.span_rep_layer = SpanMarkerV0(hidden_size, max_width, **kwargs) 216 | elif span_mode == 'conv_conv': 217 | self.span_rep_layer = SpanConv( 218 | hidden_size, max_width, span_mode='conv_conv') 219 | elif span_mode == 'conv_max': 220 | self.span_rep_layer = SpanConv( 221 | hidden_size, max_width, span_mode='conv_max') 222 | elif span_mode == 'conv_mean': 223 | self.span_rep_layer = SpanConv( 224 | hidden_size, max_width, span_mode='conv_mean') 225 | elif span_mode == 'conv_sum': 226 | self.span_rep_layer = SpanConv( 227 | hidden_size, max_width, span_mode='conv_sum') 228 | elif span_mode == 'conv_share': 229 | self.span_rep_layer = ConvShare(hidden_size, max_width) 230 | else: 231 | raise ValueError(f'Unknown span mode {span_mode}') 232 | 233 | def forward(self, x, *args): 234 | 235 | return self.span_rep_layer(x, *args) 236 | -------------------------------------------------------------------------------- /modules/token_rep.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | from flair.data import Sentence 5 | from flair.embeddings import TransformerWordEmbeddings 6 | from torch import nn 7 | from torch.nn.utils.rnn import pad_sequence 8 | 9 | 10 | class TokenRepLayer(nn.Module): 11 | def __init__(self, model_name: str = "bert-base-cased", fine_tune: bool = True, subtoken_pooling: str = "first", 12 | hidden_size: int = 768, 13 | add_tokens=["[SEP]", "[ENT]"] 14 | ): 15 | super().__init__() 16 | 17 | self.bert_layer = TransformerWordEmbeddings( 18 | model_name, 19 | fine_tune=fine_tune, 20 | subtoken_pooling=subtoken_pooling, 21 | allow_long_sentences=True 22 | ) 23 | 24 | # add tokens to vocabulary 25 | self.bert_layer.tokenizer.add_tokens(add_tokens) 26 | 27 | # resize token embeddings 28 | self.bert_layer.model.resize_token_embeddings(len(self.bert_layer.tokenizer)) 29 | 30 | bert_hidden_size = self.bert_layer.embedding_length 31 | 32 | if hidden_size != bert_hidden_size: 33 | self.projection = nn.Linear(bert_hidden_size, hidden_size) 34 | 35 | def forward(self, tokens: List[List[str]], lengths: torch.Tensor): 36 | token_embeddings = self.compute_word_embedding(tokens) 37 | 38 | if hasattr(self, "projection"): 39 | token_embeddings = self.projection(token_embeddings) 40 | 41 | B = len(lengths) 42 | max_length = lengths.max() 43 | mask = (torch.arange(max_length).view(1, -1).repeat(B, 1) < lengths.cpu().unsqueeze(1)).to( 44 | token_embeddings.device).long() 45 | return {"embeddings": token_embeddings, "mask": mask} 46 | 47 | def compute_word_embedding(self, tokens): 48 | sentences = [Sentence(i) for i in tokens] 49 | self.bert_layer.embed(sentences) 50 | token_embeddings = pad_sequence([torch.stack([t.embedding for t in k]) for k in sentences], batch_first=True) 51 | return token_embeddings 52 | -------------------------------------------------------------------------------- /modules/token_splitter.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | class TokenSplitterBase(): 5 | def __init__(self): 6 | pass 7 | 8 | def __call__(self, text) -> (str, int, int): 9 | pass 10 | 11 | 12 | class WhitespaceTokenSplitter(TokenSplitterBase): 13 | def __init__(self): 14 | self.whitespace_pattern = re.compile(r'\w+(?:[-_]\w+)*|\S') 15 | 16 | def __call__(self, text): 17 | for match in self.whitespace_pattern.finditer(text): 18 | yield match.group(), match.start(), match.end() 19 | 20 | 21 | class SpaCyTokenSplitter(TokenSplitterBase): 22 | def __init__(self, lang=None): 23 | try: 24 | import spacy # noqa 25 | except ModuleNotFoundError as error: 26 | raise error.__class__( 27 | "Please install spacy with: `pip install spacy`" 28 | ) 29 | if lang is None: 30 | lang = 'en' # Default to English if no language is specified 31 | self.nlp = spacy.blank(lang) 32 | 33 | def __call__(self, text): 34 | doc = self.nlp(text) 35 | for token in doc: 36 | yield token.text, token.idx, token.idx + len(token.text) 37 | 38 | 39 | class MecabKoTokenSplitter(TokenSplitterBase): 40 | def __init__(self): 41 | try: 42 | import mecab # noqa 43 | except ModuleNotFoundError as error: 44 | raise error.__class__( 45 | "Please install python-mecab-ko with: `pip install python-mecab-ko`" 46 | ) 47 | self.tagger = mecab.MeCab() 48 | 49 | def __call__(self, text): 50 | tokens = self.tagger.morphs(text) 51 | 52 | last_idx = 0 53 | for morph in tokens: 54 | start_idx = text.find(morph, last_idx) 55 | end_idx = start_idx + len(morph) 56 | last_idx = end_idx 57 | yield morph, start_idx, end_idx -------------------------------------------------------------------------------- /modules/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def decode_relations(id_to_rel, logits, pair_indices, max_pairs, span_indices, threshold=0.5, output_confidence=False): 6 | # Apply sigmoid function to logits 7 | probabilities = torch.sigmoid(logits) 8 | 9 | # Initialize list of relations 10 | relations = [[] for _ in range(len(logits))] 11 | 12 | # Get indices where probability is greater than threshold 13 | above_threshold_indices = (probabilities > threshold).nonzero(as_tuple=True) 14 | 15 | # Iterate over indices where probability is greater than threshold 16 | for batch_idx, position, class_idx in zip(*above_threshold_indices): 17 | # Get relation label 18 | label = id_to_rel[class_idx.item() + 1] 19 | 20 | # Get predicted pair index 21 | predicted_pair_idx = pair_indices[batch_idx, position].item() 22 | 23 | # Unravel predicted pair index into head and tail 24 | head_idx, tail_idx = np.unravel_index(predicted_pair_idx, (max_pairs, max_pairs)) 25 | 26 | # Convert head and tail indices to tuples 27 | head = tuple(span_indices[batch_idx, head_idx].tolist()) 28 | tail = tuple(span_indices[batch_idx, tail_idx].tolist()) 29 | 30 | # Get confidence 31 | confidence = probabilities[batch_idx, position, class_idx].item() 32 | 33 | # Append relation to list 34 | if output_confidence: 35 | relations[batch_idx.item()].append((head, tail, label, confidence)) 36 | else: 37 | relations[batch_idx.item()].append((head, tail, label)) 38 | 39 | return relations 40 | 41 | 42 | def decode_entities(id_to_ent, logits, span_indices, threshold=0.5, output_confidence=False): 43 | # Apply sigmoid function to logits 44 | probabilities = torch.sigmoid(logits) 45 | 46 | # Initialize list of entities 47 | entities = [] 48 | 49 | # Get indices where probability is greater than threshold 50 | above_threshold_indices = (probabilities > threshold).nonzero(as_tuple=True) 51 | 52 | # Iterate over indices where probability is greater than threshold 53 | for batch_idx, position, class_idx in zip(*above_threshold_indices): 54 | # Get entity label 55 | label = id_to_ent[class_idx.item() + 1] 56 | 57 | # Get confidence 58 | confidence = probabilities[batch_idx, position, class_idx].item() 59 | 60 | # Append entity to list 61 | if output_confidence: 62 | entities.append((tuple(span_indices[batch_idx, position].tolist()), label, confidence)) 63 | else: 64 | entities.append((tuple(span_indices[batch_idx, position].tolist()), label)) 65 | 66 | return entities 67 | 68 | 69 | def er_decoder(x, entity_logits, rel_logits, topk_pair_idx, max_top_k, candidate_spans_idx, threshold=0.5, 70 | output_confidence=False, token_splitter=None): 71 | entities = decode_entities(x["id_to_ent"], entity_logits, candidate_spans_idx, threshold, output_confidence) 72 | relations = decode_relations(x["id_to_rel"], rel_logits, topk_pair_idx, max_top_k, candidate_spans_idx, threshold, 73 | output_confidence) 74 | return entities, relations 75 | 76 | 77 | def get_relation_with_span(x): 78 | entities, relations = x['entities'], x['relations'] 79 | B = len(entities) 80 | relation_with_span = [[] for i in range(B)] 81 | for i in range(B): 82 | rel_i = relations[i] 83 | ent_i = entities[i] 84 | for rel in rel_i: 85 | act = (ent_i[rel[0]], ent_i[rel[1]], rel[2]) 86 | relation_with_span[i].append(act) 87 | return relation_with_span 88 | 89 | 90 | def get_ground_truth_relations(x, candidate_spans_idx, candidate_span_label): 91 | B, max_top_k = candidate_span_label.shape 92 | 93 | relation_classes = torch.zeros((B, max_top_k, max_top_k), dtype=torch.long, device=candidate_spans_idx.device) 94 | 95 | # Populate relation classes 96 | for i in range(B): 97 | rel_i = x["relations"][i] 98 | ent_i = x["entities"][i] 99 | 100 | new_heads, new_tails, new_rel_type = [], [], [] 101 | 102 | # Loop over the relations and entities to populate initial lists 103 | for k in rel_i: 104 | heads_i = [ent_i[k[0]][0], ent_i[k[0]][1]] 105 | tails_i = [ent_i[k[1]][0], ent_i[k[1]][1]] 106 | type_i = k[2] 107 | new_heads.append(heads_i) 108 | new_tails.append(tails_i) 109 | new_rel_type.append(type_i) 110 | 111 | # Update the original lists 112 | heads_, tails_, rel_type = new_heads, new_tails, new_rel_type 113 | 114 | # idx of candidate spans 115 | cand_i = candidate_spans_idx[i].tolist() 116 | 117 | for heads_i, tails_i, type_i in zip(heads_, tails_, rel_type): 118 | 119 | flag = False 120 | if isinstance(x["rel_to_id"], dict): 121 | if type_i in x["rel_to_id"]: 122 | flag = True 123 | elif isinstance(x["rel_to_id"], list): 124 | if type_i in x["rel_to_id"][i]: 125 | flag = True 126 | 127 | if heads_i in cand_i and tails_i in cand_i and flag: 128 | idx_head = cand_i.index(heads_i) 129 | idx_tail = cand_i.index(tails_i) 130 | 131 | if isinstance(x["rel_to_id"], list): 132 | relation_classes[i, idx_head, idx_tail] = x["rel_to_id"][i][type_i] 133 | elif isinstance(x["rel_to_id"], dict): 134 | relation_classes[i, idx_head, idx_tail] = x["rel_to_id"][type_i] 135 | 136 | # flat relation classes 137 | relation_classes = relation_classes.view(-1, max_top_k * max_top_k) 138 | 139 | # put to -1 class where corresponding candidate_span_label is -1 (for both head and tail) 140 | head_candidate_span_label = candidate_span_label.view(B, max_top_k, 1).repeat(1, 1, max_top_k).view(B, -1) 141 | tail_candidate_span_label = candidate_span_label.view(B, 1, max_top_k).repeat(1, max_top_k, 1).view(B, -1) 142 | 143 | relation_classes.masked_fill_(head_candidate_span_label.view(B, max_top_k * max_top_k) == -1, -1) # head 144 | relation_classes.masked_fill_(tail_candidate_span_label.view(B, max_top_k * max_top_k) == -1, -1) # tail 145 | 146 | return relation_classes 147 | 148 | 149 | def get_candidates(sorted_idx, tensor_elem, topk=10): 150 | # sorted_idx [B, num_spans] 151 | # tensor_elem [B, num_spans, D] or [B, num_spans] 152 | 153 | sorted_topk_idx = sorted_idx[:, :topk] 154 | 155 | if len(tensor_elem.shape) == 3: 156 | B, num_spans, D = tensor_elem.shape 157 | topk_tensor_elem = tensor_elem.gather(1, sorted_topk_idx.unsqueeze(-1).expand(-1, -1, D)) 158 | else: 159 | # [B, topk] 160 | topk_tensor_elem = tensor_elem.gather(1, sorted_topk_idx) 161 | 162 | return topk_tensor_elem, sorted_topk_idx 163 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | from types import SimpleNamespace 6 | 7 | from tqdm import tqdm 8 | from transformers import ( 9 | get_cosine_schedule_with_warmup, 10 | get_linear_schedule_with_warmup, 11 | get_constant_schedule_with_warmup, 12 | get_polynomial_decay_schedule_with_warmup, 13 | get_inverse_sqrt_schedule, 14 | ) 15 | import torch 16 | import torch.distributed as dist 17 | import torch.multiprocessing as mp 18 | from torch.nn.parallel import DistributedDataParallel as DDP 19 | from torch.utils.data import DataLoader, TensorDataset 20 | from torch.utils.data.distributed import DistributedSampler 21 | 22 | from model import GraphER 23 | from modules.base import load_config_as_namespace 24 | #from modules.run_evaluation import get_for_all_path 25 | 26 | 27 | def save_top_k_checkpoints(model: GraphER, save_path: str, checkpoint: int, top_k: int = 5): 28 | """ 29 | Save the top-k checkpoints (latest k checkpoints) of a model and tokenizer. 30 | 31 | Parameters: 32 | model (GraphER): The model to save. 33 | save_path (str): The directory path to save the checkpoints. 34 | top_k (int): The number of top checkpoints to keep. Defaults to 5. 35 | """ 36 | # Save the current model and tokenizer 37 | if isinstance(model, DDP): 38 | model.module.save_pretrained(os.path.join(save_path, str(checkpoint))) 39 | else: 40 | model.save_pretrained(os.path.join(save_path, str(checkpoint))) 41 | 42 | # List all files in the directory 43 | files = os.listdir(save_path) 44 | 45 | # Filter files to keep only the model checkpoints 46 | checkpoint_folders = [file for file in files if re.search(r'model_\d+', file)] 47 | 48 | # Sort checkpoint files by modification time (latest first) 49 | checkpoint_folders.sort(key=lambda x: os.path.getmtime(os.path.join(save_path, x)), reverse=True) 50 | 51 | # Keep only the top-k checkpoints 52 | for checkpoint_folder in checkpoint_folders[top_k:]: 53 | checkpoint_folder = os.path.join(save_path, checkpoint_folder) 54 | checkpoint_files = [os.path.join(checkpoint_folder, f) for f in os.listdir(checkpoint_folder)] 55 | for file in checkpoint_files: 56 | os.remove(file) 57 | os.rmdir(os.path.join(checkpoint_folder)) 58 | 59 | 60 | class Trainer: 61 | def __init__(self, config, allow_distributed, device='cuda'): 62 | self.config = config 63 | self.lr_encoder = float(self.config.lr_encoder) 64 | self.lr_others = float(self.config.lr_others) 65 | 66 | self.device = device 67 | 68 | if config.prev_path != "none": # fine-tuning mode 69 | self.model_config = SimpleNamespace( 70 | max_types=config.max_types, 71 | max_len=config.max_len, 72 | max_top_k=config.max_top_k, 73 | add_top_k=config.add_top_k, 74 | shuffle_types=config.shuffle_types, 75 | random_drop=config.random_drop, 76 | max_neg_type_ratio=config.max_neg_type_ratio, 77 | max_ent_types=config.fine_tune, 78 | max_rel_types=config.max_rel_types, 79 | ) 80 | else: 81 | self.model_config = SimpleNamespace( 82 | model_name=config.model_name, 83 | name=config.name, 84 | max_width=config.max_width, 85 | hidden_size=config.hidden_size, 86 | dropout=config.dropout, 87 | fine_tune=config.fine_tune, 88 | subtoken_pooling=config.subtoken_pooling, 89 | span_mode=config.span_mode, 90 | max_types=config.max_types, 91 | max_len=config.max_len, 92 | num_heads=config.num_heads, 93 | num_transformer_layers=config.num_transformer_layers, 94 | ffn_mul=config.ffn_mul, 95 | scorer=config.scorer, 96 | max_top_k=config.max_top_k, 97 | add_top_k=config.add_top_k, 98 | shuffle_types=config.shuffle_types, 99 | random_drop=config.random_drop, 100 | max_neg_type_ratio=config.max_neg_type_ratio, 101 | max_ent_types=config.fine_tune, 102 | max_rel_types=config.max_rel_types, 103 | ) 104 | 105 | self.allow_distributed = allow_distributed 106 | 107 | def setup_distributed(self, rank, world_size): 108 | os.environ['MASTER_ADDR'] = 'localhost' 109 | os.environ['MASTER_PORT'] = '12356' 110 | torch.cuda.set_device(rank) 111 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 112 | 113 | def cleanup_distributed(self): 114 | dist.destroy_process_group() 115 | 116 | def setup_model_and_optimizer(self, rank=None, device=None): 117 | if device is None: 118 | device = self.device 119 | if self.config.prev_path != "none": 120 | model = GraphER.from_pretrained(self.config.prev_path).to(device) 121 | 122 | # some parameters of model.config are not overwritten by the config file 123 | # other than these are overwritten 124 | keep_params = ['model_name', 'name', 'max_width', 'hidden_size', 'dropout', 'subtoken_pooling', 'span_mode', 125 | "fine_tune", "max_types", "max_len", "num_heads", "num_transformer_layers", "ffn_mul", 126 | "scorer"] 127 | 128 | for param in keep_params: 129 | original_value = getattr(model.config, param) 130 | setattr(self.model_config, param, original_value) 131 | 132 | model.config = self.model_config 133 | else: 134 | model = GraphER(self.model_config).to(device) 135 | 136 | if rank is not None: 137 | model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=False) 138 | optimizer = model.module.get_optimizer(self.lr_encoder, self.lr_others, freeze_token_rep=self.config.freeze_token_rep) 139 | else: 140 | optimizer = model.get_optimizer(self.lr_encoder, self.lr_others, freeze_token_rep=self.config.freeze_token_rep) 141 | return model, optimizer 142 | 143 | def train_dist(self, rank, world_size, dataset): 144 | # Init distributed process group 145 | self.setup_distributed(rank, world_size) 146 | 147 | device = f'cuda:{rank}' 148 | 149 | model, optimizer = self.setup_model_and_optimizer(rank, device=device) 150 | 151 | sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False) 152 | 153 | train_loader = model.module.create_dataloader(dataset, batch_size=self.config.train_batch_size, shuffle=False, 154 | sampler=sampler) 155 | 156 | num_steps = self.config.num_steps // world_size 157 | 158 | self.train(model=model, optimizer=optimizer, train_loader=train_loader, 159 | num_steps=num_steps, device=device, rank=rank) 160 | 161 | self.cleanup_distributed() 162 | 163 | def init_scheduler(self, scheduler_type, optimizer, num_warmup_steps, num_steps): 164 | if scheduler_type == "cosine": 165 | scheduler = get_cosine_schedule_with_warmup( 166 | optimizer, 167 | num_warmup_steps=num_warmup_steps, 168 | num_training_steps=num_steps 169 | ) 170 | elif scheduler_type == "linear": 171 | scheduler = get_linear_schedule_with_warmup( 172 | optimizer, 173 | num_warmup_steps=num_warmup_steps, 174 | num_training_steps=num_steps 175 | ) 176 | elif scheduler_type == "constant": 177 | scheduler = get_constant_schedule_with_warmup( 178 | optimizer, 179 | num_warmup_steps=num_warmup_steps, 180 | ) 181 | elif scheduler_type == "polynomial": 182 | scheduler = get_polynomial_decay_schedule_with_warmup( 183 | optimizer, 184 | num_warmup_steps=num_warmup_steps, 185 | num_training_steps=num_steps 186 | ) 187 | elif scheduler_type == "inverse_sqrt": 188 | scheduler = get_inverse_sqrt_schedule( 189 | optimizer, 190 | num_warmup_steps=num_warmup_steps, 191 | ) 192 | else: 193 | raise ValueError( 194 | f"Invalid scheduler_type value: '{scheduler_type}' \n Supported scheduler types: 'cosine', 'linear', 'constant', 'polynomial', 'inverse_sqrt'" 195 | ) 196 | return scheduler 197 | 198 | def train(self, model, optimizer, train_loader, num_steps, device='cuda', rank=None): 199 | model.train() 200 | pbar = tqdm(range(num_steps)) 201 | 202 | warmup_ratio = self.config.warmup_ratio 203 | eval_every = self.config.eval_every 204 | save_total_limit = self.config.save_total_limit 205 | log_dir = self.config.log_dir 206 | val_data_dir = self.config.val_data_dir 207 | 208 | num_warmup_steps = int(num_steps * warmup_ratio) if warmup_ratio < 1 else int(warmup_ratio) 209 | 210 | scheduler = self.init_scheduler(self.config.scheduler_type, optimizer, num_warmup_steps, num_steps) 211 | iter_train_loader = iter(train_loader) 212 | scaler = torch.cuda.amp.GradScaler() 213 | 214 | for step in pbar: 215 | optimizer.zero_grad() 216 | 217 | try: 218 | x = next(iter_train_loader) 219 | except StopIteration: 220 | iter_train_loader = iter(train_loader) 221 | x = next(iter_train_loader) 222 | 223 | for k, v in x.items(): 224 | if isinstance(v, torch.Tensor): 225 | x[k] = v.to(device) 226 | 227 | with torch.cuda.amp.autocast(dtype=torch.float16): 228 | loss = model(x) 229 | 230 | if torch.isnan(loss).any(): 231 | print("Warning: NaN loss detected") 232 | continue 233 | 234 | scaler.scale(loss).backward() 235 | scaler.step(optimizer) 236 | scaler.update() 237 | scheduler.step() 238 | 239 | description = f"step: {step} | epoch: {step // len(train_loader)} | loss: {loss.item():.2f}" 240 | pbar.set_description(description) 241 | 242 | if (step + 1) % eval_every == 0: 243 | if rank is None or rank == 0: 244 | checkpoint = f'model_{step + 1}' 245 | save_top_k_checkpoints(model, log_dir, checkpoint, save_total_limit) 246 | #if val_data_dir != "none": 247 | #get_for_all_path(model, step, log_dir, val_data_dir) 248 | model.train() 249 | 250 | def run(self): 251 | with open(self.config.train_data, 'r') as f: 252 | data = json.load(f) 253 | 254 | if torch.cuda.device_count() > 1 and self.allow_distributed: 255 | world_size = torch.cuda.device_count() 256 | mp.spawn(self.train_dist, args=(world_size, data), nprocs=world_size, join=True) 257 | else: 258 | model, optimizer = self.setup_model_and_optimizer() 259 | 260 | train_loader = model.create_dataloader(data, batch_size=self.config.train_batch_size, shuffle=True) 261 | 262 | self.train(model, optimizer, train_loader, num_steps=self.config.num_steps, device=self.device) 263 | 264 | 265 | def create_parser(): 266 | parser = argparse.ArgumentParser(description="grapher") 267 | parser.add_argument("--config", type=str, default="config.yaml", help="Path to config file") 268 | parser.add_argument('--log_dir', type=str, default='logs', help='Path to the log directory') 269 | parser.add_argument('--allow_distributed', type=bool, default=False, 270 | help='Whether to allow distributed training if there are more than one GPU available') 271 | return parser 272 | 273 | 274 | if __name__ == "__main__": 275 | parser = create_parser() 276 | args = parser.parse_args() 277 | config = load_config_as_namespace(args.config) 278 | config.log_dir = args.log_dir 279 | 280 | trainer = Trainer(config, allow_distributed=args.allow_distributed, 281 | device='cuda' if torch.cuda.is_available() else 'cpu') 282 | trainer.run() 283 | --------------------------------------------------------------------------------