├── LICENSE
├── README.md
├── config.yaml
├── grapher_fig.png
├── model.py
├── modules
├── base.py
├── data_processor.py
├── evaluator.py
├── filtering.py
├── layers.py
├── loss_functions.py
├── rel_rep.py
├── scorer.py
├── span_rep.py
├── token_rep.py
├── token_splitter.py
└── utils.py
└── train.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Urchade Zaratiana
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # *GraphER*: A Structure-aware Text-to-Graph Model for Entity and Relation Extraction
2 |
3 | 
4 |
5 | ### TO DO:
6 | - [ ] Implement negative sampling (both for entities and relations)
7 | - [ ] Add evaluation computation
8 | - [ ] Ablation study (scoring, ...)
9 |
10 | For now, you can play with the beta version in colab: [
](https://colab.research.google.com/drive/1IinAMCtUotntrtoP9zNutriZJtJ4Hymd?usp=sharing)
11 |
12 | ```bibtex
13 | @misc{zaratiana2024grapher,
14 | title={GraphER: A Structure-aware Text-to-Graph Model for Entity and Relation Extraction},
15 | author={Urchade Zaratiana and Nadi Tomeh and Niama El Khbir and Pierre Holat and Thierry Charnois},
16 | year={2024},
17 | eprint={2404.12491},
18 | archivePrefix={arXiv},
19 | primaryClass={cs.CL}
20 | }
21 | ```
22 |
--------------------------------------------------------------------------------
/config.yaml:
--------------------------------------------------------------------------------
1 | # Model Configuration
2 | model_name: microsoft/deberta-v3-base # Hugging Face model
3 | name: "grapher"
4 | max_width: 12
5 | hidden_size: 768
6 | dropout: 0.1
7 | fine_tune: true
8 | subtoken_pooling: first
9 | span_mode: markerV0
10 | num_heads: 4
11 | num_transformer_layers: 2
12 | ffn_mul: 4
13 | scorer: "dot"
14 |
15 | # Training Parameters
16 | num_steps: 30000
17 | train_batch_size: 8
18 | eval_every: 5000
19 | warmup_ratio: 0.1
20 | scheduler_type: "cosine"
21 |
22 | # Learning Rate and weight decay Configuration
23 | lr_encoder: 1e-5
24 | lr_others: 5e-5
25 |
26 | # Directory Paths
27 | root_dir: grapher_logs
28 | train_data: "data/rel_news_b.json"
29 | val_data_dir: "data/NER_datasets"
30 |
31 | # Pretrained Model Path
32 | # Use "none" if no pretrained model is being used
33 | prev_path: "none"
34 |
35 | # Advanced Training Settings
36 | size_sup: -1
37 | max_types: 25
38 | max_len: 384
39 | freeze_token_rep: false
40 | save_total_limit: 20
41 | max_top_k: 54
42 | add_top_k: 10
43 | shuffle_types: true
44 |
45 | random_drop: true
46 | max_neg_type_ratio: 2
47 | max_ent_types: 20
48 | max_rel_types: 20
--------------------------------------------------------------------------------
/grapher_fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/urchade/GraphER/23cc4edc4fb04cebd40b99ec02e55700ca8252dd/grapher_fig.png
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 |
5 | from modules.base import GrapherBase
6 | from modules.data_processor import TokenPromptProcessorTR
7 | from modules.filtering import FilteringLayer
8 | from modules.layers import MLP, LstmSeq2SeqEncoder, TransLayer, GraphEmbedder
9 | from modules.loss_functions import compute_matching_loss
10 | from modules.rel_rep import RelationRep
11 | from modules.scorer import ScorerLayer
12 | from modules.span_rep import SpanRepLayer
13 | from modules.token_rep import TokenRepLayer
14 | from modules.utils import get_ground_truth_relations, get_candidates
15 |
16 |
17 | class GraphER(GrapherBase):
18 | def __init__(self, config):
19 | super().__init__(config)
20 |
21 | self.config = config
22 |
23 | # [ENT] token
24 | self.ent_token = "<>"
25 | self.rel_token = "<>"
26 | self.sep_token = "<>"
27 |
28 | # usually a pretrained bidirectional transformer, returns first subtoken representation
29 | self.token_rep_layer = TokenRepLayer(model_name=config.model_name, fine_tune=config.fine_tune,
30 | subtoken_pooling=config.subtoken_pooling, hidden_size=config.hidden_size,
31 | add_tokens=[self.ent_token, self.rel_token, self.sep_token])
32 |
33 | # token prompt processor
34 | self.token_prompt_processor = TokenPromptProcessorTR(self.ent_token, self.rel_token, self.sep_token)
35 |
36 | # hierarchical representation of tokens (Zaratiana et al, 2022)
37 | # https://arxiv.org/pdf/2203.14710.pdf
38 | self.rnn = LstmSeq2SeqEncoder(
39 | input_size=config.hidden_size,
40 | hidden_size=config.hidden_size // 2,
41 | num_layers=1,
42 | bidirectional=True
43 | )
44 |
45 | # span representation
46 | self.span_rep_layer = SpanRepLayer(
47 | span_mode=config.span_mode,
48 | hidden_size=config.hidden_size,
49 | max_width=config.max_width,
50 | dropout=config.dropout
51 | )
52 |
53 | # prompt representation (FFN)
54 | self.ent_rep_layer = nn.Linear(config.hidden_size, config.hidden_size)
55 | self.rel_rep_layer = nn.Linear(config.hidden_size, config.hidden_size)
56 |
57 | # filtering layer for spans and relations
58 | self._span_filtering = FilteringLayer(config.hidden_size)
59 | self._rel_filtering = FilteringLayer(config.hidden_size)
60 |
61 | # relation representation
62 | self.relation_rep = RelationRep(config.hidden_size, config.dropout, config.ffn_mul)
63 |
64 | # graph embedder
65 | self.graph_embedder = GraphEmbedder(config.hidden_size)
66 |
67 | # transformer layer
68 | self.trans_layer = TransLayer(
69 | config.hidden_size,
70 | num_heads=config.num_heads,
71 | num_layers=config.num_transformer_layers
72 | )
73 |
74 | # keep_mlp
75 | self.keep_mlp = MLP([config.hidden_size, config.hidden_size * config.ffn_mul, 1], dropout=0.1)
76 |
77 | # scoring layers
78 | self.scorer_ent = ScorerLayer(
79 | scoring_type=config.scorer,
80 | hidden_size=config.hidden_size,
81 | dropout=config.dropout
82 | )
83 |
84 | self.scorer_rel = ScorerLayer(
85 | scoring_type=config.scorer,
86 | hidden_size=config.hidden_size,
87 | dropout=config.dropout
88 | )
89 |
90 | def get_optimizer(self, lr_encoder, lr_others, freeze_token_rep=False):
91 | """
92 | Parameters:
93 | - lr_encoder: Learning rate for the encoder layer.
94 | - lr_others: Learning rate for all other layers.
95 | - freeze_token_rep: whether the token representation layer should be frozen.
96 | """
97 | param_groups = [
98 | # encoder
99 | {"params": self.rnn.parameters(), "lr": lr_others},
100 | # projection layers
101 | {"params": self.span_rep_layer.parameters(), "lr": lr_others},
102 | # prompt representation
103 | {"params": self.ent_rep_layer.parameters(), "lr": lr_others},
104 | {"params": self.rel_rep_layer.parameters(), "lr": lr_others},
105 | # filtering layers
106 | {"params": self._span_filtering.parameters(), "lr": lr_others},
107 | {"params": self._rel_filtering.parameters(), "lr": lr_others},
108 | # relation representation
109 | {"params": self.relation_rep.parameters(), "lr": lr_others},
110 | # graph embedder
111 | {"params": self.graph_embedder.parameters(), "lr": lr_others},
112 | # transformer layer
113 | {"params": self.trans_layer.parameters(), "lr": lr_others},
114 | # keep_mlp
115 | {"params": self.keep_mlp.parameters(), "lr": lr_others},
116 | # scoring layer
117 | {"params": self.scorer_ent.parameters(), "lr": lr_others},
118 | {"params": self.scorer_rel.parameters(), "lr": lr_others}
119 | ]
120 |
121 | if not freeze_token_rep:
122 | # If token_rep_layer should not be frozen, add it to the optimizer with its learning rate
123 | param_groups.append({"params": self.token_rep_layer.parameters(), "lr": lr_encoder})
124 | else:
125 | # If token_rep_layer should be frozen, explicitly set requires_grad to False for its parameters
126 | for param in self.token_rep_layer.parameters():
127 | param.requires_grad = False
128 |
129 | optimizer = torch.optim.AdamW(param_groups)
130 |
131 | return optimizer
132 |
133 | def compute_score_train(self, x):
134 | span_idx = x['span_idx'] * x['span_mask'].unsqueeze(-1)
135 |
136 | # Process input
137 | word_rep, mask, entity_type_rep, entity_type_mask, rel_type_rep, relation_type_mask = self.token_prompt_processor.process(
138 | x, self.token_rep_layer, "train"
139 | )
140 |
141 | # Compute representations
142 | word_rep = self.rnn(word_rep, mask)
143 | span_rep = self.span_rep_layer(word_rep, span_idx)
144 | entity_type_rep = self.ent_rep_layer(entity_type_rep)
145 | rel_type_rep = self.rel_rep_layer(rel_type_rep)
146 |
147 | # Compute number of entity and relation types
148 | num_ent, num_rel = entity_type_rep.shape[1], rel_type_rep.shape[1]
149 |
150 | return span_rep, num_ent, num_rel, entity_type_rep, entity_type_mask, rel_type_rep, relation_type_mask, (
151 | word_rep, mask)
152 |
153 | @torch.no_grad()
154 | def compute_score_eval(self, x, device):
155 | span_idx = (x['span_idx'] * x['span_mask'].unsqueeze(-1)).to(device)
156 |
157 | # Process input
158 | word_rep, mask, entity_type_rep, relation_type_rep = self.token_prompt_processor.process(
159 | x, self.token_rep_layer, "eval"
160 | )
161 |
162 | # Compute representations
163 | word_rep = self.rnn(word_rep, mask)
164 | span_rep = self.span_rep_layer(word_rep, span_idx)
165 | entity_type_rep = self.ent_rep_layer(entity_type_rep)
166 | relation_type_rep = self.rel_rep_layer(relation_type_rep)
167 |
168 | # Compute number of entity and relation types
169 | num_ent, num_rel = entity_type_rep.shape[1], relation_type_rep.shape[1]
170 |
171 | return span_rep, num_ent, num_rel, entity_type_rep, relation_type_rep, (word_rep, mask)
172 |
173 | def forward(self, x, prediction_mode=False):
174 |
175 | # clone span_label
176 | span_label = x['span_label'].clone()
177 |
178 | # compute span representation
179 | if prediction_mode:
180 | # Get the device of the model parameters
181 | device = next(self.parameters()).device
182 |
183 | # Compute scores for evaluation
184 | span_rep, num_ent, num_rel, entity_type_rep, rel_type_rep, (word_rep, word_mask) = self.compute_score_eval(
185 | x, device)
186 |
187 | # Create masks for relation and entity types, setting all values to 1
188 | relation_type_mask = torch.ones(size=(rel_type_rep.shape[0], num_rel), device=device)
189 | entity_type_mask = torch.ones(size=(entity_type_rep.shape[0], num_ent), device=device)
190 | else:
191 | # Compute scores for training
192 | span_rep, num_ent, num_rel, entity_type_rep, entity_type_mask, rel_type_rep, relation_type_mask, (
193 | word_rep, mask) = self.compute_score_train(x)
194 |
195 | # Reshape span_rep from (B, L, K, D) to (B, L * K, D)
196 | B, L, K, D = span_rep.shape
197 | span_rep = span_rep.view(B, L * K, D)
198 |
199 | # Compute filtering scores for spans
200 | filter_score_span, filter_loss_span = self._span_filtering(span_rep, x['span_label'])
201 |
202 | # Determine the maximum number of candidates
203 | # If L is greater than the configured maximum, use the configured maximum plus an additional top K
204 | # Otherwise, use L plus an additional top K
205 | max_top_k = min(L, self.config.max_top_k) + self.config.add_top_k
206 |
207 | # Sort the filter scores for spans in descending order
208 | sorted_idx = torch.sort(filter_score_span, dim=-1, descending=True)[1]
209 |
210 | # Define the elements to get candidates for
211 | elements = [span_rep, span_label, x['span_mask'], x['span_idx']]
212 |
213 | # Use a list comprehension to get the candidates for each element
214 | candidate_span_rep, candidate_span_label, candidate_span_mask, candidate_spans_idx = [
215 | get_candidates(sorted_idx, element, topk=max_top_k)[0] for element in elements
216 | ]
217 |
218 | # Calculate the lengths for the top K entities
219 | top_k_lengths = x["seq_length"].clone() + self.config.add_top_k
220 |
221 | # Create a condition mask where the range of top K is greater than or equal to the top K lengths
222 | condition_mask = torch.arange(max_top_k, device=span_rep.device).unsqueeze(0) >= top_k_lengths.unsqueeze(-1)
223 |
224 | # Apply the condition mask to the candidate span mask and label, setting the masked values to 0 and -1
225 | # respectively
226 | candidate_span_mask.masked_fill_(condition_mask, 0)
227 | candidate_span_label.masked_fill_(condition_mask, -1)
228 |
229 | # Get ground truth relations and represent them
230 | relation_classes = get_ground_truth_relations(x, candidate_spans_idx, candidate_span_label)
231 | rel_rep = self.relation_rep(candidate_span_rep).view(B, max_top_k * max_top_k, -1) # Reshape in the same line
232 |
233 | # Compute filtering scores for relations and sort them in descending order
234 | filter_score_rel, filter_loss_rel = self._rel_filtering(rel_rep, relation_classes)
235 | sorted_idx_pair = torch.sort(filter_score_rel, dim=-1, descending=True)[1]
236 |
237 | # Embed candidate span representations
238 | candidate_span_rep, cat_pair_rep = self.graph_embedder(candidate_span_rep)
239 |
240 | # Define the elements to get candidates for
241 | elements = [cat_pair_rep.view(B, max_top_k * max_top_k, -1), relation_classes.view(B, max_top_k * max_top_k)]
242 |
243 | # Use a list comprehension to get the candidates for each element
244 | candidate_pair_rep, candidate_pair_label = [get_candidates(sorted_idx_pair, element, topk=max_top_k)[0] for
245 | element in elements]
246 |
247 | # Get the top K relation indices
248 | topK_rel_idx = sorted_idx_pair[:, :max_top_k]
249 |
250 | # Mask the candidate pair labels using the condition mask and refine the relation representation
251 | candidate_pair_label.masked_fill_(condition_mask, -1)
252 | candidate_pair_mask = candidate_pair_label > -1
253 |
254 | # Concatenate span and relation representations
255 | concat_span_pair = torch.cat((candidate_span_rep, candidate_pair_rep), dim=1)
256 | mask_span_pair = torch.cat((candidate_span_mask, candidate_pair_mask), dim=1)
257 |
258 | # Apply transformer layer and keep_mlp
259 | out_trans = self.trans_layer(concat_span_pair, mask_span_pair)
260 | keep_score = self.keep_mlp(out_trans).squeeze(-1) # Shape: (B, max_top_k + max_top_k, 1)
261 |
262 | # Apply sigmoid function and squeeze the last dimension
263 | # keep_score = torch.sigmoid(keep_score).squeeze(-1) # Shape: (B, max_top_k + max_top_k)
264 |
265 | # Split keep_score into keep_ent and keep_rel
266 | keep_ent, keep_rel = keep_score.split([max_top_k, max_top_k], dim=1)
267 |
268 | """not use output from transformer layer for now"""
269 | # Split out_trans
270 | # candidate_span_rep, candidate_pair_rep = out_trans.split([max_top_k, max_top_k], dim=1)
271 |
272 | # Compute scores for entities and relations
273 | scores_ent = self.scorer_ent(candidate_span_rep, entity_type_rep) # Shape: [B, N, C]
274 | scores_rel = self.scorer_rel(candidate_pair_rep, rel_type_rep) # Shape: [B, N, C]
275 |
276 | if prediction_mode:
277 | return {
278 | "entity_logits": scores_ent,
279 | "relation_logits": scores_rel,
280 | "keep_ent": keep_ent,
281 | "keep_rel": keep_rel,
282 | "candidate_spans_idx": candidate_spans_idx,
283 | "candidate_pair_label": candidate_pair_label,
284 | "max_top_k": max_top_k,
285 | "topK_rel_idx": topK_rel_idx
286 | }
287 | # Compute losses for relation and entity classifiers
288 | relation_loss = compute_matching_loss(scores_rel, candidate_pair_label, relation_type_mask, num_rel)
289 | entity_loss = compute_matching_loss(scores_ent, candidate_span_label, entity_type_mask, num_ent)
290 |
291 | # Concatenate labels for binary classification and compute binary classification loss
292 | ent_rel_label = (torch.cat((candidate_span_label, candidate_pair_label), dim=1) > 0).float()
293 | filter_loss = F.binary_cross_entropy_with_logits(keep_score, ent_rel_label, reduction='none')
294 |
295 | # Compute structure loss and total loss
296 | structure_loss = (filter_loss * mask_span_pair.float()).sum()
297 | total_loss = sum([filter_loss_span, filter_loss_rel, relation_loss, entity_loss, structure_loss])
298 |
299 | return total_loss
300 |
--------------------------------------------------------------------------------
/modules/base.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | from abc import abstractmethod
4 | from pathlib import Path
5 | from typing import Union, Optional, Dict
6 |
7 | import torch
8 | import torch.nn as nn
9 | import yaml
10 | from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
11 | from torch.utils.data import DataLoader
12 |
13 | from .data_processor import GrapherData
14 | from .evaluator import Evaluator
15 | from .token_splitter import WhitespaceTokenSplitter, MecabKoTokenSplitter, SpaCyTokenSplitter
16 | from .utils import er_decoder, get_relation_with_span
17 | from types import SimpleNamespace
18 |
19 |
20 | class GrapherBase(nn.Module, PyTorchModelHubMixin):
21 | def __init__(self, config):
22 | super().__init__()
23 |
24 | self.config = config
25 | self.data_proc = GrapherData(config)
26 |
27 | if not hasattr(config, 'token_splitter'):
28 | self.token_splitter = WhitespaceTokenSplitter()
29 | elif self.config.token_splitter == "spacy":
30 | lang = getattr(config, 'token_splitter_lang', None)
31 | self.token_splitter = SpaCyTokenSplitter(lang=lang)
32 | elif self.config.token_splitter == "mecab-ko":
33 | self.token_splitter = MecabKoTokenSplitter()
34 |
35 | @abstractmethod
36 | def forward(self, x):
37 | pass
38 |
39 | def adjust_logits(self, logits, keep):
40 | """Adjust logits based on the keep tensor."""
41 | keep = torch.sigmoid(keep)
42 | keep = (keep > 0.5).unsqueeze(-1).float()
43 | adjusted_logits = logits + (1 - keep) * -1e9
44 | return adjusted_logits
45 |
46 | def predict(self, x, threshold=0.5, output_confidence=False):
47 | """Predict entities and relations."""
48 | out = self.forward(x, prediction_mode=True)
49 |
50 | # Adjust relation and entity logits
51 | out["entity_logits"] = self.adjust_logits(out["entity_logits"], out["keep_ent"])
52 | out["relation_logits"] = self.adjust_logits(out["relation_logits"], out["keep_rel"])
53 |
54 | # Get entities and relations
55 | entities, relations = er_decoder(x, out["entity_logits"], out["relation_logits"], out["topK_rel_idx"],
56 | out["max_top_k"], out["candidate_spans_idx"], threshold=threshold,
57 | output_confidence=output_confidence, token_splitter=self.token_splitter)
58 | return entities, relations
59 |
60 | def evaluate(self, test_data, threshold=0.5, batch_size=12, relation_types=None):
61 | self.eval()
62 | data_loader = self.create_dataloader(test_data, batch_size=batch_size, relation_types=relation_types,
63 | shuffle=False)
64 | device = next(self.parameters()).device
65 | all_preds = []
66 | all_trues = []
67 | for x in data_loader:
68 | for k, v in x.items():
69 | if isinstance(v, torch.Tensor):
70 | x[k] = v.to(device)
71 | batch_predictions = self.predict(x, threshold)
72 | all_preds.extend(batch_predictions)
73 | all_trues.extend(get_relation_with_span(x))
74 | evaluator = Evaluator(all_trues, all_preds)
75 | out, f1 = evaluator.evaluate()
76 | return out, f1
77 |
78 | def create_dataloader(self, data, entity_types=None, **kwargs) -> DataLoader:
79 | return self.data_proc.create_dataloader(data, entity_types, **kwargs)
80 |
81 | def save_pretrained(
82 | self,
83 | save_directory: Union[str, Path],
84 | *,
85 | config: Optional[Union[dict, "DataclassInstance"]] = None,
86 | repo_id: Optional[str] = None,
87 | push_to_hub: bool = False,
88 | **push_to_hub_kwargs,
89 | ) -> Optional[str]:
90 | """
91 | Save weights in local directory.
92 |
93 | Args:
94 | save_directory (`str` or `Path`):
95 | Path to directory in which the model weights and configuration will be saved.
96 | config (`dict` or `DataclassInstance`, *optional*):
97 | Model configuration specified as a key/value dictionary or a dataclass instance.
98 | push_to_hub (`bool`, *optional*, defaults to `False`):
99 | Whether or not to push your model to the Huggingface Hub after saving it.
100 | repo_id (`str`, *optional*):
101 | ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
102 | not provided.
103 | kwargs:
104 | Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method.
105 | """
106 | save_directory = Path(save_directory)
107 | save_directory.mkdir(parents=True, exist_ok=True)
108 |
109 | # save model weights/files
110 | torch.save(self.state_dict(), save_directory / "pytorch_model.bin")
111 |
112 | # save config (if provided)
113 | if config is None:
114 | config = self.config
115 | if config is not None:
116 | if isinstance(config, argparse.Namespace) or isinstance(config, SimpleNamespace):
117 | config = vars(config)
118 | (save_directory / "config.json").write_text(json.dumps(config, indent=2))
119 |
120 | # push to the Hub if required
121 | if push_to_hub:
122 | kwargs = push_to_hub_kwargs.copy() # soft-copy to avoid mutating input
123 | if config is not None: # kwarg for `push_to_hub`
124 | kwargs["config"] = config
125 | if repo_id is None:
126 | repo_id = save_directory.name # Defaults to `save_directory` name
127 | return self.push_to_hub(repo_id=repo_id, **kwargs)
128 | return None
129 |
130 | @classmethod
131 | def _from_pretrained(
132 | cls,
133 | *,
134 | model_id: str,
135 | revision: Optional[str],
136 | cache_dir: Optional[Union[str, Path]],
137 | force_download: bool,
138 | proxies: Optional[Dict],
139 | resume_download: bool,
140 | local_files_only: bool,
141 | token: Union[str, bool, None],
142 | map_location: str = "cpu",
143 | strict: bool = False,
144 | **model_kwargs,
145 | ):
146 |
147 | # 2. Newer format: Use "pytorch_model.bin" and "gliner_config.json"
148 | model_file = Path(model_id) / "pytorch_model.bin"
149 | if not model_file.exists():
150 | model_file = hf_hub_download(
151 | repo_id=model_id,
152 | filename="pytorch_model.bin",
153 | revision=revision,
154 | cache_dir=cache_dir,
155 | force_download=force_download,
156 | proxies=proxies,
157 | resume_download=resume_download,
158 | token=token,
159 | local_files_only=local_files_only,
160 | )
161 | config_file = Path(model_id) / "config.json"
162 | if not config_file.exists():
163 | config_file = hf_hub_download(
164 | repo_id=model_id,
165 | filename="config.json",
166 | revision=revision,
167 | cache_dir=cache_dir,
168 | force_download=force_download,
169 | proxies=proxies,
170 | resume_download=resume_download,
171 | token=token,
172 | local_files_only=local_files_only,
173 | )
174 | config = load_config_as_namespace(config_file)
175 | model = cls(config)
176 | state_dict = torch.load(model_file, map_location=torch.device(map_location))
177 | model.load_state_dict(state_dict, strict=strict, assign=True)
178 | model.to(map_location)
179 | return model
180 |
181 | def to(self, device):
182 | super().to(device)
183 | import flair
184 | flair.device = device
185 | return self
186 |
187 |
188 | def load_config_as_namespace(config_file):
189 | with open(config_file, "r") as f:
190 | config_dict = yaml.safe_load(f)
191 | return argparse.Namespace(**config_dict)
192 |
--------------------------------------------------------------------------------
/modules/data_processor.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from typing import List, Tuple, Dict
3 |
4 | import torch
5 | from torch.nn.utils.rnn import pad_sequence
6 | from torch.utils.data import DataLoader
7 | import random
8 |
9 |
10 | # Abstract base class for handling data processing
11 | class GrapherData(object):
12 | def __init__(self, config):
13 | self.config = config
14 |
15 | @staticmethod
16 | def get_dict(spans: List[Tuple[int, int, str]], classes_to_id: Dict[str, int]) -> Dict[Tuple[int, int], int]:
17 | """Get a dictionary of spans."""
18 | dict_tag = defaultdict(int)
19 | for span in spans:
20 | if span[-1] not in classes_to_id:
21 | continue
22 | dict_tag[(span[0], span[1])] = classes_to_id[span[-1]]
23 | return dict_tag
24 |
25 | def preprocess_spans(self, tokens: List[str], ner: List[Tuple[int, int, str]], rel: List[Tuple[int, int, str]],
26 | classes_to_id: Dict[str, int]) -> Dict:
27 | """Preprocess spans for a given text."""
28 | # Set the maximum length for tokens
29 | max_token_length = self.config.max_len
30 |
31 | # If the number of tokens exceeds the maximum length, truncate the tokens
32 | if len(tokens) > max_token_length:
33 | token_length = max_token_length
34 | tokens = tokens[:max_token_length]
35 | else:
36 | token_length = len(tokens)
37 |
38 | # Initialize a list to store span indices
39 | span_indices = []
40 | for i in range(token_length):
41 | span_indices.extend([(i, i + j) for j in range(self.config.max_width)])
42 |
43 | # Get the dictionary of labels
44 | label_dict = self.get_dict(ner, classes_to_id) if ner else defaultdict(int)
45 |
46 | # Initialize the span labels with the corresponding label from the dictionary
47 | span_labels = torch.LongTensor([label_dict[i] for i in span_indices])
48 | span_indices = torch.LongTensor(span_indices)
49 |
50 | # Create a mask for valid spans
51 | valid_span_mask = span_indices[:, 1] > token_length - 1
52 |
53 | # Mask invalid positions in the span labels
54 | span_labels = span_labels.masked_fill(valid_span_mask, -1)
55 |
56 | # Return a dictionary with the preprocessed spans
57 | return {
58 | 'tokens': tokens,
59 | 'span_idx': span_indices,
60 | 'span_label': span_labels,
61 | 'seq_length': token_length,
62 | 'entities': ner,
63 | 'relations': rel,
64 | }
65 |
66 | def create_mapping(self, types: List[str]) -> Tuple[Dict[str, int], Dict[int, str]]:
67 | """Create a mapping from type to id and id to type."""
68 | if not types:
69 | types = ["None"]
70 | type_to_id = {type: id for id, type in enumerate(types, start=1)}
71 | id_to_type = {id: type for type, id in type_to_id.items()}
72 | return type_to_id, id_to_type
73 |
74 | def batch_generate_class_mappings(self, batch_list: List[Dict]) -> Tuple[
75 | List[Dict[str, int]], List[Dict[int, str]], List[Dict[str, int]], List[Dict[int, str]]]:
76 | """Generate class mappings for a batch of data."""
77 | all_ent_to_id, all_id_to_ent, all_rel_to_id, all_id_to_rel = [], [], [], []
78 |
79 | for b in batch_list:
80 | ent_types = list(set([el[-1] for el in b['entities']]))
81 | rel_types = list(set([el[-1] for el in b['relations']]))
82 |
83 | if self.config.shuffle_types:
84 | random.shuffle(ent_types)
85 | random.shuffle(rel_types)
86 |
87 | if len(ent_types) == 0:
88 | ent_types = ["none"]
89 | if len(rel_types) == 0:
90 | rel_types = ["none"]
91 |
92 | ent_to_id, id_to_ent = self.create_mapping(ent_types)
93 | rel_to_id, id_to_rel = self.create_mapping(rel_types)
94 |
95 | all_ent_to_id.append(ent_to_id)
96 | all_id_to_ent.append(id_to_ent)
97 | all_rel_to_id.append(rel_to_id)
98 | all_id_to_rel.append(id_to_rel)
99 |
100 | return all_ent_to_id, all_id_to_ent, all_rel_to_id, all_id_to_rel
101 |
102 | def collate_fn(self, batch_list: List[Dict], entity_types: List[str] = None,
103 | relation_types: List[str] = None) -> Dict:
104 | """Collate a batch of data."""
105 |
106 | if entity_types is None or relation_types is None:
107 | ent_to_id, id_to_ent, rel_to_id, id_to_rel = self.batch_generate_class_mappings(batch_list)
108 | else:
109 | ent_to_id, id_to_ent = self.create_mapping(entity_types)
110 | rel_to_id, id_to_rel = self.create_mapping(relation_types)
111 |
112 | batch = [self.preprocess_spans(b["tokenized_text"], b["entities"], b["relations"],
113 | ent_to_id if not isinstance(ent_to_id, list) else ent_to_id[i]) for i, b in
114 | enumerate(batch_list)]
115 |
116 | return self.create_batch_dict(batch, ent_to_id, id_to_ent, rel_to_id, id_to_rel)
117 |
118 | def create_batch_dict(self, batch: List[Dict], ent_to_id: List[Dict[str, int]], id_to_ent: List[Dict[int, str]],
119 | rel_to_id: List[Dict[str, int]], id_to_rel: List[Dict[int, str]]) -> Dict:
120 | """Create a dictionary for a batch of data."""
121 |
122 | # Extract necessary information from the batch
123 | tokens = [el["tokens"] for el in batch]
124 | span_idx = pad_sequence([b["span_idx"] for b in batch], batch_first=True, padding_value=0)
125 | span_label = pad_sequence([el["span_label"] for el in batch], batch_first=True, padding_value=-1)
126 | seq_length = torch.LongTensor([el["seq_length"] for el in batch])
127 | entities = [el["entities"] for el in batch]
128 | relations = [el["relations"] for el in batch]
129 |
130 | # Create a mask for valid spans
131 | span_mask = span_label != -1
132 |
133 | # Return a dictionary with the preprocessed spans
134 | return {
135 | 'seq_length': seq_length,
136 | 'span_idx': span_idx,
137 | 'tokens': tokens,
138 | 'span_mask': span_mask,
139 | 'span_label': span_label,
140 | 'entities': entities,
141 | 'relations': relations,
142 | 'ent_to_id': ent_to_id,
143 | 'id_to_ent': id_to_ent,
144 | 'rel_to_id': rel_to_id,
145 | 'id_to_rel': id_to_rel
146 | }
147 |
148 | def create_dataloader(self, data, entity_types=None, relation_types=None, **kwargs) -> DataLoader:
149 | return DataLoader(data, collate_fn=lambda x: self.collate_fn(x, entity_types, relation_types), **kwargs)
150 |
151 |
152 | class TokenPromptProcessorTR:
153 | def __init__(self, entity_token, relation_token, sep_token):
154 | self.entity_token = entity_token
155 | self.sep_token = sep_token
156 | self.relation_token = relation_token
157 |
158 | def process(self, x, token_rep_layer, mode):
159 | if mode == "train":
160 | return self._process_train(x, token_rep_layer)
161 | elif mode == "eval":
162 | return self._process_eval(x, token_rep_layer)
163 | else:
164 | raise ValueError("Invalid mode specified. Choose 'train' or 'eval'.")
165 |
166 | def _process_train(self, x, token_rep_layer):
167 |
168 | device = next(token_rep_layer.parameters()).device
169 |
170 | new_length = x["seq_length"].clone()
171 | new_tokens = []
172 | all_len_prompt = []
173 | num_classes_all = []
174 | num_relations_all = []
175 |
176 | for i in range(len(x["tokens"])):
177 | all_types_i = list(x["ent_to_id"][i].keys())
178 | all_relations_i = list(x["rel_to_id"][i].keys())
179 | entity_prompt = []
180 | relation_prompt = []
181 | num_classes_all.append(len(all_types_i))
182 | num_relations_all.append(len(all_relations_i))
183 |
184 | for entity_type in all_types_i:
185 | entity_prompt.append(self.entity_token)
186 | entity_prompt.append(entity_type)
187 | entity_prompt.append(self.sep_token)
188 |
189 | for relation_type in all_relations_i:
190 | relation_prompt.append(self.relation_token)
191 | relation_prompt.append(relation_type)
192 | relation_prompt.append(self.sep_token)
193 |
194 | combined_prompt = entity_prompt + relation_prompt
195 | tokens_p = combined_prompt + x["tokens"][i]
196 | new_length[i] += len(combined_prompt)
197 | new_tokens.append(tokens_p)
198 | all_len_prompt.append(len(combined_prompt))
199 |
200 | max_num_classes = max(num_classes_all)
201 | entity_type_pos = torch.arange(max_num_classes).unsqueeze(0).expand(len(num_classes_all), -1).to(device)
202 | entity_type_mask = entity_type_pos < torch.tensor(num_classes_all).unsqueeze(-1).to(device)
203 |
204 | max_num_relations = max(num_relations_all)
205 | relation_type_pos = torch.arange(max_num_relations).unsqueeze(0).expand(len(num_relations_all), -1).to(device)
206 | relation_type_mask = relation_type_pos < torch.tensor(num_relations_all).unsqueeze(-1).to(device)
207 |
208 | bert_output = token_rep_layer(new_tokens, new_length)
209 | word_rep_w_prompt = bert_output["embeddings"]
210 | mask_w_prompt = bert_output["mask"]
211 |
212 | word_rep = []
213 | mask = []
214 | entity_type_rep = []
215 | relation_type_rep = []
216 |
217 | for i in range(len(x["tokens"])):
218 | prompt_entity_length = all_len_prompt[i]
219 | entity_len = 2 * len(list(x["ent_to_id"][i].keys())) + 1
220 | relation_len = 2 * len(list(x["rel_to_id"][i].keys())) + 1
221 |
222 | word_rep.append(word_rep_w_prompt[i, prompt_entity_length:new_length[i]])
223 | mask.append(mask_w_prompt[i, prompt_entity_length:new_length[i]])
224 |
225 | entity_rep = word_rep_w_prompt[i, :entity_len - 1]
226 | entity_rep = entity_rep[0::2]
227 | entity_type_rep.append(entity_rep)
228 |
229 | relation_rep = word_rep_w_prompt[i, entity_len:entity_len + relation_len - 1]
230 | relation_rep = relation_rep[0::2]
231 | relation_type_rep.append(relation_rep)
232 |
233 | word_rep = pad_sequence(word_rep, batch_first=True)
234 | mask = pad_sequence(mask, batch_first=True)
235 | entity_type_rep = pad_sequence(entity_type_rep, batch_first=True)
236 | relation_type_rep = pad_sequence(relation_type_rep, batch_first=True)
237 |
238 | return word_rep, mask, entity_type_rep, entity_type_mask, relation_type_rep, relation_type_mask
239 |
240 | def _process_eval(self, x, token_rep_layer):
241 | all_types = list(x["ent_to_id"].keys())
242 | all_relations = list(x["rel_to_id"].keys())
243 | entity_prompt = []
244 | relation_prompt = []
245 |
246 | for entity_type in all_types:
247 | entity_prompt.append(self.entity_token)
248 | entity_prompt.append(entity_type)
249 | entity_prompt.append(self.sep_token)
250 |
251 | for relation_type in all_relations:
252 | relation_prompt.append(self.relation_token)
253 | relation_prompt.append(relation_type)
254 | relation_prompt.append(self.sep_token)
255 |
256 | combined_prompt = entity_prompt + relation_prompt
257 | prompt_entity_length = len(combined_prompt)
258 | tokens_p = [combined_prompt + tokens for tokens in x["tokens"]]
259 | seq_length_p = x["seq_length"] + prompt_entity_length
260 |
261 | # Converting tokens_p to a format suitable for token_rep_layer
262 | out = token_rep_layer(tokens_p, seq_length_p)
263 |
264 | word_rep_w_prompt = out["embeddings"]
265 | mask_w_prompt = out["mask"]
266 |
267 | word_rep = word_rep_w_prompt[:, prompt_entity_length:, :]
268 | mask = mask_w_prompt[:, prompt_entity_length:]
269 |
270 | entity_type_rep = word_rep_w_prompt[:, :len(entity_prompt) - 1, :]
271 | entity_type_rep = entity_type_rep[:, 0::2, :]
272 |
273 | relation_type_rep = word_rep_w_prompt[:, len(entity_prompt):prompt_entity_length - 1, :]
274 | relation_type_rep = relation_type_rep[:, 0::2, :]
275 |
276 | return word_rep, mask, entity_type_rep, relation_type_rep
277 |
--------------------------------------------------------------------------------
/modules/evaluator.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class Evaluator:
5 | def __init__(self, all_true, all_outs):
6 | self.all_true = all_true
7 | self.all_outs = all_outs
8 |
9 | @torch.no_grad()
10 | def evaluate(self):
11 | precision, recall, f1 = calculate_metrics(self.all_true, self.all_outs, with_type=False)
12 | output_str = f"P: {precision:.2%}\tR: {recall:.2%}\tF1: {f1:.2%}\n"
13 | return output_str, f1
14 |
15 |
16 | def calculate_metrics(all_rel_true, all_rel_pred, with_type=True):
17 | flatrue = []
18 | for i, v in enumerate(all_rel_true):
19 | for j in v:
20 | try:
21 | head, tail, tp = j
22 | except:
23 | (head, tail), tp = j
24 | if with_type:
25 | flatrue.append((head, tail, tp, i))
26 | else:
27 | flatrue.append(((head[0], head[1]), (tail[0], tail[1]), tp, i))
28 |
29 | flapred = []
30 | for i, v in enumerate(all_rel_pred):
31 | for j in v:
32 | try:
33 | head, tail, tp = j
34 | except:
35 | (head, tail), tp = j
36 | if with_type:
37 | flapred.append((head, tail, tp, i))
38 | else:
39 | flapred.append(((head[0], head[1]), (tail[0], tail[1]), tp, i))
40 |
41 | TP = len(set(flatrue).intersection(set(flapred)))
42 | FP = len(flapred) - TP
43 | FN = len(flatrue) - TP
44 |
45 | if (TP + FP) == 0:
46 | prec = 0
47 | else:
48 | prec = TP / (TP + FP)
49 |
50 | if (TP + FN) == 0:
51 | rec = 0
52 | else:
53 | rec = TP / (TP + FN)
54 |
55 | # Note: It seems that you were using avg_pr and avg_re in the original code to calculate f1,
56 | # but they are not defined in the code snippet you provided.
57 | # Hence, I'm using 'prec' and 'rec' to calculate f1.
58 | if (prec + rec) == 0:
59 | f1 = 0
60 | else:
61 | f1 = 2 * (prec * rec) / (prec + rec)
62 |
63 | return prec, rec, f1
64 |
--------------------------------------------------------------------------------
/modules/filtering.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from .loss_functions import down_weight_loss
3 |
4 | class FilteringLayer(nn.Module):
5 | def __init__(self, hidden_size):
6 | super().__init__()
7 |
8 | self.filter_layer = nn.Linear(hidden_size, 2)
9 |
10 | def forward(self, embeds, label):
11 |
12 | # Extract dimensions
13 | B, num_spans, D = embeds.shape
14 |
15 | # Compute score using a predefined filtering function
16 | score = self.filter_layer(embeds) # Shape: [B, num_spans, num_classes]
17 |
18 | # Modify label to binary (0 for negative class, 1 for positive)
19 | label_m = label.clone()
20 | label_m[label_m > 0] = 1
21 |
22 | # Initialize the loss
23 | filter_loss = 0
24 | if self.training:
25 | # Compute the loss if in training mode
26 | filter_loss = down_weight_loss(score.view(B * num_spans, -1),
27 | label_m.view(-1),
28 | sample_rate=0.,
29 | is_logit=True)
30 |
31 | # Compute the filter score (difference between positive and negative class scores)
32 | filter_score = score[..., 1] - score[..., 0] # Shape: [B, num_spans]
33 |
34 | # Mask out filter scores for ignored labels
35 | filter_score = filter_score.masked_fill(label == -1, float('-inf'))
36 |
37 | if self.training:
38 | filter_score = filter_score.masked_fill(label_m > 0, float('inf'))
39 |
40 | return filter_score, filter_loss
41 |
--------------------------------------------------------------------------------
/modules/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
4 |
5 |
6 | def MLP(units, dropout, activation=nn.ReLU):
7 | units = [int(u) for u in units]
8 | assert len(units) >= 2
9 | layers = []
10 | for i in range(len(units) - 2):
11 | layers.append(nn.Linear(units[i], units[i + 1]))
12 | layers.append(activation())
13 | layers.append(nn.Dropout(dropout))
14 | layers.append(nn.Linear(units[-2], units[-1]))
15 | return nn.Sequential(*layers)
16 |
17 |
18 | def create_transformer_encoder(d_model, nhead, num_layers, ffn_mul=4, dropout=0.1):
19 | layer = nn.TransformerEncoderLayer(
20 | d_model=d_model, nhead=nhead, batch_first=True, norm_first=False, dim_feedforward=d_model * ffn_mul,
21 | dropout=dropout)
22 | encoder = nn.TransformerEncoder(layer, num_layers=num_layers)
23 | return encoder
24 |
25 |
26 | class TransLayer(nn.Module):
27 | def __init__(self, d_model, num_heads, num_layers, ffn_mul=4, dropout=0.1):
28 | super(TransLayer, self).__init__()
29 |
30 | if num_layers > 0:
31 | self.transformer_encoder = create_transformer_encoder(d_model, num_heads, num_layers, ffn_mul, dropout)
32 |
33 | def forward(self, x, mask):
34 | mask = mask == False
35 | if not hasattr(self, 'transformer_encoder'):
36 | return x
37 | else:
38 | return self.transformer_encoder(src=x, src_key_padding_mask=mask)
39 |
40 |
41 | class GraphEmbedder(nn.Module):
42 | def __init__(self, d_model):
43 | super().__init__()
44 |
45 | # Project node to half of its dimension
46 | self.project_node = nn.Linear(d_model, d_model // 2)
47 |
48 | # Initialize identifier with zeros
49 | self.identifier = nn.Parameter(torch.randn(2, d_model))
50 | nn.init.zeros_(self.identifier)
51 |
52 | def forward(self, candidate_span_rep):
53 | max_top_k = candidate_span_rep.size()[1]
54 |
55 | # Project nodes
56 | nodes = self.project_node(candidate_span_rep)
57 |
58 | # Split nodes into heads and tails
59 | heads = nodes.unsqueeze(2).expand(-1, -1, max_top_k, -1)
60 | tails = nodes.unsqueeze(1).expand(-1, max_top_k, -1, -1)
61 |
62 | # Concatenate heads and tails to form edges
63 | edges = torch.cat([heads, tails], dim=-1)
64 |
65 | # Duplicate nodes along the last dimension
66 | nodes = torch.cat([nodes, nodes], dim=-1)
67 |
68 | # Add identifier to nodes and edges
69 | nodes += self.identifier[0]
70 | edges += self.identifier[1]
71 |
72 | return nodes, edges
73 |
74 |
75 | class LstmSeq2SeqEncoder(nn.Module):
76 | def __init__(self, input_size, hidden_size, num_layers=1, dropout=0., bidirectional=False):
77 | super(LstmSeq2SeqEncoder, self).__init__()
78 |
79 | self.lstm = nn.LSTM(input_size=input_size,
80 | hidden_size=hidden_size,
81 | num_layers=num_layers,
82 | dropout=dropout,
83 | bidirectional=bidirectional,
84 | batch_first=True)
85 |
86 | def forward(self, x, mask, hidden=None):
87 | # Packing the input sequence
88 | lengths = mask.sum(dim=1).cpu()
89 | packed_x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
90 |
91 | # Passing packed sequence through LSTM
92 | packed_output, hidden = self.lstm(packed_x, hidden)
93 |
94 | # Unpacking the output sequence
95 | output, _ = pad_packed_sequence(packed_output, batch_first=True)
96 |
97 | return output
98 |
--------------------------------------------------------------------------------
/modules/loss_functions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def compute_matching_loss(logits, labels, mask, num_classes):
6 | B, _, _ = logits.size()
7 |
8 | logits_label = logits.view(-1, num_classes)
9 | labels = labels.view(-1) # (batch_size * num_spans)
10 | mask_label = labels != -1 # (batch_size * num_spans)
11 | labels.masked_fill_(~mask_label, 0) # Set the labels of padding tokens to 0
12 |
13 | # one-hot encoding
14 | labels_one_hot = torch.zeros(labels.size(0), num_classes + 1, dtype=torch.float32).to(logits.device)
15 | labels_one_hot.scatter_(1, labels.unsqueeze(1), 1) # Set the corresponding index to 1
16 | labels_one_hot = labels_one_hot[:, 1:] # Remove the first column
17 |
18 | # loss for classifier
19 | loss = F.binary_cross_entropy_with_logits(logits_label, labels_one_hot, reduction='none')
20 | # mask loss using mask (B, C)
21 | masked_loss = loss.view(B, -1, num_classes) * mask.unsqueeze(1)
22 | loss = masked_loss.view(-1, num_classes)
23 | # expand mask_label to loss
24 | mask_label = mask_label.unsqueeze(-1).expand_as(loss)
25 | # put lower loss for in labels_one_hot (2 for positive, 1 for negative)
26 |
27 | # apply mask
28 | loss = loss * mask_label.float()
29 | loss = loss.sum()
30 |
31 | return loss
32 |
33 |
34 | def down_weight_loss(logits, y, sample_rate=0.1, is_logit=True):
35 |
36 | if is_logit:
37 | loss_func = F.cross_entropy
38 | else:
39 | loss_func = F.nll_loss
40 |
41 | loss_entity = loss_func(logits, y.masked_fill(y == 0, -1), ignore_index=-1, reduction='sum')
42 | loss_non_entity = loss_func(logits, y.masked_fill(y > 0, -1), ignore_index=-1, reduction='sum')
43 |
44 | return loss_entity + loss_non_entity * (1 - sample_rate)
45 |
--------------------------------------------------------------------------------
/modules/rel_rep.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from .layers import MLP
4 |
5 |
6 | class RelationRep(nn.Module):
7 | def __init__(self, hidden_size, dropout, ffn_mul):
8 | super().__init__()
9 |
10 | self.head_mlp = nn.Linear(hidden_size, hidden_size // 2)
11 | self.tail_mlp = nn.Linear(hidden_size, hidden_size // 2)
12 | self.out_mlp = MLP([hidden_size, hidden_size * ffn_mul, hidden_size], dropout)
13 |
14 | def forward(self, span_reps):
15 | """
16 | :param span_reps [B, topk, D]
17 | :return relation_reps [B, topk, topk, D]
18 | """
19 |
20 | heads, tails = span_reps, span_reps
21 |
22 | # Apply MLPs to heads and tails
23 | heads = self.head_mlp(heads)
24 | tails = self.tail_mlp(tails)
25 |
26 | # Expand heads and tails to create relation representations
27 | heads = heads.unsqueeze(2).expand(-1, -1, heads.shape[1], -1)
28 | tails = tails.unsqueeze(1).expand(-1, tails.shape[1], -1, -1)
29 |
30 | # Concatenate heads and tails to create relation representations
31 | relation_reps = torch.cat([heads, tails], dim=-1)
32 |
33 | # Apply MLP to relation representations
34 | relation_reps = self.out_mlp(relation_reps)
35 |
36 | return relation_reps
37 |
--------------------------------------------------------------------------------
/modules/scorer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 |
5 | from .layers import MLP
6 |
7 |
8 | class ScorerLayer(nn.Module):
9 | def __init__(self, scoring_type="dot", hidden_size=768, dropout=0.1):
10 | super().__init__()
11 |
12 | self.scoring_type = scoring_type
13 |
14 | if scoring_type == "concat_proj":
15 | self.proj = MLP([hidden_size * 4, hidden_size * 4, 1], dropout)
16 | elif scoring_type == "dot_thresh":
17 | self.proj_thresh = MLP([hidden_size, hidden_size * 4, 2], dropout)
18 | self.proj_type = MLP([hidden_size, hidden_size * 4, hidden_size], dropout)
19 |
20 | def forward(self, candidate_pair_rep, rel_type_rep):
21 | # candidate_pair_rep: [B, N, D]
22 | # rel_type_rep: [B, T, D]
23 | if self.scoring_type == "dot":
24 | return torch.einsum("bnd,btd->bnt", candidate_pair_rep, rel_type_rep)
25 |
26 | elif self.scoring_type == "dot_thresh":
27 | # compute the scaling factor and threshold
28 | B, T, D = rel_type_rep.size()
29 | scaler = self.proj_thresh(rel_type_rep) # [B, T, 2]
30 | # alpha: scaling factor, beta: threshold
31 | alpha, beta = scaler[..., 0].view(B, 1, T), scaler[..., 1].view(B, 1, T)
32 | alpha = F.softplus(alpha) # reason: alpha should be positive
33 | # project the relation type representation
34 | rel_type_rep = self.proj_type(rel_type_rep) # [B, T, D]
35 | # compute the score (before sigmoid)
36 | score = torch.einsum("bnd,btd->bnt", candidate_pair_rep, rel_type_rep) # [B, N, T]
37 | return (score + beta) * alpha # [B, N, T]
38 |
39 | elif self.scoring_type == "dot_norm":
40 | score = torch.einsum("bnd,btd->bnt", candidate_pair_rep, rel_type_rep) # [B, N, T]
41 | bias_1 = self.dy_bias_type(rel_type_rep).transpose(1, 2) # [B, 1, T]
42 | bias_2 = self.dy_bias_rel(candidate_pair_rep) # [B, N, 1]
43 | return score + self.bias + bias_1 + bias_2
44 |
45 | elif self.scoring_type == "concat_proj":
46 | prod_features = candidate_pair_rep.unsqueeze(2) * rel_type_rep.unsqueeze(1) # [B, N, T, D]
47 | diff_features = candidate_pair_rep.unsqueeze(2) - rel_type_rep.unsqueeze(1) # [B, N, T, D]
48 | expanded_pair_rep = candidate_pair_rep.unsqueeze(2).repeat(1, 1, rel_type_rep.size(1), 1)
49 | expanded_rel_type_rep = rel_type_rep.unsqueeze(1).repeat(1, candidate_pair_rep.size(1), 1, 1)
50 | features = torch.cat([prod_features, diff_features, expanded_pair_rep, expanded_rel_type_rep],
51 | dim=-1) # [B, N, T, 2D]
52 | return self.proj(features).squeeze(-1)
53 |
--------------------------------------------------------------------------------
/modules/span_rep.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 |
5 |
6 | def create_projection_layer(hidden_size: int, dropout: float, out_dim: int = None) -> nn.Sequential:
7 | """
8 | Creates a projection layer with specified configurations.
9 | """
10 | if out_dim is None:
11 | out_dim = hidden_size
12 |
13 | return nn.Sequential(
14 | nn.Linear(hidden_size, int(out_dim * 3 / 2)),
15 | nn.ReLU(),
16 | nn.Dropout(dropout),
17 | nn.Linear(int(out_dim * 3 / 2), out_dim)
18 | )
19 |
20 |
21 | class SpanConvBlock(nn.Module):
22 | def __init__(self, hidden_size, kernel_size, span_mode='conv_normal'):
23 | super().__init__()
24 |
25 | if span_mode == 'conv_conv':
26 | self.conv = nn.Conv1d(hidden_size, hidden_size,
27 | kernel_size=kernel_size)
28 |
29 | # initialize the weights
30 | nn.init.kaiming_uniform_(self.conv.weight, nonlinearity='relu')
31 |
32 | elif span_mode == 'conv_max':
33 | self.conv = nn.MaxPool1d(kernel_size=kernel_size, stride=1)
34 | elif span_mode == 'conv_mean' or span_mode == 'conv_sum':
35 | self.conv = nn.AvgPool1d(kernel_size=kernel_size, stride=1)
36 |
37 | self.span_mode = span_mode
38 |
39 | self.pad = kernel_size - 1
40 |
41 | def forward(self, x):
42 |
43 | x = torch.einsum('bld->bdl', x)
44 |
45 | if self.pad > 0:
46 | x = F.pad(x, (0, self.pad), "constant", 0)
47 |
48 | x = self.conv(x)
49 |
50 | if self.span_mode == "conv_sum":
51 | x = x * (self.pad + 1)
52 |
53 | return torch.einsum('bdl->bld', x)
54 |
55 |
56 | class SpanConv(nn.Module):
57 | def __init__(self, hidden_size, max_width, span_mode):
58 | super().__init__()
59 |
60 | kernels = [i + 2 for i in range(max_width - 1)]
61 |
62 | self.convs = nn.ModuleList()
63 |
64 | for kernel in kernels:
65 | self.convs.append(SpanConvBlock(hidden_size, kernel, span_mode))
66 |
67 | self.project = nn.Sequential(
68 | nn.ReLU(),
69 | nn.Linear(hidden_size, hidden_size)
70 | )
71 |
72 | def forward(self, x, *args):
73 |
74 | span_reps = [x]
75 |
76 | for conv in self.convs:
77 | h = conv(x)
78 | span_reps.append(h)
79 |
80 | span_reps = torch.stack(span_reps, dim=-2)
81 |
82 | return self.project(span_reps)
83 |
84 |
85 | class ConvShare(nn.Module):
86 | def __init__(self, hidden_size, max_width):
87 | super().__init__()
88 |
89 | self.max_width = max_width
90 |
91 | self.conv_weight = nn.Parameter(
92 | torch.randn(hidden_size, hidden_size, max_width))
93 |
94 | nn.init.kaiming_uniform_(self.conv_weight, nonlinearity='relu')
95 |
96 | self.project = nn.Sequential(
97 | nn.ReLU(),
98 | nn.Linear(hidden_size, hidden_size)
99 | )
100 |
101 | def forward(self, x, *args):
102 | span_reps = []
103 |
104 | x = torch.einsum('bld->bdl', x)
105 |
106 | for i in range(self.max_width):
107 | pad = i
108 | x_i = F.pad(x, (0, pad), "constant", 0)
109 | conv_w = self.conv_weight[:, :, :i + 1]
110 | out_i = F.conv1d(x_i, conv_w)
111 | span_reps.append(out_i.transpose(-1, -2))
112 |
113 | out = torch.stack(span_reps, dim=-2)
114 |
115 | return self.project(out)
116 |
117 |
118 | def extract_elements(sequence, indices):
119 | B, L, D = sequence.shape
120 | K = indices.shape[1]
121 |
122 | # Expand indices to [B, K, D]
123 | expanded_indices = indices.unsqueeze(2).expand(-1, -1, D)
124 |
125 | # Gather the elements
126 | extracted_elements = torch.gather(sequence, 1, expanded_indices)
127 |
128 | return extracted_elements
129 |
130 |
131 | class SpanMarker(nn.Module):
132 |
133 | def __init__(self, hidden_size, max_width, dropout=0.4):
134 | super().__init__()
135 |
136 | self.max_width = max_width
137 |
138 | self.project_start = nn.Sequential(
139 | nn.Linear(hidden_size, hidden_size * 2, bias=True),
140 | nn.ReLU(),
141 | nn.Dropout(dropout),
142 | nn.Linear(hidden_size * 2, hidden_size, bias=True),
143 | )
144 |
145 | self.project_end = nn.Sequential(
146 | nn.Linear(hidden_size, hidden_size * 2, bias=True),
147 | nn.ReLU(),
148 | nn.Dropout(dropout),
149 | nn.Linear(hidden_size * 2, hidden_size, bias=True),
150 | )
151 |
152 | self.out_project = nn.Linear(hidden_size * 2, hidden_size, bias=True)
153 |
154 | def forward(self, h, span_idx):
155 | # h of shape [B, L, D]
156 | # query_seg of shape [D, max_width]
157 |
158 | B, L, D = h.size()
159 |
160 | # project start and end
161 | start_rep = self.project_start(h)
162 | end_rep = self.project_end(h)
163 |
164 | start_span_rep = extract_elements(start_rep, span_idx[:, :, 0])
165 | end_span_rep = extract_elements(end_rep, span_idx[:, :, 1])
166 |
167 | # concat start and end
168 | cat = torch.cat([start_span_rep, end_span_rep], dim=-1).relu()
169 |
170 | # project
171 | cat = self.out_project(cat)
172 |
173 | # reshape
174 | return cat.view(B, L, self.max_width, D)
175 |
176 |
177 | class SpanMarkerV0(nn.Module):
178 | """
179 | Marks and projects span endpoints using an MLP.
180 | """
181 |
182 | def __init__(self, hidden_size: int, max_width: int, dropout: float = 0.4):
183 | super().__init__()
184 | self.max_width = max_width
185 |
186 | self.project_start = create_projection_layer(hidden_size, dropout)
187 | self.project_end = create_projection_layer(hidden_size, dropout)
188 | self.out_project = create_projection_layer(hidden_size * 2, dropout, hidden_size)
189 |
190 | def forward(self, h: torch.Tensor, span_idx: torch.Tensor) -> torch.Tensor:
191 | B, L, D = h.size()
192 |
193 | start_rep = self.project_start(h)
194 | end_rep = self.project_end(h)
195 |
196 | start_span_rep = extract_elements(start_rep, span_idx[:, :, 0])
197 | end_span_rep = extract_elements(end_rep, span_idx[:, :, 1])
198 |
199 | cat = torch.cat([start_span_rep, end_span_rep], dim=-1).relu()
200 |
201 | return self.out_project(cat).view(B, L, self.max_width, D)
202 |
203 |
204 | class SpanRepLayer(nn.Module):
205 | """
206 | Various span representation approaches
207 | """
208 |
209 | def __init__(self, hidden_size, max_width, span_mode, **kwargs):
210 | super().__init__()
211 |
212 | if span_mode == 'marker':
213 | self.span_rep_layer = SpanMarker(hidden_size, max_width, **kwargs)
214 | elif span_mode == 'markerV0':
215 | self.span_rep_layer = SpanMarkerV0(hidden_size, max_width, **kwargs)
216 | elif span_mode == 'conv_conv':
217 | self.span_rep_layer = SpanConv(
218 | hidden_size, max_width, span_mode='conv_conv')
219 | elif span_mode == 'conv_max':
220 | self.span_rep_layer = SpanConv(
221 | hidden_size, max_width, span_mode='conv_max')
222 | elif span_mode == 'conv_mean':
223 | self.span_rep_layer = SpanConv(
224 | hidden_size, max_width, span_mode='conv_mean')
225 | elif span_mode == 'conv_sum':
226 | self.span_rep_layer = SpanConv(
227 | hidden_size, max_width, span_mode='conv_sum')
228 | elif span_mode == 'conv_share':
229 | self.span_rep_layer = ConvShare(hidden_size, max_width)
230 | else:
231 | raise ValueError(f'Unknown span mode {span_mode}')
232 |
233 | def forward(self, x, *args):
234 |
235 | return self.span_rep_layer(x, *args)
236 |
--------------------------------------------------------------------------------
/modules/token_rep.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import torch
4 | from flair.data import Sentence
5 | from flair.embeddings import TransformerWordEmbeddings
6 | from torch import nn
7 | from torch.nn.utils.rnn import pad_sequence
8 |
9 |
10 | class TokenRepLayer(nn.Module):
11 | def __init__(self, model_name: str = "bert-base-cased", fine_tune: bool = True, subtoken_pooling: str = "first",
12 | hidden_size: int = 768,
13 | add_tokens=["[SEP]", "[ENT]"]
14 | ):
15 | super().__init__()
16 |
17 | self.bert_layer = TransformerWordEmbeddings(
18 | model_name,
19 | fine_tune=fine_tune,
20 | subtoken_pooling=subtoken_pooling,
21 | allow_long_sentences=True
22 | )
23 |
24 | # add tokens to vocabulary
25 | self.bert_layer.tokenizer.add_tokens(add_tokens)
26 |
27 | # resize token embeddings
28 | self.bert_layer.model.resize_token_embeddings(len(self.bert_layer.tokenizer))
29 |
30 | bert_hidden_size = self.bert_layer.embedding_length
31 |
32 | if hidden_size != bert_hidden_size:
33 | self.projection = nn.Linear(bert_hidden_size, hidden_size)
34 |
35 | def forward(self, tokens: List[List[str]], lengths: torch.Tensor):
36 | token_embeddings = self.compute_word_embedding(tokens)
37 |
38 | if hasattr(self, "projection"):
39 | token_embeddings = self.projection(token_embeddings)
40 |
41 | B = len(lengths)
42 | max_length = lengths.max()
43 | mask = (torch.arange(max_length).view(1, -1).repeat(B, 1) < lengths.cpu().unsqueeze(1)).to(
44 | token_embeddings.device).long()
45 | return {"embeddings": token_embeddings, "mask": mask}
46 |
47 | def compute_word_embedding(self, tokens):
48 | sentences = [Sentence(i) for i in tokens]
49 | self.bert_layer.embed(sentences)
50 | token_embeddings = pad_sequence([torch.stack([t.embedding for t in k]) for k in sentences], batch_first=True)
51 | return token_embeddings
52 |
--------------------------------------------------------------------------------
/modules/token_splitter.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 |
4 | class TokenSplitterBase():
5 | def __init__(self):
6 | pass
7 |
8 | def __call__(self, text) -> (str, int, int):
9 | pass
10 |
11 |
12 | class WhitespaceTokenSplitter(TokenSplitterBase):
13 | def __init__(self):
14 | self.whitespace_pattern = re.compile(r'\w+(?:[-_]\w+)*|\S')
15 |
16 | def __call__(self, text):
17 | for match in self.whitespace_pattern.finditer(text):
18 | yield match.group(), match.start(), match.end()
19 |
20 |
21 | class SpaCyTokenSplitter(TokenSplitterBase):
22 | def __init__(self, lang=None):
23 | try:
24 | import spacy # noqa
25 | except ModuleNotFoundError as error:
26 | raise error.__class__(
27 | "Please install spacy with: `pip install spacy`"
28 | )
29 | if lang is None:
30 | lang = 'en' # Default to English if no language is specified
31 | self.nlp = spacy.blank(lang)
32 |
33 | def __call__(self, text):
34 | doc = self.nlp(text)
35 | for token in doc:
36 | yield token.text, token.idx, token.idx + len(token.text)
37 |
38 |
39 | class MecabKoTokenSplitter(TokenSplitterBase):
40 | def __init__(self):
41 | try:
42 | import mecab # noqa
43 | except ModuleNotFoundError as error:
44 | raise error.__class__(
45 | "Please install python-mecab-ko with: `pip install python-mecab-ko`"
46 | )
47 | self.tagger = mecab.MeCab()
48 |
49 | def __call__(self, text):
50 | tokens = self.tagger.morphs(text)
51 |
52 | last_idx = 0
53 | for morph in tokens:
54 | start_idx = text.find(morph, last_idx)
55 | end_idx = start_idx + len(morph)
56 | last_idx = end_idx
57 | yield morph, start_idx, end_idx
--------------------------------------------------------------------------------
/modules/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def decode_relations(id_to_rel, logits, pair_indices, max_pairs, span_indices, threshold=0.5, output_confidence=False):
6 | # Apply sigmoid function to logits
7 | probabilities = torch.sigmoid(logits)
8 |
9 | # Initialize list of relations
10 | relations = [[] for _ in range(len(logits))]
11 |
12 | # Get indices where probability is greater than threshold
13 | above_threshold_indices = (probabilities > threshold).nonzero(as_tuple=True)
14 |
15 | # Iterate over indices where probability is greater than threshold
16 | for batch_idx, position, class_idx in zip(*above_threshold_indices):
17 | # Get relation label
18 | label = id_to_rel[class_idx.item() + 1]
19 |
20 | # Get predicted pair index
21 | predicted_pair_idx = pair_indices[batch_idx, position].item()
22 |
23 | # Unravel predicted pair index into head and tail
24 | head_idx, tail_idx = np.unravel_index(predicted_pair_idx, (max_pairs, max_pairs))
25 |
26 | # Convert head and tail indices to tuples
27 | head = tuple(span_indices[batch_idx, head_idx].tolist())
28 | tail = tuple(span_indices[batch_idx, tail_idx].tolist())
29 |
30 | # Get confidence
31 | confidence = probabilities[batch_idx, position, class_idx].item()
32 |
33 | # Append relation to list
34 | if output_confidence:
35 | relations[batch_idx.item()].append((head, tail, label, confidence))
36 | else:
37 | relations[batch_idx.item()].append((head, tail, label))
38 |
39 | return relations
40 |
41 |
42 | def decode_entities(id_to_ent, logits, span_indices, threshold=0.5, output_confidence=False):
43 | # Apply sigmoid function to logits
44 | probabilities = torch.sigmoid(logits)
45 |
46 | # Initialize list of entities
47 | entities = []
48 |
49 | # Get indices where probability is greater than threshold
50 | above_threshold_indices = (probabilities > threshold).nonzero(as_tuple=True)
51 |
52 | # Iterate over indices where probability is greater than threshold
53 | for batch_idx, position, class_idx in zip(*above_threshold_indices):
54 | # Get entity label
55 | label = id_to_ent[class_idx.item() + 1]
56 |
57 | # Get confidence
58 | confidence = probabilities[batch_idx, position, class_idx].item()
59 |
60 | # Append entity to list
61 | if output_confidence:
62 | entities.append((tuple(span_indices[batch_idx, position].tolist()), label, confidence))
63 | else:
64 | entities.append((tuple(span_indices[batch_idx, position].tolist()), label))
65 |
66 | return entities
67 |
68 |
69 | def er_decoder(x, entity_logits, rel_logits, topk_pair_idx, max_top_k, candidate_spans_idx, threshold=0.5,
70 | output_confidence=False, token_splitter=None):
71 | entities = decode_entities(x["id_to_ent"], entity_logits, candidate_spans_idx, threshold, output_confidence)
72 | relations = decode_relations(x["id_to_rel"], rel_logits, topk_pair_idx, max_top_k, candidate_spans_idx, threshold,
73 | output_confidence)
74 | return entities, relations
75 |
76 |
77 | def get_relation_with_span(x):
78 | entities, relations = x['entities'], x['relations']
79 | B = len(entities)
80 | relation_with_span = [[] for i in range(B)]
81 | for i in range(B):
82 | rel_i = relations[i]
83 | ent_i = entities[i]
84 | for rel in rel_i:
85 | act = (ent_i[rel[0]], ent_i[rel[1]], rel[2])
86 | relation_with_span[i].append(act)
87 | return relation_with_span
88 |
89 |
90 | def get_ground_truth_relations(x, candidate_spans_idx, candidate_span_label):
91 | B, max_top_k = candidate_span_label.shape
92 |
93 | relation_classes = torch.zeros((B, max_top_k, max_top_k), dtype=torch.long, device=candidate_spans_idx.device)
94 |
95 | # Populate relation classes
96 | for i in range(B):
97 | rel_i = x["relations"][i]
98 | ent_i = x["entities"][i]
99 |
100 | new_heads, new_tails, new_rel_type = [], [], []
101 |
102 | # Loop over the relations and entities to populate initial lists
103 | for k in rel_i:
104 | heads_i = [ent_i[k[0]][0], ent_i[k[0]][1]]
105 | tails_i = [ent_i[k[1]][0], ent_i[k[1]][1]]
106 | type_i = k[2]
107 | new_heads.append(heads_i)
108 | new_tails.append(tails_i)
109 | new_rel_type.append(type_i)
110 |
111 | # Update the original lists
112 | heads_, tails_, rel_type = new_heads, new_tails, new_rel_type
113 |
114 | # idx of candidate spans
115 | cand_i = candidate_spans_idx[i].tolist()
116 |
117 | for heads_i, tails_i, type_i in zip(heads_, tails_, rel_type):
118 |
119 | flag = False
120 | if isinstance(x["rel_to_id"], dict):
121 | if type_i in x["rel_to_id"]:
122 | flag = True
123 | elif isinstance(x["rel_to_id"], list):
124 | if type_i in x["rel_to_id"][i]:
125 | flag = True
126 |
127 | if heads_i in cand_i and tails_i in cand_i and flag:
128 | idx_head = cand_i.index(heads_i)
129 | idx_tail = cand_i.index(tails_i)
130 |
131 | if isinstance(x["rel_to_id"], list):
132 | relation_classes[i, idx_head, idx_tail] = x["rel_to_id"][i][type_i]
133 | elif isinstance(x["rel_to_id"], dict):
134 | relation_classes[i, idx_head, idx_tail] = x["rel_to_id"][type_i]
135 |
136 | # flat relation classes
137 | relation_classes = relation_classes.view(-1, max_top_k * max_top_k)
138 |
139 | # put to -1 class where corresponding candidate_span_label is -1 (for both head and tail)
140 | head_candidate_span_label = candidate_span_label.view(B, max_top_k, 1).repeat(1, 1, max_top_k).view(B, -1)
141 | tail_candidate_span_label = candidate_span_label.view(B, 1, max_top_k).repeat(1, max_top_k, 1).view(B, -1)
142 |
143 | relation_classes.masked_fill_(head_candidate_span_label.view(B, max_top_k * max_top_k) == -1, -1) # head
144 | relation_classes.masked_fill_(tail_candidate_span_label.view(B, max_top_k * max_top_k) == -1, -1) # tail
145 |
146 | return relation_classes
147 |
148 |
149 | def get_candidates(sorted_idx, tensor_elem, topk=10):
150 | # sorted_idx [B, num_spans]
151 | # tensor_elem [B, num_spans, D] or [B, num_spans]
152 |
153 | sorted_topk_idx = sorted_idx[:, :topk]
154 |
155 | if len(tensor_elem.shape) == 3:
156 | B, num_spans, D = tensor_elem.shape
157 | topk_tensor_elem = tensor_elem.gather(1, sorted_topk_idx.unsqueeze(-1).expand(-1, -1, D))
158 | else:
159 | # [B, topk]
160 | topk_tensor_elem = tensor_elem.gather(1, sorted_topk_idx)
161 |
162 | return topk_tensor_elem, sorted_topk_idx
163 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import re
5 | from types import SimpleNamespace
6 |
7 | from tqdm import tqdm
8 | from transformers import (
9 | get_cosine_schedule_with_warmup,
10 | get_linear_schedule_with_warmup,
11 | get_constant_schedule_with_warmup,
12 | get_polynomial_decay_schedule_with_warmup,
13 | get_inverse_sqrt_schedule,
14 | )
15 | import torch
16 | import torch.distributed as dist
17 | import torch.multiprocessing as mp
18 | from torch.nn.parallel import DistributedDataParallel as DDP
19 | from torch.utils.data import DataLoader, TensorDataset
20 | from torch.utils.data.distributed import DistributedSampler
21 |
22 | from model import GraphER
23 | from modules.base import load_config_as_namespace
24 | #from modules.run_evaluation import get_for_all_path
25 |
26 |
27 | def save_top_k_checkpoints(model: GraphER, save_path: str, checkpoint: int, top_k: int = 5):
28 | """
29 | Save the top-k checkpoints (latest k checkpoints) of a model and tokenizer.
30 |
31 | Parameters:
32 | model (GraphER): The model to save.
33 | save_path (str): The directory path to save the checkpoints.
34 | top_k (int): The number of top checkpoints to keep. Defaults to 5.
35 | """
36 | # Save the current model and tokenizer
37 | if isinstance(model, DDP):
38 | model.module.save_pretrained(os.path.join(save_path, str(checkpoint)))
39 | else:
40 | model.save_pretrained(os.path.join(save_path, str(checkpoint)))
41 |
42 | # List all files in the directory
43 | files = os.listdir(save_path)
44 |
45 | # Filter files to keep only the model checkpoints
46 | checkpoint_folders = [file for file in files if re.search(r'model_\d+', file)]
47 |
48 | # Sort checkpoint files by modification time (latest first)
49 | checkpoint_folders.sort(key=lambda x: os.path.getmtime(os.path.join(save_path, x)), reverse=True)
50 |
51 | # Keep only the top-k checkpoints
52 | for checkpoint_folder in checkpoint_folders[top_k:]:
53 | checkpoint_folder = os.path.join(save_path, checkpoint_folder)
54 | checkpoint_files = [os.path.join(checkpoint_folder, f) for f in os.listdir(checkpoint_folder)]
55 | for file in checkpoint_files:
56 | os.remove(file)
57 | os.rmdir(os.path.join(checkpoint_folder))
58 |
59 |
60 | class Trainer:
61 | def __init__(self, config, allow_distributed, device='cuda'):
62 | self.config = config
63 | self.lr_encoder = float(self.config.lr_encoder)
64 | self.lr_others = float(self.config.lr_others)
65 |
66 | self.device = device
67 |
68 | if config.prev_path != "none": # fine-tuning mode
69 | self.model_config = SimpleNamespace(
70 | max_types=config.max_types,
71 | max_len=config.max_len,
72 | max_top_k=config.max_top_k,
73 | add_top_k=config.add_top_k,
74 | shuffle_types=config.shuffle_types,
75 | random_drop=config.random_drop,
76 | max_neg_type_ratio=config.max_neg_type_ratio,
77 | max_ent_types=config.fine_tune,
78 | max_rel_types=config.max_rel_types,
79 | )
80 | else:
81 | self.model_config = SimpleNamespace(
82 | model_name=config.model_name,
83 | name=config.name,
84 | max_width=config.max_width,
85 | hidden_size=config.hidden_size,
86 | dropout=config.dropout,
87 | fine_tune=config.fine_tune,
88 | subtoken_pooling=config.subtoken_pooling,
89 | span_mode=config.span_mode,
90 | max_types=config.max_types,
91 | max_len=config.max_len,
92 | num_heads=config.num_heads,
93 | num_transformer_layers=config.num_transformer_layers,
94 | ffn_mul=config.ffn_mul,
95 | scorer=config.scorer,
96 | max_top_k=config.max_top_k,
97 | add_top_k=config.add_top_k,
98 | shuffle_types=config.shuffle_types,
99 | random_drop=config.random_drop,
100 | max_neg_type_ratio=config.max_neg_type_ratio,
101 | max_ent_types=config.fine_tune,
102 | max_rel_types=config.max_rel_types,
103 | )
104 |
105 | self.allow_distributed = allow_distributed
106 |
107 | def setup_distributed(self, rank, world_size):
108 | os.environ['MASTER_ADDR'] = 'localhost'
109 | os.environ['MASTER_PORT'] = '12356'
110 | torch.cuda.set_device(rank)
111 | dist.init_process_group("nccl", rank=rank, world_size=world_size)
112 |
113 | def cleanup_distributed(self):
114 | dist.destroy_process_group()
115 |
116 | def setup_model_and_optimizer(self, rank=None, device=None):
117 | if device is None:
118 | device = self.device
119 | if self.config.prev_path != "none":
120 | model = GraphER.from_pretrained(self.config.prev_path).to(device)
121 |
122 | # some parameters of model.config are not overwritten by the config file
123 | # other than these are overwritten
124 | keep_params = ['model_name', 'name', 'max_width', 'hidden_size', 'dropout', 'subtoken_pooling', 'span_mode',
125 | "fine_tune", "max_types", "max_len", "num_heads", "num_transformer_layers", "ffn_mul",
126 | "scorer"]
127 |
128 | for param in keep_params:
129 | original_value = getattr(model.config, param)
130 | setattr(self.model_config, param, original_value)
131 |
132 | model.config = self.model_config
133 | else:
134 | model = GraphER(self.model_config).to(device)
135 |
136 | if rank is not None:
137 | model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=False)
138 | optimizer = model.module.get_optimizer(self.lr_encoder, self.lr_others, freeze_token_rep=self.config.freeze_token_rep)
139 | else:
140 | optimizer = model.get_optimizer(self.lr_encoder, self.lr_others, freeze_token_rep=self.config.freeze_token_rep)
141 | return model, optimizer
142 |
143 | def train_dist(self, rank, world_size, dataset):
144 | # Init distributed process group
145 | self.setup_distributed(rank, world_size)
146 |
147 | device = f'cuda:{rank}'
148 |
149 | model, optimizer = self.setup_model_and_optimizer(rank, device=device)
150 |
151 | sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False)
152 |
153 | train_loader = model.module.create_dataloader(dataset, batch_size=self.config.train_batch_size, shuffle=False,
154 | sampler=sampler)
155 |
156 | num_steps = self.config.num_steps // world_size
157 |
158 | self.train(model=model, optimizer=optimizer, train_loader=train_loader,
159 | num_steps=num_steps, device=device, rank=rank)
160 |
161 | self.cleanup_distributed()
162 |
163 | def init_scheduler(self, scheduler_type, optimizer, num_warmup_steps, num_steps):
164 | if scheduler_type == "cosine":
165 | scheduler = get_cosine_schedule_with_warmup(
166 | optimizer,
167 | num_warmup_steps=num_warmup_steps,
168 | num_training_steps=num_steps
169 | )
170 | elif scheduler_type == "linear":
171 | scheduler = get_linear_schedule_with_warmup(
172 | optimizer,
173 | num_warmup_steps=num_warmup_steps,
174 | num_training_steps=num_steps
175 | )
176 | elif scheduler_type == "constant":
177 | scheduler = get_constant_schedule_with_warmup(
178 | optimizer,
179 | num_warmup_steps=num_warmup_steps,
180 | )
181 | elif scheduler_type == "polynomial":
182 | scheduler = get_polynomial_decay_schedule_with_warmup(
183 | optimizer,
184 | num_warmup_steps=num_warmup_steps,
185 | num_training_steps=num_steps
186 | )
187 | elif scheduler_type == "inverse_sqrt":
188 | scheduler = get_inverse_sqrt_schedule(
189 | optimizer,
190 | num_warmup_steps=num_warmup_steps,
191 | )
192 | else:
193 | raise ValueError(
194 | f"Invalid scheduler_type value: '{scheduler_type}' \n Supported scheduler types: 'cosine', 'linear', 'constant', 'polynomial', 'inverse_sqrt'"
195 | )
196 | return scheduler
197 |
198 | def train(self, model, optimizer, train_loader, num_steps, device='cuda', rank=None):
199 | model.train()
200 | pbar = tqdm(range(num_steps))
201 |
202 | warmup_ratio = self.config.warmup_ratio
203 | eval_every = self.config.eval_every
204 | save_total_limit = self.config.save_total_limit
205 | log_dir = self.config.log_dir
206 | val_data_dir = self.config.val_data_dir
207 |
208 | num_warmup_steps = int(num_steps * warmup_ratio) if warmup_ratio < 1 else int(warmup_ratio)
209 |
210 | scheduler = self.init_scheduler(self.config.scheduler_type, optimizer, num_warmup_steps, num_steps)
211 | iter_train_loader = iter(train_loader)
212 | scaler = torch.cuda.amp.GradScaler()
213 |
214 | for step in pbar:
215 | optimizer.zero_grad()
216 |
217 | try:
218 | x = next(iter_train_loader)
219 | except StopIteration:
220 | iter_train_loader = iter(train_loader)
221 | x = next(iter_train_loader)
222 |
223 | for k, v in x.items():
224 | if isinstance(v, torch.Tensor):
225 | x[k] = v.to(device)
226 |
227 | with torch.cuda.amp.autocast(dtype=torch.float16):
228 | loss = model(x)
229 |
230 | if torch.isnan(loss).any():
231 | print("Warning: NaN loss detected")
232 | continue
233 |
234 | scaler.scale(loss).backward()
235 | scaler.step(optimizer)
236 | scaler.update()
237 | scheduler.step()
238 |
239 | description = f"step: {step} | epoch: {step // len(train_loader)} | loss: {loss.item():.2f}"
240 | pbar.set_description(description)
241 |
242 | if (step + 1) % eval_every == 0:
243 | if rank is None or rank == 0:
244 | checkpoint = f'model_{step + 1}'
245 | save_top_k_checkpoints(model, log_dir, checkpoint, save_total_limit)
246 | #if val_data_dir != "none":
247 | #get_for_all_path(model, step, log_dir, val_data_dir)
248 | model.train()
249 |
250 | def run(self):
251 | with open(self.config.train_data, 'r') as f:
252 | data = json.load(f)
253 |
254 | if torch.cuda.device_count() > 1 and self.allow_distributed:
255 | world_size = torch.cuda.device_count()
256 | mp.spawn(self.train_dist, args=(world_size, data), nprocs=world_size, join=True)
257 | else:
258 | model, optimizer = self.setup_model_and_optimizer()
259 |
260 | train_loader = model.create_dataloader(data, batch_size=self.config.train_batch_size, shuffle=True)
261 |
262 | self.train(model, optimizer, train_loader, num_steps=self.config.num_steps, device=self.device)
263 |
264 |
265 | def create_parser():
266 | parser = argparse.ArgumentParser(description="grapher")
267 | parser.add_argument("--config", type=str, default="config.yaml", help="Path to config file")
268 | parser.add_argument('--log_dir', type=str, default='logs', help='Path to the log directory')
269 | parser.add_argument('--allow_distributed', type=bool, default=False,
270 | help='Whether to allow distributed training if there are more than one GPU available')
271 | return parser
272 |
273 |
274 | if __name__ == "__main__":
275 | parser = create_parser()
276 | args = parser.parse_args()
277 | config = load_config_as_namespace(args.config)
278 | config.log_dir = args.log_dir
279 |
280 | trainer = Trainer(config, allow_distributed=args.allow_distributed,
281 | device='cuda' if torch.cuda.is_available() else 'cpu')
282 | trainer.run()
283 |
--------------------------------------------------------------------------------