├── .gitignore ├── LICENSE ├── models ├── utils.py └── prism.py ├── data ├── DWIE │ ├── label_map.json │ └── rel_desc.json ├── DocRED │ ├── label_map.json │ ├── rel_info.json │ └── rel_desc.json └── Re-DocRED │ ├── label_map.json │ ├── rel_info.json │ └── rel_desc.json ├── utils.py ├── scripts ├── evaluate.sh └── train.sh ├── logger.py ├── README.md ├── config.py ├── evaluate.py ├── train.py ├── evaluation.py └── dataset.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .cache/ 3 | .empty/ 4 | .checkpoints/ 5 | .ipynb_checkpoints/ 6 | .logs/ 7 | data*/ 8 | wandb/ 9 | *.out 10 | *.png 11 | *.txt 12 | *.yaml -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Minseok Choi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoTokenizer, AutoModel 2 | from dataset import load_and_cache_relations 3 | from .prism import PRISM 4 | 5 | 6 | def load_config(args): 7 | config = AutoConfig.from_pretrained( 8 | args.config_name if args.config_name else args.model_name_or_path, 9 | num_labels=args.num_labels, 10 | cache_dir=args.cache_dir if args.cache_dir else None) 11 | return config 12 | 13 | 14 | def load_tokenizer(args): 15 | tokenizer = AutoTokenizer.from_pretrained( 16 | args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, 17 | cache_dir=args.cache_dir if args.cache_dir else None) 18 | return tokenizer 19 | 20 | 21 | def load_model(args, config, tokenizer): 22 | relation_features = load_and_cache_relations(args, config, tokenizer) 23 | 24 | encoder1 = AutoModel.from_pretrained( 25 | args.model_name_or_path, 26 | from_tf=bool(".ckpt" in args.model_name_or_path), 27 | cache_dir=args.cache_dir if args.cache_dir else None) 28 | 29 | if not args.share_params: 30 | encoder2 = AutoModel.from_pretrained( 31 | args.model_name_or_path, 32 | from_tf=bool(".ckpt" in args.model_name_or_path), 33 | cache_dir=args.cache_dir if args.cache_dir else None) 34 | else: 35 | encoder2 = None 36 | 37 | model = PRISM(args, config, relation_features, encoder1, encoder2) 38 | 39 | return model.to(args.device) 40 | -------------------------------------------------------------------------------- /data/DWIE/label_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "None": 0, 3 | "institution_of": 1, 4 | "part_of": 2, 5 | "head_of": 3, 6 | "member_of": 4, 7 | "agent_of": 5, 8 | "citizen_of": 6, 9 | "citizen_of-x": 7, 10 | "head_of_state": 8, 11 | "head_of_state-x": 9, 12 | "head_of_gov": 10, 13 | "head_of_gov-x": 11, 14 | "gpe0": 12, 15 | "based_in0": 13, 16 | "based_in0-x": 14, 17 | "event_in0": 15, 18 | "minister_of": 16, 19 | "minister_of-x": 17, 20 | "in0": 18, 21 | "in0-x": 19, 22 | "based_in2": 20, 23 | "agency_of": 21, 24 | "agency_of-x": 22, 25 | "ministry_of": 23, 26 | "artifact_of": 24, 27 | "in1": 25, 28 | "agent_of-x": 26, 29 | "signed_by": 27, 30 | "appears_in": 28, 31 | "vs": 29, 32 | "won_vs": 30, 33 | "coach_of": 31, 34 | "player_of": 32, 35 | "is_meeting": 33, 36 | "created_by": 34, 37 | "spokesperson_of": 35, 38 | "event_in": 36, 39 | "product_of": 37, 40 | "in2": 38, 41 | "award_received": 39, 42 | "law_of": 40, 43 | "spouse_of": 41, 44 | "event_in2": 42, 45 | "royalty_of": 43, 46 | "gpe1": 44, 47 | "advisor_of": 45, 48 | "parent_of": 46, 49 | "child_of": 47, 50 | "based_in1": 48, 51 | "gpe2": 49, 52 | "directed_by": 50, 53 | "plays_in": 51, 54 | "character_in": 52, 55 | "present_in0": 53, 56 | "founder_of": 54, 57 | "mayor_of": 55, 58 | "based_in2-x": 56, 59 | "sanctions": 57, 60 | "sibling": 58, 61 | "brand_of": 59, 62 | "event_in0-x": 60, 63 | "played_by": 61, 64 | "artifact_of-x": 62, 65 | "based_in1-x": 63, 66 | "event_in1": 64, 67 | "publisher": 65 68 | } -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | 4 | import torch 5 | from transformers import get_linear_schedule_with_warmup, get_constant_schedule_with_warmup 6 | 7 | 8 | def set_seed(seed): 9 | random.seed(seed) 10 | np.random.seed(seed) 11 | torch.manual_seed(seed) 12 | if torch.cuda.is_available(): 13 | torch.cuda.manual_seed_all(seed) 14 | torch.backends.cudnn.deterministic = True 15 | torch.backends.cudnn.benchmark = False 16 | 17 | 18 | def seed_worker(worker_id): 19 | worker_seed = torch.initial_seed() % 2**32 20 | np.random.seed(worker_seed) 21 | random.seed(worker_seed) 22 | 23 | 24 | def init_scaler(args): 25 | return torch.cuda.amp.GradScaler(enabled=args.use_amp) 26 | 27 | 28 | def init_optimizer(args, model): 29 | new_layer = ["extractor", "bilinear"] 30 | optimizer_grouped_parameters = [ 31 | {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in new_layer)],}, 32 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in new_layer)], "lr": args.clf_lr}, 33 | ] 34 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.lr) 35 | return optimizer 36 | 37 | 38 | def init_scheduler(args, optimizer): 39 | if args.warmup_ratio > 0: 40 | scheduler = get_linear_schedule_with_warmup(optimizer, 41 | num_warmup_steps=args.warmup_steps, 42 | num_training_steps=args.total_steps) 43 | else: 44 | scheduler = get_constant_schedule_with_warmup(optimizer, 45 | num_warmup_steps=args.warmup_steps) 46 | return scheduler 47 | -------------------------------------------------------------------------------- /scripts/evaluate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_VISIBLE_DEVICES=0 3 | 4 | model_name_or_path="bert-base-cased" 5 | data_dir="data/DocRED" 6 | 7 | num_train_ratio=1 # 0.0327(3%), 0.1(10%), 1(100%) 8 | train_batch_size=4 9 | gradient_accumulation_steps=1 10 | dist_fn="cosine" # inner, cosine 11 | ent_pooler="logsumexp" # max, sum, avg, logsumexp 12 | rel_pooler="cls" # pooler, cls 13 | lr=3e-5 14 | clf_lr=1e-4 15 | temperature=0.1 16 | warmup_ratio=0.06 17 | seed=42 18 | share_params=1 # 0(false) or 1(true) 19 | log_to_file=0 # " 20 | long_seq=0 # " 21 | 22 | batch_size=$((train_batch_size * gradient_accumulation_steps)) 23 | IFS='/' read -ra x <<< $data_dir && dataset_name=${x[1]} # data_dir.split("/")[1] 24 | IFS='-' read -ra x <<< $model_name_or_path && model_type=${x[0]} # model_name_or_path.split("-")[0] 25 | if [ $share_params == 1 ] ; then enc="share" ; else enc="sep" ; fi 26 | if [ $long_seq == 1 ] ; then long="_long" ; else long="" ; fi 27 | 28 | exp="BS${batch_size}_LR${lr}_W${warmup_ratio}_T${temperature}_S${seed}${long}" 29 | train_output_dir=".checkpoints/${dataset_name}/${model_name_or_path}/${enc}/${dist_fn}/${ent_pooler}/${rel_pooler}/N${num_train_ratio}/${exp}" 30 | 31 | # HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 \ 32 | python evaluate.py \ 33 | --model_name_or_path ${model_name_or_path} \ 34 | --model_type ${model_type} \ 35 | --data_dir ${data_dir} \ 36 | --dataset_name ${dataset_name} \ 37 | --temperature ${temperature} \ 38 | --num_train_ratio ${num_train_ratio} \ 39 | --train_output_dir ${train_output_dir} \ 40 | --dist_fn ${dist_fn} \ 41 | --ent_pooler ${ent_pooler} \ 42 | --rel_pooler ${rel_pooler} \ 43 | --share_params ${share_params} \ 44 | --log_to_file ${log_to_file} \ 45 | --long_seq ${long_seq} 46 | -------------------------------------------------------------------------------- /data/DocRED/label_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "None": 0, 3 | "P6": 30, 4 | "P17": 2, 5 | "P19": 7, 6 | "P20": 25, 7 | "P22": 35, 8 | "P25": 34, 9 | "P26": 33, 10 | "P27": 5, 11 | "P30": 12, 12 | "P31": 83, 13 | "P35": 77, 14 | "P36": 32, 15 | "P37": 37, 16 | "P39": 57, 17 | "P40": 36, 18 | "P50": 70, 19 | "P54": 69, 20 | "P57": 48, 21 | "P58": 49, 22 | "P69": 26, 23 | "P86": 63, 24 | "P102": 43, 25 | "P108": 28, 26 | "P112": 85, 27 | "P118": 56, 28 | "P123": 86, 29 | "P127": 55, 30 | "P131": 3, 31 | "P136": 72, 32 | "P137": 84, 33 | "P140": 78, 34 | "P150": 4, 35 | "P155": 51, 36 | "P156": 52, 37 | "P159": 1, 38 | "P161": 23, 39 | "P162": 75, 40 | "P166": 29, 41 | "P170": 61, 42 | "P171": 91, 43 | "P172": 8, 44 | "P175": 21, 45 | "P176": 87, 46 | "P178": 40, 47 | "P179": 59, 48 | "P190": 96, 49 | "P194": 53, 50 | "P205": 73, 51 | "P206": 15, 52 | "P241": 54, 53 | "P264": 18, 54 | "P272": 50, 55 | "P276": 13, 56 | "P279": 82, 57 | "P355": 89, 58 | "P361": 31, 59 | "P364": 80, 60 | "P400": 41, 61 | "P403": 24, 62 | "P449": 62, 63 | "P463": 20, 64 | "P488": 64, 65 | "P495": 16, 66 | "P527": 19, 67 | "P551": 17, 68 | "P569": 6, 69 | "P570": 27, 70 | "P571": 9, 71 | "P576": 10, 72 | "P577": 22, 73 | "P580": 66, 74 | "P582": 67, 75 | "P585": 44, 76 | "P607": 11, 77 | "P674": 58, 78 | "P676": 68, 79 | "P706": 74, 80 | "P710": 76, 81 | "P737": 81, 82 | "P740": 45, 83 | "P749": 88, 84 | "P800": 39, 85 | "P807": 95, 86 | "P840": 71, 87 | "P937": 42, 88 | "P1001": 47, 89 | "P1056": 92, 90 | "P1198": 90, 91 | "P1336": 79, 92 | "P1344": 65, 93 | "P1365": 94, 94 | "P1366": 93, 95 | "P1376": 14, 96 | "P1412": 38, 97 | "P1441": 60, 98 | "P3373": 46 99 | } -------------------------------------------------------------------------------- /data/Re-DocRED/label_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "None": 0, 3 | "P6": 30, 4 | "P17": 2, 5 | "P19": 7, 6 | "P20": 25, 7 | "P22": 35, 8 | "P25": 34, 9 | "P26": 33, 10 | "P27": 5, 11 | "P30": 12, 12 | "P31": 83, 13 | "P35": 77, 14 | "P36": 32, 15 | "P37": 37, 16 | "P39": 57, 17 | "P40": 36, 18 | "P50": 70, 19 | "P54": 69, 20 | "P57": 48, 21 | "P58": 49, 22 | "P69": 26, 23 | "P86": 63, 24 | "P102": 43, 25 | "P108": 28, 26 | "P112": 85, 27 | "P118": 56, 28 | "P123": 86, 29 | "P127": 55, 30 | "P131": 3, 31 | "P136": 72, 32 | "P137": 84, 33 | "P140": 78, 34 | "P150": 4, 35 | "P155": 51, 36 | "P156": 52, 37 | "P159": 1, 38 | "P161": 23, 39 | "P162": 75, 40 | "P166": 29, 41 | "P170": 61, 42 | "P171": 91, 43 | "P172": 8, 44 | "P175": 21, 45 | "P176": 87, 46 | "P178": 40, 47 | "P179": 59, 48 | "P190": 96, 49 | "P194": 53, 50 | "P205": 73, 51 | "P206": 15, 52 | "P241": 54, 53 | "P264": 18, 54 | "P272": 50, 55 | "P276": 13, 56 | "P279": 82, 57 | "P355": 89, 58 | "P361": 31, 59 | "P364": 80, 60 | "P400": 41, 61 | "P403": 24, 62 | "P449": 62, 63 | "P463": 20, 64 | "P488": 64, 65 | "P495": 16, 66 | "P527": 19, 67 | "P551": 17, 68 | "P569": 6, 69 | "P570": 27, 70 | "P571": 9, 71 | "P576": 10, 72 | "P577": 22, 73 | "P580": 66, 74 | "P582": 67, 75 | "P585": 44, 76 | "P607": 11, 77 | "P674": 58, 78 | "P676": 68, 79 | "P706": 74, 80 | "P710": 76, 81 | "P737": 81, 82 | "P740": 45, 83 | "P749": 88, 84 | "P800": 39, 85 | "P807": 95, 86 | "P840": 71, 87 | "P937": 42, 88 | "P1001": 47, 89 | "P1056": 92, 90 | "P1198": 90, 91 | "P1336": 79, 92 | "P1344": 65, 93 | "P1365": 94, 94 | "P1366": 93, 95 | "P1376": 14, 96 | "P1412": 38, 97 | "P1441": 60, 98 | "P3373": 46 99 | } -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | 5 | class FileLogger: 6 | 7 | def __init__(self, output_dir, is_master=False, is_rank0=False, log_to_file=False): 8 | self.output_dir = output_dir 9 | os.makedirs(output_dir, exist_ok=True) 10 | 11 | # Log to console if rank 0, Log to console and file if master 12 | if not is_rank0: 13 | self.logger = NoOp() 14 | else: 15 | self.logger = self.get_logger(output_dir, log_to_file=(is_master and log_to_file)) 16 | 17 | 18 | def get_logger(self, output_dir, log_to_file=True): 19 | logger = logging.getLogger() 20 | logger.setLevel(logging.DEBUG) 21 | logging.getLogger("urllib3").setLevel(logging.WARNING) 22 | formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(module)s: %(message)s', 23 | datefmt='%m/%d/%Y %H:%M:%S') 24 | 25 | if log_to_file: 26 | vlog = logging.FileHandler(output_dir+'/verbose.log') 27 | vlog.setLevel(logging.INFO) 28 | vlog.setFormatter(formatter) 29 | logger.addHandler(vlog) 30 | 31 | eventlog = logging.FileHandler(output_dir+'/event.log') 32 | eventlog.setLevel(logging.WARN) 33 | eventlog.setFormatter(formatter) 34 | logger.addHandler(eventlog) 35 | 36 | debuglog = logging.FileHandler(output_dir+'/debug.log') 37 | debuglog.setLevel(logging.DEBUG) 38 | debuglog.setFormatter(formatter) 39 | logger.addHandler(debuglog) 40 | 41 | console = logging.StreamHandler() 42 | console.setFormatter(formatter) 43 | console.setLevel(logging.DEBUG) 44 | logger.addHandler(console) 45 | return logger 46 | 47 | def console(self, *args): 48 | self.logger.debug(*args) 49 | 50 | def event(self, *args): 51 | self.logger.warn(*args) 52 | 53 | def verbose(self, *args): 54 | self.logger.info(*args) 55 | 56 | 57 | # no_op method/object that accept every signature 58 | class NoOp: 59 | def __getattr__(self, *args): 60 | def no_op(*args, **kwargs): pass 61 | return no_op 62 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_VISIBLE_DEVICES=0 3 | 4 | model_name_or_path="bert-base-cased" 5 | data_dir="data/DocRED" 6 | 7 | num_train_ratio=1 # 0.0327(3%), 0.1(10%), 1(100%) 8 | train_batch_size=4 9 | eval_batch_size=8 10 | gradient_accumulation_steps=1 11 | dist_fn="cosine" # inner, cosine 12 | ent_pooler="logsumexp" # max, sum, avg, logsumexp 13 | rel_pooler="cls" # pooler, cls 14 | lr=3e-5 15 | clf_lr=1e-4 16 | temperature=0.1 17 | warmup_ratio=0.06 18 | epochs=30 19 | seed=42 20 | share_params=1 # 0(false) or 1(true) 21 | wandb_on=0 # " 22 | log_to_file=1 # " 23 | long_seq=0 # " 24 | 25 | batch_size=$((train_batch_size * gradient_accumulation_steps)) 26 | IFS='/' read -ra x <<< $data_dir && dataset_name=${x[1]} # data_dir.split("/")[1] 27 | IFS='-' read -ra x <<< $model_name_or_path && model_type=${x[0]} # model_name_or_path.split("-")[0] 28 | if [ $share_params == 1 ] ; then enc="share" ; else enc="sep" ; fi 29 | if [ $long_seq == 1 ] ; then long="_long" ; else long="" ; fi 30 | 31 | exp="BS${batch_size}_LR${lr}_W${warmup_ratio}_T${temperature}_S${seed}${long}" 32 | train_output_dir=".checkpoints/${dataset_name}/${model_name_or_path}/${enc}/${dist_fn}/${ent_pooler}/${rel_pooler}/N${num_train_ratio}/${exp}" 33 | 34 | # HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 \ 35 | python train.py \ 36 | --model_name_or_path ${model_name_or_path} \ 37 | --model_type ${model_type} \ 38 | --data_dir ${data_dir} \ 39 | --dataset_name ${dataset_name} \ 40 | --num_train_ratio ${num_train_ratio} \ 41 | --temperature ${temperature} \ 42 | --train_output_dir ${train_output_dir} \ 43 | --train_batch_size ${train_batch_size} \ 44 | --eval_batch_size ${eval_batch_size} \ 45 | --gradient_accumulation_steps ${gradient_accumulation_steps} \ 46 | --dist_fn ${dist_fn} \ 47 | --ent_pooler ${ent_pooler} \ 48 | --rel_pooler ${rel_pooler} \ 49 | --lr ${lr} \ 50 | --clf_lr ${clf_lr} \ 51 | --warmup_ratio ${warmup_ratio} \ 52 | --epochs ${epochs} \ 53 | --seed ${seed} \ 54 | --share_params ${share_params} \ 55 | --wandb_on ${wandb_on} \ 56 | --log_to_file ${log_to_file} \ 57 | --long_seq ${long_seq} 58 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PRiSM 2 | 3 | Source code for our *Findings of IJCNLP-AACL 2023* paper [PRiSM: Enhancing Low-Resource Document-Level Relation Extraction with Relation-Aware Score Calibration](https://arxiv.org/abs/2309.13869). 4 | 5 | ## Requirements 6 | 7 | - Python (tested on 3.8.16) 8 | - CUDA (tested on 11.7) 9 | - PyTorch (tested on 1.13.1) 10 | - Transformers (tested on 4.30.0) 11 | - numpy (tested on 1.22.4) 12 | - wandb 13 | - tqdm 14 | 15 | ## Datasets 16 | 17 | Datasets can be downloaded here: [DocRED](https://drive.google.com/drive/folders/1c5-0YwnoJx8NS6CV2f-NoTHR__BdkNqw), [Re-DocRED](https://github.com/tonytan48/Re-DocRED), [DWIE](https://github.com/klimzaporojets/DWIE). The expected structure of files is: 18 | 19 | ``` 20 | [working directory] 21 | |-- data 22 | | |-- DocRED 23 | | | |-- train_distant.json 24 | | | |-- train.json 25 | | | |-- dev.json 26 | | | |-- test.json 27 | | | |-- label_map.json 28 | | | |-- rel_info.json 29 | | | |-- rel_desc.json 30 | | |-- Re-DocRED 31 | | | |-- train_distant.json 32 | | | |-- train.json 33 | | | |-- dev.json 34 | | | |-- test.json 35 | | | |-- label_map.json 36 | | | |-- rel_info.json 37 | | | |-- rel_desc.json 38 | | |-- DWIE 39 | | | |-- train/ 40 | | | |-- dev/ 41 | | | |-- test/ 42 | | | |-- label_map.json 43 | | | |-- rel_desc.json 44 | ``` 45 | 46 | ## Training and Evaluation 47 | 48 | Train the model with the following command: 49 | 50 | ```bash 51 | >> bash scripts/train.sh 52 | ``` 53 | 54 | Evaluate the model with the following command: 55 | 56 | ```bash 57 | >> bash scripts/evaluate.sh 58 | ``` 59 | 60 | ## Citation 61 | 62 | If you make use of this code in your work, please kindly cite our paper: 63 | 64 | ```bibtex 65 | @inproceedings{choi2023prism, 66 | author={Choi, Minseok and Lim, Hyesu and Choo, Jaegul}, 67 | title={P{R}i{S}{M}: Enhancing Low-Resource Document-Level Relation Extraction with Relation-Aware Score Calibration}, 68 | booktitle={Proceedings of the 13th International Joint Conference on Natural Language Processing and the 3rd Conference of the Asia-Pacific Chapter of the Association for Computational Linguistics}, 69 | month={November}, 70 | year={2023}, 71 | address={Nusa Dua, Bali}, 72 | publisher={Association for Computational Linguistics}, 73 | pages={39--47}, 74 | url={https://aclanthology.org/2023.findings-ijcnlp.4} 75 | } 76 | ``` 77 | 78 | ## Acknowledgements 79 | 80 | This work was supported by Institute of Information & communications Technology Planning & Evaluation (IITP) grant funded by the Korea government (MSIT) (No.2019-0-00075, Artificial Intelligence Graduate School Program (KAIST)), the National Supercomputing Center with supercomputing resources including technical support (KSC-2022-CRE-0312), and Samsung Electronics Co., Ltd. 81 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Configurations. 3 | """ 4 | 5 | def model_args(parser): 6 | parser.add_argument("--model_type", default="bert", type=str) 7 | parser.add_argument("--model_name_or_path", default="bert-base-cased", type=str) 8 | parser.add_argument("--config_name", default="", type=str) 9 | parser.add_argument("--tokenizer_name", default="", type=str) 10 | parser.add_argument("--cache_dir", default=".cache/", type=str) 11 | parser.add_argument("--ent_pooler", default="logsumexp", type=str) 12 | parser.add_argument("--rel_pooler", default="cls", type=str) 13 | parser.add_argument("--dist_fn", default="", type=str) 14 | parser.add_argument("--temperature", default=0.1, type=float) 15 | parser.add_argument("--embedding_size", default=768, type=int) 16 | parser.add_argument("--block_size", default=64, type=int) 17 | parser.add_argument("--group_bilinear", default=True, type=bool) 18 | parser.add_argument("--share_params", default=1, type=int) 19 | parser.add_argument("--long_seq", default=0, type=int) 20 | 21 | 22 | def data_args(parser): 23 | parser.add_argument("--log_dir", type=str, default=".logs/") 24 | parser.add_argument("--dataset_name", default="DocRED", type=str) 25 | parser.add_argument("--data_dir", default="data/DocRED", type=str) 26 | parser.add_argument("--train_output_dir", default=".checkpoints/", type=str) 27 | parser.add_argument("--test_output_dir", default=".checkpoints/", type=str) 28 | parser.add_argument("--max_seq_length", default=512, type=int) 29 | parser.add_argument("--num_workers", default=8, type=int) 30 | parser.add_argument("--num_labels", default=97, type=int) 31 | parser.add_argument("--num_train_ratio", default=1.0, type=float) 32 | parser.add_argument("--mark_entities", default=True, type=bool) 33 | parser.add_argument("--log_to_file", default=0, type=int) 34 | 35 | 36 | def train_args(parser): 37 | parser.add_argument("--train_batch_size", default=4, type=int) 38 | parser.add_argument("--eval_batch_size", default=8, type=int) 39 | parser.add_argument("--gradient_accumulation_steps", default=1, type=int) 40 | parser.add_argument("--logging_steps", default=100, type=str) 41 | parser.add_argument("--lr", default=3e-5, type=float) 42 | parser.add_argument("--clf_lr", default=1e-4, type=float) 43 | parser.add_argument("--max_grad_norm", default=1.0, type=float) 44 | parser.add_argument("--warmup_ratio", default=0.06, type=float) 45 | parser.add_argument("--max_tolerance", default=5, type=int) 46 | parser.add_argument("--epochs", default=30, type=int) 47 | parser.add_argument("--seed", default=42, type=int) 48 | parser.add_argument("--use_amp", default=True, type=bool) 49 | parser.add_argument("--hide_tqdm", default=False, type=bool) 50 | parser.add_argument("--wandb_on", default=0, type=int) 51 | parser.add_argument("--resume", action="store_true") 52 | 53 | 54 | def predict_args(parser): 55 | parser.add_argument("--theta", default=0.5, type=float) 56 | parser.add_argument("--eval_batch_size", default=20, type=int) 57 | parser.add_argument("--test_batch_size", default=20, type=int) 58 | -------------------------------------------------------------------------------- /data/DocRED/rel_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "P6": "head of government", 3 | "P17": "country", 4 | "P19": "place of birth", 5 | "P20": "place of death", 6 | "P22": "father", 7 | "P25": "mother", 8 | "P26": "spouse", 9 | "P27": "country of citizenship", 10 | "P30": "continent", 11 | "P31": "instance of", 12 | "P35": "head of state", 13 | "P36": "capital", 14 | "P37": "official language", 15 | "P39": "position held", 16 | "P40": "child", 17 | "P50": "author", 18 | "P54": "member of sports team", 19 | "P57": "director", 20 | "P58": "screenwriter", 21 | "P69": "educated at", 22 | "P86": "composer", 23 | "P102": "member of political party", 24 | "P108": "employer", 25 | "P112": "founded by", 26 | "P118": "league", 27 | "P123": "publisher", 28 | "P127": "owned by", 29 | "P131": "located in the administrative territorial entity", 30 | "P136": "genre", 31 | "P137": "operator", 32 | "P140": "religion", 33 | "P150": "contains administrative territorial entity", 34 | "P155": "follows", 35 | "P156": "followed by", 36 | "P159": "headquarters location", 37 | "P161": "cast member", 38 | "P162": "producer", 39 | "P166": "award received", 40 | "P170": "creator", 41 | "P171": "parent taxon", 42 | "P172": "ethnic group", 43 | "P175": "performer", 44 | "P176": "manufacturer", 45 | "P178": "developer", 46 | "P179": "series", 47 | "P190": "sister city", 48 | "P194": "legislative body", 49 | "P205": "basin country", 50 | "P206": "located in or next to body of water", 51 | "P241": "military branch", 52 | "P264": "record label", 53 | "P272": "production company", 54 | "P276": "location", 55 | "P279": "subclass of", 56 | "P355": "subsidiary", 57 | "P361": "part of", 58 | "P364": "original language of work", 59 | "P400": "platform", 60 | "P403": "mouth of the watercourse", 61 | "P449": "original network", 62 | "P463": "member of", 63 | "P488": "chairperson", 64 | "P495": "country of origin", 65 | "P527": "has part", 66 | "P551": "residence", 67 | "P569": "date of birth", 68 | "P570": "date of death", 69 | "P571": "inception", 70 | "P576": "dissolved, abolished or demolished", 71 | "P577": "publication date", 72 | "P580": "start time", 73 | "P582": "end time", 74 | "P585": "point in time", 75 | "P607": "conflict", 76 | "P674": "characters", 77 | "P676": "lyrics by", 78 | "P706": "located on terrain feature", 79 | "P710": "participant", 80 | "P737": "influenced by", 81 | "P740": "location of formation", 82 | "P749": "parent organization", 83 | "P800": "notable work", 84 | "P807": "separated from", 85 | "P840": "narrative location", 86 | "P937": "work location", 87 | "P1001": "applies to jurisdiction", 88 | "P1056": "product or material produced", 89 | "P1198": "unemployment rate", 90 | "P1336": "territory claimed by", 91 | "P1344": "participant of", 92 | "P1365": "replaces", 93 | "P1366": "replaced by", 94 | "P1376": "capital of", 95 | "P1412": "languages spoken, written or signed", 96 | "P1441": "present in work", 97 | "P3373": "sibling" 98 | } -------------------------------------------------------------------------------- /data/Re-DocRED/rel_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "P6": "head of government", 3 | "P17": "country", 4 | "P19": "place of birth", 5 | "P20": "place of death", 6 | "P22": "father", 7 | "P25": "mother", 8 | "P26": "spouse", 9 | "P27": "country of citizenship", 10 | "P30": "continent", 11 | "P31": "instance of", 12 | "P35": "head of state", 13 | "P36": "capital", 14 | "P37": "official language", 15 | "P39": "position held", 16 | "P40": "child", 17 | "P50": "author", 18 | "P54": "member of sports team", 19 | "P57": "director", 20 | "P58": "screenwriter", 21 | "P69": "educated at", 22 | "P86": "composer", 23 | "P102": "member of political party", 24 | "P108": "employer", 25 | "P112": "founded by", 26 | "P118": "league", 27 | "P123": "publisher", 28 | "P127": "owned by", 29 | "P131": "located in the administrative territorial entity", 30 | "P136": "genre", 31 | "P137": "operator", 32 | "P140": "religion", 33 | "P150": "contains administrative territorial entity", 34 | "P155": "follows", 35 | "P156": "followed by", 36 | "P159": "headquarters location", 37 | "P161": "cast member", 38 | "P162": "producer", 39 | "P166": "award received", 40 | "P170": "creator", 41 | "P171": "parent taxon", 42 | "P172": "ethnic group", 43 | "P175": "performer", 44 | "P176": "manufacturer", 45 | "P178": "developer", 46 | "P179": "series", 47 | "P190": "sister city", 48 | "P194": "legislative body", 49 | "P205": "basin country", 50 | "P206": "located in or next to body of water", 51 | "P241": "military branch", 52 | "P264": "record label", 53 | "P272": "production company", 54 | "P276": "location", 55 | "P279": "subclass of", 56 | "P355": "subsidiary", 57 | "P361": "part of", 58 | "P364": "original language of work", 59 | "P400": "platform", 60 | "P403": "mouth of the watercourse", 61 | "P449": "original network", 62 | "P463": "member of", 63 | "P488": "chairperson", 64 | "P495": "country of origin", 65 | "P527": "has part", 66 | "P551": "residence", 67 | "P569": "date of birth", 68 | "P570": "date of death", 69 | "P571": "inception", 70 | "P576": "dissolved, abolished or demolished", 71 | "P577": "publication date", 72 | "P580": "start time", 73 | "P582": "end time", 74 | "P585": "point in time", 75 | "P607": "conflict", 76 | "P674": "characters", 77 | "P676": "lyrics by", 78 | "P706": "located on terrain feature", 79 | "P710": "participant", 80 | "P737": "influenced by", 81 | "P740": "location of formation", 82 | "P749": "parent organization", 83 | "P800": "notable work", 84 | "P807": "separated from", 85 | "P840": "narrative location", 86 | "P937": "work location", 87 | "P1001": "applies to jurisdiction", 88 | "P1056": "product or material produced", 89 | "P1198": "unemployment rate", 90 | "P1336": "territory claimed by", 91 | "P1344": "participant of", 92 | "P1365": "replaces", 93 | "P1366": "replaced by", 94 | "P1376": "capital of", 95 | "P1412": "languages spoken, written or signed", 96 | "P1441": "present in work", 97 | "P3373": "sibling" 98 | } -------------------------------------------------------------------------------- /data/DWIE/rel_desc.json: -------------------------------------------------------------------------------- 1 | { 2 | "None": "subject is not related to object", 3 | "institution_of": "associated with an institution", 4 | "part_of": "component of another entity", 5 | "head_of": "leading position in an organization", 6 | "member_of": "person and an organization he is member of", 7 | "agent_of": "person officially working for a government", 8 | "citizen_of": "person and the country it is citizen of", 9 | "citizen_of-x": "individual and a specific country or nation", 10 | "head_of_state": "head of a state", 11 | "head_of_state-x": "highest constitutional position in a country", 12 | "head_of_gov": "head of a government", 13 | "head_of_gov-x": "highest executive authority of a government", 14 | "gpe0": "countries and former countries", 15 | "based_in0": "organization operating in a country", 16 | "based_in0-x": "country where an organization has its primary operations", 17 | "event_in0": "event and the country it is taking place in", 18 | "minister_of": "politician and the government where he is a minister of", 19 | "minister_of-x": "government official overseeing a specific department of government", 20 | "in0": "physically located in a country", 21 | "in0-x": "located within the boundaries of a country", 22 | "based_in2": "organization operating in a city", 23 | "agency_of": "governmental organization and its geopolitical entity", 24 | "agency_of-x": "organization carrying out specific tasks on behalf of another government", 25 | "ministry_of": "ministry and the geopolitical entity it belongs to", 26 | "artifact_of": "object owned by a person or company", 27 | "in1": "physically located in a state or province", 28 | "agent_of-x": "acting on behalf of another entity", 29 | "signed_by": "document formally endorsed by a specific entity", 30 | "appears_in": "sport team or athlete participating in a competition event", 31 | "vs": "athlete or sport team has competed, or will compete, against an opponent", 32 | "won_vs": "athlete or sport team won vs another athlete or team", 33 | "coach_of": "coach and the team, or athlete, who is being coached", 34 | "player_of": "participant in a particular sport, game, or performance", 35 | "is_meeting": "two people are meeting", 36 | "created_by": "work-of-art and the artist who created it", 37 | "spokesperson_of": "official representative for a particular organization", 38 | "event_in": "event taking place within the context associated with another entity", 39 | "product_of": "product and the company who is producing it", 40 | "in2": "physically located in a city", 41 | "award_received": "person or a work of art receiving an award", 42 | "law_of": "legal document and the regulations it contains", 43 | "spouse_of": "one person is married, or is about to marry, to another", 44 | "event_in2": "event and the city it is taking place in", 45 | "royalty_of": "royalty and the country he is a royalty of", 46 | "gpe1": "geopolitical administrative regions", 47 | "advisor_of": "person giving advice to another person or organization", 48 | "parent_of": "parent and the son (or daughter)", 49 | "child_of": "son (or daughter) and the parent", 50 | "based_in1": "organization operating in a state or province", 51 | "gpe2": "cities, towns, municipalities, sub municipalities", 52 | "directed_by": "movie and the director who directed it", 53 | "plays_in": "actor and the movie or play he performs in", 54 | "character_in": "fictional character that is a part of a story", 55 | "present_in0": "existence of an entity within a specific context", 56 | "founder_of": "organization and the person who has founded it", 57 | "mayor_of": "elected leader of a municipality or city", 58 | "based_in2-x": "city of an organization", 59 | "sanctions": "one country imposes sanctions on another country", 60 | "sibling": "two people are brothers or sisters", 61 | "brand_of": "product and the specific brand that produces it", 62 | "event_in0-x": "event taking place within the context of another country", 63 | "played_by": "character and the actor who plays it", 64 | "artifact_of-x": "product of the work of a particular individual or group", 65 | "based_in1-x": "state or province of an organization", 66 | "event_in1": "event and the state or province it is taking place in", 67 | "publisher": "entity or organization producing and distributing a particular piece of content" 68 | } -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import csv 4 | import os 5 | os.environ["OMP_NUM_THREADS"] = "1" # restraints the model to 1 cpu 6 | import os.path as osp 7 | 8 | import torch 9 | 10 | from tqdm import tqdm 11 | 12 | import config 13 | from dataset import load_data 14 | from models.utils import load_config, load_tokenizer, load_model 15 | from logger import FileLogger 16 | from evaluation import * 17 | from utils import * 18 | 19 | 20 | class Evaluator: 21 | 22 | def __init__(self): 23 | ### Load config / tokenizer / model ### 24 | self.config = load_config(args) 25 | self.tokenizer = load_tokenizer(args) 26 | 27 | ### Load data ### 28 | self.valid_loader, self.valid_features = load_data(args, self.config, self.tokenizer, split="dev") 29 | self.test_loader, self.test_features = load_data(args, self.config, self.tokenizer, split="test") 30 | self.theta = args.theta 31 | 32 | ### Load trained parameter weights ### 33 | ckpt_model_path = osp.join(args.train_output_dir, "best_valid_f1.pt") 34 | if osp.exists(ckpt_model_path): 35 | log.console(f"Loading model checkpoint from {ckpt_model_path}...") 36 | ckpt = torch.load(ckpt_model_path) 37 | log.console(f"Validation loss was {ckpt['loss']:.4f}") 38 | log.console(f"Validation avg theta was {ckpt['theta']:.4f}") 39 | log.console(f"Validation F1 was {ckpt['f1']:.4f}") 40 | pretrained_dict = {key.replace("module.", ""): value for key, value in ckpt['model_state_dict'].items()} 41 | self.theta = ckpt['theta'] 42 | self.model = load_model(args, self.config, self.tokenizer) 43 | self.model.load_state_dict(pretrained_dict) 44 | else: 45 | log.event("Predicting with untrained model!") 46 | self.model = load_model(args, self.config, self.tokenizer) 47 | 48 | 49 | @torch.no_grad() 50 | def evaluate(self, split="dev"): 51 | self.model.eval() 52 | dataloader = self.valid_loader if split == "dev" else self.test_loader 53 | features = self.valid_features if split == "dev" else self.test_features 54 | total = len(dataloader) 55 | logits, labels = [], [] 56 | 57 | with tqdm(desc="Evaluating", total=total, ncols=100) as pbar: 58 | for step, inputs in enumerate(dataloader, 1): 59 | inputs["input_ids"] = inputs["input_ids"].to(args.device) 60 | inputs["attention_mask"] = inputs["attention_mask"].to(args.device) 61 | 62 | ### Forward pass ### 63 | outputs = self.model(**inputs) 64 | _, logit, label = outputs 65 | logits.append(logit) 66 | labels.append(label) 67 | 68 | pbar.update(1) 69 | del outputs 70 | 71 | logits = torch.cat(logits, dim=0) 72 | labels = torch.cat(labels, dim=0) 73 | 74 | # Remove "no relation" label (idx=0) b/c it was a "fake" label => should not be counted in F1 75 | logits_eval = logits[:,1:] 76 | labels_eval = labels[:,1:] 77 | 78 | score_dict = unofficial_evaluate(logits_eval, labels_eval, dataset_name=args.dataset_name) 79 | if split == "dev": 80 | self.theta = score_dict["theta"] 81 | best_f1 = score_dict["F1"] 82 | 83 | ece, ace, prob_true, prob_pred = calibrate(logits, labels) 84 | log.console(f"ECE: {ece}, ACE: {ace}") 85 | 86 | if args.dataset_name in {"DocRED", "Re-DocRED"}: 87 | ans = to_official(logits_eval, features, self.theta) 88 | best_f1, _, best_f1_ign, _ = official_evaluate(ans, args.data_dir, split=split) 89 | 90 | with open(osp.join(args.train_output_dir, "evaluation.txt"), "a") as f: 91 | f.write(f"{split} F1: {best_f1}\n") 92 | if args.dataset_name in {"DocRED", "Re-DocRED"}: 93 | f.write(f"{split} Ign F1: {best_f1_ign}\n") 94 | f.write(f"{split} Macro F1: {score_dict['macro_F1']}\n") 95 | f.write(f"{split} Macro F1@500: {score_dict['macro_F1_at_500']}\n") 96 | f.write(f"{split} Macro F1@200: {score_dict['macro_F1_at_200']}\n") 97 | f.write(f"{split} Macro F1@100: {score_dict['macro_F1_at_100']}\n") 98 | f.write(f"{split} ECE: {ece}\n") 99 | f.write(f"{split} ACE: {ace}\n") 100 | f.write(f"{split} F1 Per Class: {score_dict['F1_per_class']}\n") 101 | 102 | with open(osp.join(args.train_output_dir, f"calibration_curve_data.csv"), "a") as f: 103 | writer = csv.writer(f) 104 | writer.writerow(prob_true.tolist()) 105 | writer.writerow(prob_pred.tolist()) 106 | 107 | 108 | @torch.no_grad() 109 | def report(self): 110 | self.model.eval() 111 | total = len(self.test_loader) 112 | preds = [] 113 | 114 | with tqdm(desc="Evaluating", total=total, ncols=100) as pbar: 115 | for step, inputs in enumerate(self.test_loader, 1): 116 | inputs["input_ids"] = inputs["input_ids"].to(args.device) 117 | inputs["attention_mask"] = inputs["attention_mask"].to(args.device) 118 | 119 | ### Forward pass ### 120 | outputs = self.model(**inputs) 121 | _, pred, _ = outputs 122 | preds.append(pred) 123 | 124 | pbar.update(1) 125 | del outputs 126 | 127 | preds = torch.cat(preds, dim=0)[:,1:] 128 | ans = to_official(preds, self.test_features, self.theta) 129 | 130 | with open(osp.join(args.train_output_dir, "result.json"), "w") as f: 131 | json.dump(ans, f) 132 | 133 | 134 | if __name__ == "__main__": 135 | parser = argparse.ArgumentParser(description="evaluate.py") 136 | config.model_args(parser) 137 | config.data_args(parser) 138 | config.predict_args(parser) 139 | args = parser.parse_args() 140 | args.n_gpu = torch.cuda.device_count() 141 | args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 142 | 143 | with open(osp.join(args.data_dir, "label_map.json"), "r") as f: 144 | label_map = json.load(f) 145 | args.num_labels = len(label_map) 146 | 147 | os.makedirs(args.train_output_dir, exist_ok=True) 148 | os.makedirs(args.cache_dir, exist_ok=True) 149 | 150 | log = FileLogger(args.train_output_dir, is_master=True, is_rank0=True, log_to_file=args.log_to_file) 151 | log.console(args) 152 | 153 | evaluator = Evaluator() 154 | evaluator.evaluate(split="dev") 155 | 156 | if args.dataset_name in {"Re-DocRED", "DWIE"}: 157 | evaluator.evaluate(split="test") 158 | elif args.dataset_name == "DocRED": 159 | evaluator.report() 160 | -------------------------------------------------------------------------------- /data/DocRED/rel_desc.json: -------------------------------------------------------------------------------- 1 | { 2 | "None": "subject is not related to object", 3 | "P6": "head of the executive power of this town, city, municipality, state, country, or other governmental body", 4 | "P17": "sovereign state of this item", 5 | "P19": "most specific known (e.g. city instead of country, or hospital instead of city) birth location of a person, animal or fictional character", 6 | "P20": "most specific known (e.g. city instead of country, or hospital instead of city) death location of a person, animal or fictional character", 7 | "P22": "male parent of the subject", 8 | "P25": "female parent of the subject", 9 | "P26": "the subject has the object as their spouse (husband, wife, partner, etc.)", 10 | "P27": "the object is a country that recognizes the subject as its citizen", 11 | "P30": "continent of which the subject is a part", 12 | "P31": "that class of which this subject is a particular example and member", 13 | "P35": "official with the highest formal authority in a country/state", 14 | "P36": "primary city of a country, province, state or other type of administrative territorial entity", 15 | "P37": "language designated as official by this item", 16 | "P39": "subject currently or formerly holds the object position or public office", 17 | "P40": "subject has object as biological, foster, and/or adoptive child", 18 | "P50": "main creator(s) of a written work", 19 | "P54": "sports teams or clubs that the subject currently represents or formerly represented", 20 | "P57": "director(s) of film, TV-series, stageplay, video game or similar", 21 | "P58": "person(s) who wrote the script for subject item", 22 | "P69": "educational institution attended by subject", 23 | "P86": "person(s) who wrote the music", 24 | "P102": "the political party of which this politician is or has been a member", 25 | "P108": "person or organization for which the subject works or worked", 26 | "P112": "founder or co-founder of this organization, religion or place", 27 | "P118": "league in which team or player plays or has played in", 28 | "P123": "organization or person responsible for publishing books, periodicals, games or software", 29 | "P127": "owner of the subject", 30 | "P131": "the item is located on the territory of the following administrative entity.", 31 | "P136": "creative work's genre or an artist's field of work", 32 | "P137": "person, profession, or organization that operates the equipment, facility, or service", 33 | "P140": "religion of a person, organization or religious building, or associated with this subject", 34 | "P150": "(list of) direct subdivisions of an administrative territorial entity", 35 | "P155": "immediately prior item in a series of which the subject is a part", 36 | "P156": "immediately following item in a series of which the subject is a part", 37 | "P159": "specific location where an organization's headquarters is or has been situated", 38 | "P161": "actor in the subject production", 39 | "P162": "person(s) who produced the film, musical work, theatrical production, etc. (for film, this does not include executive producers, associate producers, etc.)", 40 | "P166": "award or recognition received by a person, organisation or creative work", 41 | "P170": "maker of this creative work or other object (where no more specific property exists)", 42 | "P171": "closest parent taxon of the taxon in question", 43 | "P172": "subject's ethnicity (consensus is that a VERY high standard of proof is needed for this field to be used. In general this means 1) the subject claims it him/herself, or 2) it is widely agreed on by scholars, or 3) is fictional and portrayed as such).", 44 | "P175": "actor, musician, band or other performer associated with this role or musical work", 45 | "P176": "manufacturer or producer of this product", 46 | "P178": "organisation or person that developed the item", 47 | "P179": "series which contains the subject", 48 | "P190": "twin towns, sister cities, twinned municipalities and other localities that have a partnership or cooperative agreement, either legally or informally acknowledged by their governments", 49 | "P194": "legislative body governing this entity; political institution with elected representatives, such as a parliament/legislature or council", 50 | "P205": "country that have drainage to/from or border the body of water", 51 | "P206": "sea, lake or river", 52 | "P241": "branch to which this military unit, award, office, or person belongs, e.g. Royal Navy", 53 | "P264": "brand and trademark associated with the marketing of subject music recordings and music videos", 54 | "P272": "company that produced this film, audio or performing arts work", 55 | "P276": "location of the item, physical object or event is within", 56 | "P279": "all instances of these items are instances of those items; this item is a class (subset) of that item", 57 | "P355": "subsidiary of a company or organization, opposite of parent organization", 58 | "P361": "object of which the subject is a part (if this subject is already part of object A which is a part of object B, then please only make the subject part of object A). Inverse property of \"has part\".", 59 | "P364": "language in which a film or a performance work was originally created. Deprecated for written works", 60 | "P400": "platform for which a work was developed or released, or the specific platform version of a software product", 61 | "P403": "the body of water to which the watercourse drains", 62 | "P449": "network(s) or service(s) that originally broadcasted a radio or television program", 63 | "P463": "organization, musical group, or club to which the subject belongs", 64 | "P488": "presiding member of an organization, group or body", 65 | "P495": "country of origin of this item (creative work, food, phrase, product, etc.)", 66 | "P527": "part of this subject", 67 | "P551": "the place where the person is or has been, resident", 68 | "P569": "date on which the subject was born", 69 | "P570": "date on which the subject died", 70 | "P571": "date or point in time when the subject came into existence as defined", 71 | "P576": "point in time at which the subject (organisation, building) ceased to exist", 72 | "P577": "date or point in time when a work was first published or released", 73 | "P580": "time an item begins to exist or a statement starts being valid", 74 | "P582": "time an item ceases to exist or a statement stops being valid", 75 | "P585": "time and date something took place, existed or a statement was true", 76 | "P607": "battles, wars or other military engagements in which the person or item participated", 77 | "P674": "characters which appear in this item (like plays, operas, operettas, books, comics, films, TV series, video games)", 78 | "P676": "author of song lyrics", 79 | "P706": "located on the specified landform", 80 | "P710": "person, group of people or organization (object) that actively takes/took part in an event or process (subject)", 81 | "P737": "this person, idea, etc. is informed by that other person, idea, etc., e.g. \"Heidegger was influenced by Aristotle\".", 82 | "P740": "location where a group or organization was formed", 83 | "P749": "parent organization of an organization, opposite of subsidiaries", 84 | "P800": "notable scientific, artistic or literary work, or other work of significance among subject's works", 85 | "P807": "subject was founded or started by separating from identified object", 86 | "P840": "the narrative of the work is set in this location", 87 | "P937": "location where persons were active", 88 | "P1001": "the item (an institution, law, public office ...) or statement belongs to or has power over or applies to the value (a territorial jurisdiction: a country, state, municipality, ...)", 89 | "P1056": "material or product produced by a government agency, business, industry, facility, or process", 90 | "P1198": "portion of a workforce population that is not employed", 91 | "P1336": "administrative divisions that claim control of a given area", 92 | "P1344": "event a person or an organization was/is a participant in", 93 | "P1365": "person or item replaced", 94 | "P1366": "other person or item which continues the item by replacing it in its role", 95 | "P1376": "country, state, department, canton or other administrative division of which the municipality is the governmental seat", 96 | "P1412": "language(s) that a person speaks, writes or signs, including the native language(s)", 97 | "P1441": "this (fictional) entity or person is present in the story of that work", 98 | "P3373": "the subject has the object as their sibling (brother, sister, etc.)" 99 | } -------------------------------------------------------------------------------- /data/Re-DocRED/rel_desc.json: -------------------------------------------------------------------------------- 1 | { 2 | "None": "subject is not related to object", 3 | "P6": "head of the executive power of this town, city, municipality, state, country, or other governmental body", 4 | "P17": "sovereign state of this item", 5 | "P19": "most specific known (e.g. city instead of country, or hospital instead of city) birth location of a person, animal or fictional character", 6 | "P20": "most specific known (e.g. city instead of country, or hospital instead of city) death location of a person, animal or fictional character", 7 | "P22": "male parent of the subject", 8 | "P25": "female parent of the subject", 9 | "P26": "the subject has the object as their spouse (husband, wife, partner, etc.)", 10 | "P27": "the object is a country that recognizes the subject as its citizen", 11 | "P30": "continent of which the subject is a part", 12 | "P31": "that class of which this subject is a particular example and member", 13 | "P35": "official with the highest formal authority in a country/state", 14 | "P36": "primary city of a country, province, state or other type of administrative territorial entity", 15 | "P37": "language designated as official by this item", 16 | "P39": "subject currently or formerly holds the object position or public office", 17 | "P40": "subject has object as biological, foster, and/or adoptive child", 18 | "P50": "main creator(s) of a written work", 19 | "P54": "sports teams or clubs that the subject currently represents or formerly represented", 20 | "P57": "director(s) of film, TV-series, stageplay, video game or similar", 21 | "P58": "person(s) who wrote the script for subject item", 22 | "P69": "educational institution attended by subject", 23 | "P86": "person(s) who wrote the music", 24 | "P102": "the political party of which this politician is or has been a member", 25 | "P108": "person or organization for which the subject works or worked", 26 | "P112": "founder or co-founder of this organization, religion or place", 27 | "P118": "league in which team or player plays or has played in", 28 | "P123": "organization or person responsible for publishing books, periodicals, games or software", 29 | "P127": "owner of the subject", 30 | "P131": "the item is located on the territory of the following administrative entity.", 31 | "P136": "creative work's genre or an artist's field of work", 32 | "P137": "person, profession, or organization that operates the equipment, facility, or service", 33 | "P140": "religion of a person, organization or religious building, or associated with this subject", 34 | "P150": "(list of) direct subdivisions of an administrative territorial entity", 35 | "P155": "immediately prior item in a series of which the subject is a part", 36 | "P156": "immediately following item in a series of which the subject is a part", 37 | "P159": "specific location where an organization's headquarters is or has been situated", 38 | "P161": "actor in the subject production", 39 | "P162": "person(s) who produced the film, musical work, theatrical production, etc. (for film, this does not include executive producers, associate producers, etc.)", 40 | "P166": "award or recognition received by a person, organisation or creative work", 41 | "P170": "maker of this creative work or other object (where no more specific property exists)", 42 | "P171": "closest parent taxon of the taxon in question", 43 | "P172": "subject's ethnicity (consensus is that a VERY high standard of proof is needed for this field to be used. In general this means 1) the subject claims it him/herself, or 2) it is widely agreed on by scholars, or 3) is fictional and portrayed as such).", 44 | "P175": "actor, musician, band or other performer associated with this role or musical work", 45 | "P176": "manufacturer or producer of this product", 46 | "P178": "organisation or person that developed the item", 47 | "P179": "series which contains the subject", 48 | "P190": "twin towns, sister cities, twinned municipalities and other localities that have a partnership or cooperative agreement, either legally or informally acknowledged by their governments", 49 | "P194": "legislative body governing this entity; political institution with elected representatives, such as a parliament/legislature or council", 50 | "P205": "country that have drainage to/from or border the body of water", 51 | "P206": "sea, lake or river", 52 | "P241": "branch to which this military unit, award, office, or person belongs, e.g. Royal Navy", 53 | "P264": "brand and trademark associated with the marketing of subject music recordings and music videos", 54 | "P272": "company that produced this film, audio or performing arts work", 55 | "P276": "location of the item, physical object or event is within", 56 | "P279": "all instances of these items are instances of those items; this item is a class (subset) of that item", 57 | "P355": "subsidiary of a company or organization, opposite of parent organization", 58 | "P361": "object of which the subject is a part (if this subject is already part of object A which is a part of object B, then please only make the subject part of object A). Inverse property of \"has part\".", 59 | "P364": "language in which a film or a performance work was originally created. Deprecated for written works", 60 | "P400": "platform for which a work was developed or released, or the specific platform version of a software product", 61 | "P403": "the body of water to which the watercourse drains", 62 | "P449": "network(s) or service(s) that originally broadcasted a radio or television program", 63 | "P463": "organization, musical group, or club to which the subject belongs", 64 | "P488": "presiding member of an organization, group or body", 65 | "P495": "country of origin of this item (creative work, food, phrase, product, etc.)", 66 | "P527": "part of this subject", 67 | "P551": "the place where the person is or has been, resident", 68 | "P569": "date on which the subject was born", 69 | "P570": "date on which the subject died", 70 | "P571": "date or point in time when the subject came into existence as defined", 71 | "P576": "point in time at which the subject (organisation, building) ceased to exist", 72 | "P577": "date or point in time when a work was first published or released", 73 | "P580": "time an item begins to exist or a statement starts being valid", 74 | "P582": "time an item ceases to exist or a statement stops being valid", 75 | "P585": "time and date something took place, existed or a statement was true", 76 | "P607": "battles, wars or other military engagements in which the person or item participated", 77 | "P674": "characters which appear in this item (like plays, operas, operettas, books, comics, films, TV series, video games)", 78 | "P676": "author of song lyrics", 79 | "P706": "located on the specified landform", 80 | "P710": "person, group of people or organization (object) that actively takes/took part in an event or process (subject)", 81 | "P737": "this person, idea, etc. is informed by that other person, idea, etc., e.g. \"Heidegger was influenced by Aristotle\".", 82 | "P740": "location where a group or organization was formed", 83 | "P749": "parent organization of an organization, opposite of subsidiaries", 84 | "P800": "notable scientific, artistic or literary work, or other work of significance among subject's works", 85 | "P807": "subject was founded or started by separating from identified object", 86 | "P840": "the narrative of the work is set in this location", 87 | "P937": "location where persons were active", 88 | "P1001": "the item (an institution, law, public office ...) or statement belongs to or has power over or applies to the value (a territorial jurisdiction: a country, state, municipality, ...)", 89 | "P1056": "material or product produced by a government agency, business, industry, facility, or process", 90 | "P1198": "portion of a workforce population that is not employed", 91 | "P1336": "administrative divisions that claim control of a given area", 92 | "P1344": "event a person or an organization was/is a participant in", 93 | "P1365": "person or item replaced", 94 | "P1366": "other person or item which continues the item by replacing it in its role", 95 | "P1376": "country, state, department, canton or other administrative division of which the municipality is the governmental seat", 96 | "P1412": "language(s) that a person speaks, writes or signs, including the native language(s)", 97 | "P1441": "this (fictional) entity or person is present in the story of that work", 98 | "P3373": "the subject has the object as their sibling (brother, sister, etc.)" 99 | } -------------------------------------------------------------------------------- /models/prism.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class PRISM(nn.Module): 9 | 10 | def __init__(self, args, config, relation_features, encoder1, encoder2): 11 | super().__init__() 12 | self.args = args 13 | self.config = config 14 | self.encoder1 = encoder1 15 | self.encoder2 = encoder2 16 | self.relation_features = relation_features 17 | args.embedding_size = config.hidden_size # 768 for BERT, 1024 for RoBERTa 18 | 19 | self.head_extractor = nn.Linear(config.hidden_size, args.embedding_size) 20 | self.tail_extractor = nn.Linear(config.hidden_size, args.embedding_size) 21 | self.pair_extractor = nn.Linear(2 * args.embedding_size, config.hidden_size) 22 | self.loss_fnt = nn.BCEWithLogitsLoss(reduction="mean") 23 | 24 | if args.group_bilinear: 25 | self.bilinear = nn.Linear(args.embedding_size * args.block_size, args.num_labels) 26 | else: 27 | self.bilinear = nn.Bilinear(args.embedding_size, args.embedding_size, args.num_labels) 28 | 29 | 30 | def __process_long_input(self, input_ids, attention_mask): 31 | N, T = input_ids.shape 32 | new_input_ids, new_attention_mask = [], [] 33 | num_chunks = math.ceil(T / self.args.max_seq_length) 34 | # Split into chunks for training 35 | for i in range(N): 36 | for c in range(num_chunks): 37 | start = c * self.args.max_seq_length 38 | end = (c+1) * self.args.max_seq_length 39 | # pad the last chunk 40 | if len(input_ids[i, start:end]) < self.args.max_seq_length: 41 | to_pad = self.args.max_seq_length - len(input_ids[i, start:end]) 42 | new_input_ids.append(F.pad(input_ids[i, start:end], (0, to_pad), value=self.config.pad_token_id)) 43 | new_attention_mask.append(F.pad(attention_mask[i, start:end], (0, to_pad), value=0)) 44 | else: 45 | new_input_ids.append(input_ids[i, start:end]) 46 | new_attention_mask.append(attention_mask[i, start:end]) 47 | 48 | input_ids = torch.stack(new_input_ids, dim=0) 49 | attention_mask = torch.stack(new_attention_mask, dim=0) 50 | out = self.encoder1(input_ids, attention_mask, output_attentions=True) 51 | out.last_hidden_state = out.last_hidden_state.reshape(N, num_chunks * self.args.max_seq_length, -1) 52 | return out 53 | 54 | 55 | def __get_entity_embeddings(self, output, ent_pos, ent_pairs): 56 | offset = 1 # cls token shifts ent_pos by 1 57 | last_hidden = output.last_hidden_state 58 | N, T, D = last_hidden.shape 59 | head_embeddings, tail_embeddings = [], [] 60 | 61 | for i in range(N): 62 | ent_embs = [] 63 | for ent in ent_pos[i]: 64 | if len(ent) == 0: 65 | e_emb = last_hidden.new_zeros(D) 66 | elif len(ent) > 1: # more than 1 mention 67 | e_emb = [] 68 | for start, end in ent: 69 | if start + offset < T: # In case the entity mention is truncated due to limited max seq length 70 | if self.args.mark_entities: 71 | e_emb.append(last_hidden[i, start+offset]) 72 | else: # max-pool all token embeddings to represent entity embedding 73 | m_emb, _ = torch.max(last_hidden[i, start+offset:end+offset], dim=0) 74 | e_emb.append(m_emb) 75 | 76 | if len(e_emb) > 0: 77 | e_emb = torch.stack(e_emb, dim=0) 78 | if self.args.ent_pooler == "logsumexp": 79 | e_emb = torch.logsumexp(e_emb, dim=0) 80 | elif self.args.ent_pooler == "max": 81 | e_emb, _ = torch.max(e_emb, dim=0) 82 | elif self.args.ent_pooler == "sum": 83 | e_emb = torch.sum(e_emb, dim=0) 84 | elif self.args.ent_pooler == "avg": 85 | e_emb = torch.mean(e_emb, dim=0) 86 | else: 87 | raise ValueError("Supported pooling operations: logsumexp, max, sum, avg") 88 | else: 89 | e_emb = last_hidden.new_zeros(D) 90 | else: 91 | start, end = ent[0] 92 | if start + offset < T: 93 | if self.args.mark_entities: 94 | e_emb = last_hidden[i, start+offset] 95 | else: # max-pool all token embeddings to represent entity embedding 96 | e_emb, _ = torch.max(last_hidden[i, start+offset:end+offset], dim=0) 97 | else: 98 | e_emb = last_hidden.new_zeros(D) 99 | 100 | ent_embs.append(e_emb) 101 | 102 | ent_embs = torch.stack(ent_embs, dim=0) # (num_ents, D) 103 | 104 | # Get embeddings of all possible entity pairs 105 | ent_pairs_i = torch.tensor(ent_pairs[i], dtype=torch.long, device=last_hidden.device) 106 | head_embs = torch.index_select(ent_embs, dim=0, index=ent_pairs_i[:, 0]) 107 | tail_embs = torch.index_select(ent_embs, dim=0, index=ent_pairs_i[:, 1]) 108 | 109 | head_embeddings.append(head_embs) 110 | tail_embeddings.append(tail_embs) 111 | 112 | head_embeddings = torch.cat(head_embeddings, dim=0) 113 | tail_embeddings = torch.cat(tail_embeddings, dim=0) 114 | 115 | return head_embeddings, tail_embeddings 116 | 117 | 118 | def __get_relation_embeddings(self, r_out): 119 | if self.args.rel_pooler == "pooler": 120 | return r_out.pooler_output 121 | elif self.args.rel_pooler == "cls": 122 | return r_out.last_hidden_state[:,0,:] 123 | else: 124 | raise ValueError("Supported pooling operations: pooler, cls.") 125 | 126 | 127 | def __compute_ht_scores(self, head_embs, tail_embs): 128 | if self.args.group_bilinear: 129 | b1 = head_embs.view(-1, self.args.embedding_size // self.args.block_size, self.args.block_size).unsqueeze(3) 130 | b2 = tail_embs.view(-1, self.args.embedding_size // self.args.block_size, self.args.block_size).unsqueeze(2) 131 | bl = (b1 * b2).view(-1, self.args.embedding_size * self.args.block_size) 132 | scores = self.bilinear(bl) 133 | else: 134 | scores = self.bilinear(head_embs, tail_embs) 135 | return scores 136 | 137 | 138 | def __compute_pr_scores(self, pair_embs, rel_embs): 139 | scores = pair_embs @ rel_embs.T 140 | normalized_pair_embs = F.normalize(pair_embs, p=2, dim=-1) 141 | normalized_rel_embs = F.normalize(rel_embs, p=2, dim=-1) 142 | normalized_scores = normalized_pair_embs @ normalized_rel_embs.T 143 | return scores, normalized_scores 144 | 145 | 146 | def forward( 147 | self, 148 | input_ids=None, 149 | attention_mask=None, 150 | ent_pos=None, 151 | ent_pairs=None, 152 | labels=None, 153 | ): 154 | # multi-label classification 155 | N, T = input_ids.shape 156 | if self.args.long_seq and T > self.args.max_seq_length: 157 | out = self.__process_long_input(input_ids, attention_mask) 158 | else: 159 | out = self.encoder1(input_ids, attention_mask, output_attentions=True) 160 | 161 | h_embs, t_embs = self.__get_entity_embeddings(out, ent_pos, ent_pairs) 162 | h_embs = torch.tanh(self.head_extractor(h_embs)) 163 | t_embs = torch.tanh(self.tail_extractor(t_embs)) 164 | logits = self.__compute_ht_scores(h_embs, t_embs) 165 | 166 | # pair-relation similarity 167 | if self.args.share_params: 168 | r_out = self.encoder1(self.relation_features["input_ids"].to(input_ids), 169 | self.relation_features["attention_mask"].to(input_ids)) 170 | else: 171 | r_out = self.encoder2(self.relation_features["input_ids"].to(input_ids), 172 | self.relation_features["attention_mask"].to(input_ids)) 173 | 174 | r_embs = self.__get_relation_embeddings(r_out) 175 | p_embs = torch.cat([h_embs, t_embs], dim=-1) 176 | p_embs = torch.tanh(self.pair_extractor(p_embs)) 177 | pr_logits, normalized_pr_logits = self.__compute_pr_scores(p_embs, r_embs) 178 | 179 | if self.args.dist_fn == "inner": 180 | logits = logits + pr_logits 181 | elif self.args.dist_fn == "cosine": 182 | logits = logits + (normalized_pr_logits / self.args.temperature) 183 | 184 | model_output = (torch.sigmoid(logits),) 185 | 186 | if labels is not None: 187 | labels = [torch.tensor(label) for label in labels] 188 | labels = torch.cat(labels, dim=0).to(logits) 189 | loss = self.loss_fnt(logits, labels) 190 | model_output = (loss,) + model_output + (labels,) 191 | 192 | return model_output 193 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | os.environ["OMP_NUM_THREADS"] = "1" # restraints the model to 1 cpu 4 | import os.path as osp 5 | import time 6 | 7 | import torch 8 | 9 | from tqdm import tqdm 10 | import wandb 11 | 12 | import config 13 | from dataset import load_data 14 | from models.utils import load_config, load_tokenizer, load_model 15 | from logger import FileLogger 16 | from evaluation import * 17 | from utils import * 18 | 19 | 20 | class Trainer: 21 | 22 | def __init__(self): 23 | ### Load config / tokenizer / model ### 24 | self.config = load_config(args) 25 | self.tokenizer = load_tokenizer(args) 26 | 27 | ### Load data ### 28 | self.train_loader, _ = load_data(args, self.config, self.tokenizer, split="train") 29 | self.valid_loader, self.valid_features = load_data(args, self.config, self.tokenizer, split="dev") 30 | 31 | self.model = load_model(args, self.config, self.tokenizer) 32 | 33 | ### Calculate steps ### 34 | args.total_steps = int(len(self.train_loader) * args.epochs // args.gradient_accumulation_steps) 35 | args.warmup_steps = int(args.total_steps * args.warmup_ratio) 36 | log.console(f"warmup steps: {args.warmup_steps}, total steps: {args.total_steps}") 37 | 38 | ### scaler / optimizer / scheduler ### 39 | self.scaler = init_scaler(args) 40 | self.optimizer = init_optimizer(args, self.model) 41 | self.scheduler = init_scheduler(args, self.optimizer) 42 | 43 | self.best_valid_loss = float("inf") 44 | self.best_valid_f1 = float("-inf") 45 | self.start_epoch = 0 46 | self.tolerance = 0 47 | self.global_step = 0 48 | 49 | ### Resume training ### 50 | ckpt_model_path = osp.join(args.train_output_dir, "best_valid_f1.pt") 51 | if args.resume and osp.exists(ckpt_model_path): 52 | log.console(f"Loading model checkpoint from {ckpt_model_path}...") 53 | ckpt = torch.load(ckpt_model_path) 54 | self.best_valid_loss = ckpt["loss"] 55 | self.best_valid_f1 = ckpt["f1"] 56 | self.start_epoch = ckpt["epoch"] 57 | self.global_step = ckpt["steps"] 58 | self.model.load_state_dict(ckpt['model_state_dict']) 59 | self.optimizer.load_state_dict(ckpt['optimizer_state_dict']) 60 | self.scheduler.load_state_dict(ckpt['scheduler_state_dict']) 61 | log.console(f"Validation loss was {ckpt['loss']:.4f}") 62 | log.console(f"Validation F1 was {ckpt['f1']:.4f}") 63 | else: 64 | log.console(f"Training model from scratch") 65 | 66 | 67 | def train(self): 68 | for epoch in range(self.start_epoch, args.epochs): 69 | avg_train_loss = self.__epoch_train(epoch) 70 | avg_valid_loss, score_dict = self.__epoch_valid() 71 | 72 | log.console(f"epoch: {epoch+1}, " + 73 | f"steps: {self.global_step}, " + 74 | f"current lr: {self.optimizer.param_groups[0]['lr']:.8f}, " + 75 | f"train loss: {avg_train_loss:.4f}, " + 76 | f"valid loss: {avg_valid_loss:.4f}, " + 77 | f"best theta: {score_dict['theta']}") 78 | log.console(f"P ({score_dict['num_matches']}/{score_dict['num_preds']}): {score_dict['P']:.5f}, " + 79 | f"R ({score_dict['num_matches']}/{score_dict['num_labels']}): {score_dict['R']:.5f}, " + 80 | f"F1: {score_dict['F1']:.5f}") 81 | 82 | if args.wandb_on: 83 | wandb.log({"Train Loss": avg_train_loss, "Validation Loss": avg_valid_loss, 84 | "Precision": score_dict['P'], "Recall": score_dict['R'], "F1": score_dict['F1']}) 85 | 86 | if score_dict["F1"] > self.best_valid_f1: 87 | self.tolerance = 0 88 | self.best_valid_f1 = score_dict["F1"] 89 | log.console(f"Saving best valid F1 checkpoint to {args.train_output_dir}...") 90 | torch.save({'epoch': epoch, 91 | 'steps': self.global_step, 92 | 'loss': avg_valid_loss, 93 | 'p': score_dict['P'], 94 | 'r': score_dict['R'], 95 | 'f1': score_dict['F1'], 96 | 'theta': score_dict['theta'], 97 | 'model_state_dict': self.model.state_dict(), 98 | 'optimizer_state_dict': self.optimizer.state_dict(), 99 | 'scheduler_state_dict': self.scheduler.state_dict() 100 | }, osp.join(args.train_output_dir, "best_valid_f1.pt")) 101 | with open(osp.join(args.train_output_dir, "hyparams.txt"), "w") as f: 102 | f.write(f"Epoch: {epoch}\n" + 103 | f"Total Steps: {self.global_step}\n" + 104 | f"Train Loss: {avg_train_loss}\n" + 105 | f"Valid Loss: {avg_valid_loss}\n" + 106 | f"Theta: {score_dict['theta']}\n" + 107 | f"Precision: {score_dict['P']}\n" + 108 | f"Recall: {score_dict['R']}\n" + 109 | f"F1: {score_dict['F1']}") 110 | else: 111 | self.tolerance += 1 112 | log.console(f"F1 did not improve, patience: {self.tolerance}/{args.max_tolerance}") 113 | 114 | if self.tolerance == args.max_tolerance: break 115 | 116 | 117 | def __epoch_train(self, epoch): 118 | self.model.train() 119 | train_loss = 0. 120 | total = len(self.train_loader) 121 | 122 | with tqdm(desc="Training", total=total, ncols=100, disable=args.hide_tqdm) as pbar: 123 | for step, inputs in enumerate(self.train_loader, 1): 124 | inputs["input_ids"] = inputs["input_ids"].to(args.device) 125 | inputs["attention_mask"] = inputs["attention_mask"].to(args.device) 126 | 127 | ### Forward pass ### 128 | with torch.cuda.amp.autocast(enabled=args.use_amp): 129 | loss, _, _ = self.model(**inputs) 130 | 131 | if args.gradient_accumulation_steps > 1: 132 | loss = loss / args.gradient_accumulation_steps 133 | 134 | train_loss += loss.item() 135 | 136 | ### Backward pass ### 137 | if step % args.gradient_accumulation_steps == 0: 138 | self.optimizer.zero_grad() 139 | self.scaler.scale(loss).backward() 140 | self.scaler.unscale_(self.optimizer) 141 | if args.max_grad_norm > 0: 142 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), args.max_grad_norm) 143 | self.scaler.step(self.optimizer) 144 | self.scaler.update() 145 | self.scheduler.step() 146 | self.global_step += 1 147 | 148 | if self.global_step == 1 or self.global_step % args.logging_steps == 0: 149 | log.console(f"epoch: {epoch+1}, " + 150 | f"steps: {self.global_step}, " + 151 | f"current lr: {self.optimizer.param_groups[0]['lr']:.8f}, " + 152 | f"train loss: {(train_loss / step):.4f}") 153 | 154 | pbar.update(1) 155 | del loss 156 | 157 | return train_loss / total 158 | 159 | 160 | @torch.no_grad() 161 | def __epoch_valid(self): 162 | self.model.eval() 163 | valid_loss = 0. 164 | total = len(self.valid_loader) 165 | preds, labels = [], [] 166 | 167 | with tqdm(desc="Evaluating", total=total, ncols=100, disable=args.hide_tqdm) as pbar: 168 | for step, inputs in enumerate(self.valid_loader, 1): 169 | inputs["input_ids"] = inputs["input_ids"].to(args.device) 170 | inputs["attention_mask"] = inputs["attention_mask"].to(args.device) 171 | 172 | ### Forward pass ### 173 | outputs = self.model(**inputs) 174 | loss, pred, label = outputs 175 | preds.append(pred) 176 | labels.append(label) 177 | valid_loss += loss.item() 178 | 179 | pbar.update(1) 180 | del outputs 181 | 182 | # Remove "no relation" label (idx=0) b/c it was a "fake" label => should not be counted in F1 183 | preds = torch.cat(preds, dim=0)[:,1:] 184 | labels = torch.cat(labels, dim=0)[:,1:] 185 | 186 | score_dict = unofficial_evaluate(preds, labels, dataset_name=args.dataset_name) 187 | 188 | return valid_loss / total, score_dict 189 | 190 | 191 | if __name__ == "__main__": 192 | parser = argparse.ArgumentParser(description="train.py") 193 | config.model_args(parser) 194 | config.data_args(parser) 195 | config.train_args(parser) 196 | args = parser.parse_args() 197 | args.n_gpu = torch.cuda.device_count() 198 | args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 199 | 200 | with open(osp.join(args.data_dir, "label_map.json"), "r") as f: 201 | label_map = json.load(f) 202 | args.num_labels = len(label_map) 203 | 204 | os.makedirs(args.train_output_dir, exist_ok=True) 205 | os.makedirs(args.cache_dir, exist_ok=True) 206 | 207 | log = FileLogger(args.train_output_dir, is_master=True, is_rank0=True, log_to_file=args.log_to_file) 208 | log.console(args) 209 | if args.wandb_on: 210 | project_name = f"PRiSM-{args.dataset_name}" 211 | run_name = "/".join(args.train_output_dir.split("/")[2:]) 212 | wandb.init(project=project_name, name=run_name) 213 | 214 | set_seed(args.seed) 215 | 216 | trainer = Trainer() 217 | start_time = time.time() 218 | trainer.train() 219 | log.console(f"Time for training: {time.time() - start_time:.1f} seconds") 220 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import torch 5 | 6 | from tqdm import tqdm 7 | 8 | rel2id = json.load(open('data/DocRED/label_map.json', 'r')) 9 | id2rel = {value: key for key, value in rel2id.items()} 10 | rel2name = json.load(open('data/DocRED/rel_info.json', 'r')) 11 | 12 | 13 | def to_official(preds, features, threshold): 14 | 15 | h_idx, t_idx, title = [], [], [] 16 | 17 | for f in features: 18 | pairs = f["ent_pairs"] 19 | h_idx += [pair[0] for pair in pairs] 20 | t_idx += [pair[1] for pair in pairs] 21 | title += [f["title"] for _ in pairs] 22 | 23 | res = [] 24 | with tqdm(desc="Converting to submission format", total=len(preds), ncols=100) as pbar: 25 | for i in range(preds.shape[0]): 26 | pred = preds[i] 27 | if threshold == -1: # assume atlop if threhold is -1 28 | pred = torch.nonzero(pred, as_tuple=True)[0].tolist() 29 | else: 30 | pred = (pred >= threshold).nonzero(as_tuple=True)[0].tolist() 31 | for p in pred: 32 | res.append({ 33 | 'title': title[i], 34 | 'h_idx': h_idx[i], 35 | 't_idx': t_idx[i], 36 | 'r': id2rel[p+1], # need to skip "no relation" label 37 | }) 38 | pbar.update(1) 39 | 40 | return res 41 | 42 | 43 | def gen_train_facts(data_file_name, truth_dir): 44 | fact_file_name = data_file_name[data_file_name.find("train"):] 45 | fact_file_name = os.path.join(truth_dir, fact_file_name.replace(".json", ".fact")) 46 | 47 | if os.path.exists(fact_file_name): 48 | fact_in_train = set([]) 49 | triples = json.load(open(fact_file_name)) 50 | for x in triples: 51 | fact_in_train.add(tuple(x)) 52 | return fact_in_train 53 | 54 | fact_in_train = set([]) 55 | ori_data = json.load(open(data_file_name)) 56 | for data in ori_data: 57 | vertexSet = data['vertexSet'] 58 | for label in data['labels']: 59 | rel = label['r'] 60 | for n1 in vertexSet[label['h']]: 61 | for n2 in vertexSet[label['t']]: 62 | fact_in_train.add((n1['name'], n2['name'], rel)) 63 | 64 | json.dump(list(fact_in_train), open(fact_file_name, "w")) 65 | 66 | return fact_in_train 67 | 68 | 69 | def official_evaluate(tmp, data_dir, split): 70 | 71 | truth_dir = os.path.join(data_dir, 'ref') 72 | 73 | if not os.path.exists(truth_dir): 74 | os.makedirs(truth_dir) 75 | 76 | fact_in_train_annotated = gen_train_facts(os.path.join(data_dir, "train.json"), truth_dir) 77 | fact_in_train_distant = gen_train_facts(os.path.join(data_dir, "train_distant.json"), truth_dir) 78 | 79 | truth = json.load(open(os.path.join(data_dir, f"{split}.json"))) 80 | 81 | std = {} 82 | tot_evidences = 0 83 | titleset = set([]) 84 | 85 | title2vertexSet = {} 86 | 87 | for x in truth: 88 | title = x['title'] 89 | titleset.add(title) 90 | 91 | vertexSet = x['vertexSet'] 92 | title2vertexSet[title] = vertexSet 93 | 94 | for label in x['labels']: 95 | r = label['r'] 96 | h_idx = label['h'] 97 | t_idx = label['t'] 98 | std[(title, r, h_idx, t_idx)] = set(label['evidence']) 99 | tot_evidences += len(label['evidence']) 100 | 101 | tot_relations = len(std) 102 | tmp.sort(key=lambda x: (x['title'], x['h_idx'], x['t_idx'], x['r'])) 103 | submission_answer = [tmp[0]] 104 | for i in range(1, len(tmp)): 105 | x = tmp[i] 106 | y = tmp[i - 1] 107 | if (x['title'], x['h_idx'], x['t_idx'], x['r']) != (y['title'], y['h_idx'], y['t_idx'], y['r']): 108 | submission_answer.append(tmp[i]) 109 | 110 | correct_re = 0 111 | correct_evidence = 0 112 | pred_evi = 0 113 | 114 | correct_in_train_annotated = 0 115 | correct_in_train_distant = 0 116 | titleset2 = set([]) 117 | 118 | with tqdm("Calculating official scores", total=len(submission_answer), ncols=100) as pbar: 119 | for x in submission_answer: 120 | title = x['title'] 121 | h_idx = x['h_idx'] 122 | t_idx = x['t_idx'] 123 | r = x['r'] 124 | titleset2.add(title) 125 | if title not in title2vertexSet: 126 | continue 127 | vertexSet = title2vertexSet[title] 128 | 129 | if 'evidence' in x: 130 | evi = set(x['evidence']) 131 | else: 132 | evi = set([]) 133 | pred_evi += len(evi) 134 | 135 | if (title, r, h_idx, t_idx) in std: 136 | correct_re += 1 137 | stdevi = std[(title, r, h_idx, t_idx)] 138 | correct_evidence += len(stdevi & evi) 139 | in_train_annotated = in_train_distant = False 140 | for n1 in vertexSet[h_idx]: 141 | for n2 in vertexSet[t_idx]: 142 | if (n1['name'], n2['name'], r) in fact_in_train_annotated: 143 | in_train_annotated = True 144 | if (n1['name'], n2['name'], r) in fact_in_train_distant: 145 | in_train_distant = True 146 | 147 | if in_train_annotated: 148 | correct_in_train_annotated += 1 149 | if in_train_distant: 150 | correct_in_train_distant += 1 151 | pbar.update(1) 152 | 153 | re_p = 1.0 * correct_re / len(submission_answer) 154 | re_r = 1.0 * correct_re / tot_relations 155 | if re_p + re_r == 0: 156 | re_f1 = 0 157 | else: 158 | re_f1 = 2.0 * re_p * re_r / (re_p + re_r) 159 | 160 | evi_p = 1.0 * correct_evidence / pred_evi if pred_evi > 0 else 0 161 | evi_r = 1.0 * correct_evidence / tot_evidences 162 | if evi_p + evi_r == 0: 163 | evi_f1 = 0 164 | else: 165 | evi_f1 = 2.0 * evi_p * evi_r / (evi_p + evi_r) 166 | 167 | re_p_ignore_train_annotated = 1.0 * (correct_re - correct_in_train_annotated) / (len(submission_answer) - correct_in_train_annotated + 1e-5) 168 | re_p_ignore_train = 1.0 * (correct_re - correct_in_train_distant) / (len(submission_answer) - correct_in_train_distant + 1e-5) 169 | 170 | if re_p_ignore_train_annotated + re_r == 0: 171 | re_f1_ignore_train_annotated = 0 172 | else: 173 | re_f1_ignore_train_annotated = 2.0 * re_p_ignore_train_annotated * re_r / (re_p_ignore_train_annotated + re_r) 174 | 175 | if re_p_ignore_train + re_r == 0: 176 | re_f1_ignore_train = 0 177 | else: 178 | re_f1_ignore_train = 2.0 * re_p_ignore_train * re_r / (re_p_ignore_train + re_r) 179 | 180 | return re_f1, evi_f1, re_f1_ignore_train_annotated, re_f1_ignore_train 181 | 182 | 183 | def unofficial_evaluate(preds, labels, dataset_name="DocRED"): 184 | score_dict = {} 185 | best_theta = -1 186 | 187 | # need to find optimal threshold where f1 is maximized 188 | sorted_logits, sorted_idxes = preds.flatten().sort(descending=True) 189 | sorted_labels = torch.gather(labels.flatten(), dim=0, index=sorted_idxes) 190 | predictions = torch.ones_like(sorted_logits).to(sorted_logits) 191 | num_preds = predictions.cumsum(0) 192 | num_labels = labels.sum() 193 | num_matches = (predictions * sorted_labels).cumsum(0) 194 | precisions = num_matches / num_preds 195 | recalls = num_matches / num_labels 196 | f1s = 2 * precisions * recalls / (precisions + recalls + 1e-20) 197 | 198 | f1, best_f1_pos = f1s.max(0) 199 | precision = precisions[best_f1_pos] 200 | recall = recalls[best_f1_pos] 201 | num_matches = num_matches[best_f1_pos] 202 | num_preds = num_preds[best_f1_pos] 203 | best_theta = sorted_logits[best_f1_pos].item() 204 | 205 | num_preds_per_class = (preds >= best_theta).sum(0) 206 | num_matches_per_class = (labels * (preds >= best_theta)).sum(0) 207 | num_labels_per_class = labels.sum(0) 208 | 209 | # Calculate macro F1 210 | precision_per_class = num_matches_per_class / (num_preds_per_class + 1e-20) 211 | recall_per_class = num_matches_per_class / (num_labels_per_class + 1e-20) 212 | f1_per_class = 2 * precision_per_class * recall_per_class / (precision_per_class + recall_per_class + 1e-20) 213 | macro_f1 = f1_per_class.mean() 214 | 215 | # class frequency 216 | if dataset_name == "DocRED": 217 | label_freq = torch.tensor([264, 8921, 4193, 2004, 2689, 1044, 511, 79, 475, 79, 275, 356, 172, 76, 194, 539, 35, 583, 632, 414, 1052, 1142, 621, 95, 203, 316, 805, 196, 173, 210, 596, 85, 303, 74, 273, 360, 119, 155, 150, 238, 304, 104, 406, 96, 62, 335, 298, 246, 156, 82, 188, 192, 166, 108, 208, 185, 23, 163, 144, 299, 231, 152, 79, 63, 223, 110, 51, 36, 379, 320, 48, 111, 85, 137, 119, 191, 140, 144, 33, 66, 9, 77, 103, 95, 100, 172, 83, 92, 92, 2, 75, 36, 36, 18, 2, 4]).to(f1_per_class) 218 | elif dataset_name == "Re-DocRED": 219 | label_freq = torch.tensor([263, 14401, 20402, 3369, 4665, 1172, 692, 155, 868, 181, 575, 761, 336, 178, 431, 948, 66, 923, 2313, 1299, 1773, 1621, 919, 200, 281, 503, 1000, 421, 340, 368, 2112, 178, 640, 168, 466, 703, 281, 366, 3055, 402, 460, 204, 403, 191, 102, 712, 1207, 341, 237, 152, 506, 506, 305, 191, 389, 356, 49, 370, 245, 669, 410, 264, 171, 145, 1168, 222, 105, 79, 379, 489, 83, 239, 174, 293, 249, 1168, 292, 357, 59, 107, 22, 152, 225, 192, 204, 298, 144, 230, 230, 2, 117, 65, 96, 96, 8, 8]).to(f1_per_class) 220 | elif dataset_name == "DWIE": 221 | label_freq = torch.tensor([83, 133, 470, 1403, 751, 1572, 1518, 307, 291, 211, 193, 1255, 2005, 1597, 137, 184, 170, 1703, 1206, 158, 361, 326, 68, 5, 11, 99, 18, 242, 253, 51, 57, 367, 32, 123, 30, 4, 21, 126, 87, 16, 43, 25, 7, 27, 6, 16, 16, 16, 11, 43, 30, 12, 7, 5, 9, 2, 2, 3, 0, 2, 1, 0, 0, 1, 1]).to(f1_per_class) 222 | 223 | macro_f1_at_500 = f1_per_class[label_freq < 500].mean() 224 | macro_f1_at_200 = f1_per_class[label_freq < 200].mean() 225 | macro_f1_at_100 = f1_per_class[label_freq < 100].mean() 226 | 227 | score_dict["P"] = precision.item() 228 | score_dict["R"] = recall.item() 229 | score_dict["F1"] = f1.item() 230 | score_dict["macro_F1"] = macro_f1.item() 231 | score_dict["macro_F1_at_500"] = macro_f1_at_500.item() 232 | score_dict["macro_F1_at_200"] = macro_f1_at_200.item() 233 | score_dict["macro_F1_at_100"] = macro_f1_at_100.item() 234 | score_dict["F1_per_class"] = f1_per_class.tolist() 235 | score_dict["num_matches"] = num_matches.long().item() 236 | score_dict["num_preds"] = num_preds.long().item() 237 | score_dict["num_labels"] = num_labels.long().item() 238 | score_dict["theta"] = best_theta 239 | 240 | return score_dict 241 | 242 | 243 | def calibrate(logits, labels, preds=None): 244 | _logits = logits.flatten().cpu().numpy() 245 | _labels = labels.flatten().cpu().numpy() 246 | 247 | N = len(_logits) # total sample size 248 | _, num_labels = logits.shape 249 | n_bins = 10 250 | bins = np.linspace(0.0, 1.0, n_bins + 1) 251 | 252 | # ECE & reliability diagram for ALL 253 | binids = np.searchsorted(bins[1:-1], _logits) 254 | bin_sums = np.bincount(binids, weights=_logits, minlength=len(bins)) 255 | bin_true = np.bincount(binids, weights=_labels, minlength=len(bins)) 256 | bin_total = np.bincount(binids, minlength=len(bins)) 257 | nonzero = bin_total != 0 258 | prob_true = bin_true[nonzero] / bin_total[nonzero] 259 | prob_pred = bin_sums[nonzero] / bin_total[nonzero] 260 | ece = ((bin_total[nonzero] / N) * abs(prob_true - prob_pred)).sum() 261 | 262 | # ACE 263 | ace_list = [] 264 | for k in range(num_labels): 265 | _class_logits = logits[:, k].cpu().numpy() 266 | _class_labels = labels[:, k].cpu().numpy() 267 | if preds is not None: 268 | _class_preds = preds[:, k].cpu().numpy() 269 | 270 | even_bins = np.percentile(_class_logits, bins * 100) 271 | _class_binids_even = np.searchsorted(even_bins[1:-1], _class_logits) 272 | 273 | if preds is not None: 274 | _class_bin_sums_even = np.bincount(_class_binids_even, weights=_class_preds, minlength=len(even_bins)) 275 | else: 276 | _class_bin_sums_even = np.bincount(_class_binids_even, weights=_class_logits, minlength=len(even_bins)) 277 | _class_bin_true_even = np.bincount(_class_binids_even, weights=_class_labels, minlength=len(even_bins)) 278 | _class_bin_total_even = np.bincount(_class_binids_even, minlength=len(even_bins)) 279 | _nonzero_even = _class_bin_total_even != 0 280 | _class_prob_true_even = _class_bin_true_even[_nonzero_even] / _class_bin_total_even[_nonzero_even] 281 | _class_prob_pred_even = _class_bin_sums_even[_nonzero_even] / _class_bin_total_even[_nonzero_even] 282 | _class_ace_score = abs(_class_prob_true_even - _class_prob_pred_even).mean() 283 | ace_list.append(_class_ace_score) 284 | 285 | ace = sum(ace_list) / len(ace_list) 286 | 287 | return ece, ace, prob_true, prob_pred 288 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import os.path as osp 4 | import json 5 | import random 6 | from collections import defaultdict 7 | 8 | import torch 9 | from torch.utils.data import Dataset, DataLoader 10 | 11 | from tqdm import tqdm 12 | 13 | from utils import seed_worker 14 | 15 | 16 | def load_data(args, config, tokenizer, split="train"): 17 | 18 | if args.dataset_name in {"DocRED", "Re-DocRED"}: 19 | dataset = DocREDDataset(args, config, tokenizer, split) 20 | elif args.dataset_name == "DWIE": 21 | dataset = DWIEDataset(args, config, tokenizer, split) 22 | else: 23 | raise ValueError("Dataset must be DocRED, Re-DocRED, or DWIE.") 24 | 25 | if split == "train": 26 | dataloader = DataLoader(dataset, 27 | batch_size=args.train_batch_size, 28 | collate_fn=dataset.collate_fn, 29 | worker_init_fn=seed_worker, 30 | num_workers=args.num_workers, 31 | shuffle=True, 32 | drop_last=True, 33 | pin_memory=True) 34 | elif split == "dev": 35 | dataloader = DataLoader(dataset, 36 | batch_size=args.eval_batch_size, 37 | collate_fn=dataset.collate_fn, 38 | shuffle=False, 39 | drop_last=False, 40 | pin_memory=True) 41 | elif split =="test": 42 | dataloader = DataLoader(dataset, 43 | batch_size=args.test_batch_size, 44 | collate_fn=dataset.collate_fn, 45 | shuffle=False, 46 | drop_last=False) 47 | else: 48 | raise ValueError("Data split must be either train/dev/test.") 49 | 50 | return dataloader, dataset.features 51 | 52 | 53 | def load_and_cache_relations(args, config, tokenizer): 54 | save_dir = osp.join(args.data_dir, "cached") 55 | save_path = osp.join(save_dir, f"{args.model_name_or_path}_reldesc{args.num_labels}.pt") 56 | 57 | os.makedirs(save_dir, exist_ok=True) 58 | if osp.exists(save_path): 59 | logging.info(f"Loading relation features from {save_path}") 60 | return torch.load(save_path) 61 | 62 | with open(osp.join(args.data_dir, f"rel_desc.json")) as f: 63 | relations = json.load(f) 64 | 65 | with open(osp.join(args.data_dir, f"label_map.json")) as f: 66 | label_map = json.load(f) 67 | 68 | relation_features = [None] * args.num_labels 69 | for rel_id, relation in relations.items(): 70 | input_ids = tokenizer.encode(relation) 71 | attention_mask = [1] * len(input_ids) 72 | relation_features[label_map[rel_id]] = {"input_ids": input_ids, "attention_mask": attention_mask} 73 | 74 | # Collate 75 | PAD = config.pad_token_id 76 | max_len = max([len(r["input_ids"]) for r in relation_features]) 77 | input_ids = [r["input_ids"] + [PAD] * (max_len - len(r["input_ids"])) for r in relation_features] 78 | attention_mask = [r["attention_mask"] + [0] * (max_len - len(r["attention_mask"])) for r in relation_features] 79 | input_ids = torch.tensor(input_ids, dtype=torch.long) 80 | attention_mask = torch.tensor(attention_mask, dtype=torch.long) 81 | 82 | relation_features = {"input_ids": input_ids, "attention_mask": attention_mask} 83 | logging.info(f"Saving relation features to {save_path}") 84 | torch.save(relation_features, save_path) 85 | 86 | return relation_features 87 | 88 | 89 | class DocREDDataset(Dataset): 90 | 91 | def __init__(self, args, config, tokenizer, split="train"): 92 | self.args = args 93 | self.config = config 94 | self.tokenizer = tokenizer 95 | self.split = split 96 | self.features = [] 97 | 98 | self.ent_marked = "_entmarked" if args.mark_entities else "" 99 | self.save_dir = osp.join(args.data_dir, "cached") 100 | self.save_path = osp.join(self.save_dir, f"{split}_{args.model_name_or_path}{self.ent_marked}.pt") 101 | os.makedirs(self.save_dir, exist_ok=True) 102 | 103 | self.ner_map = {'PAD':0, 'ORG':1, 'LOC':2, 'NUM':3, 'TIME':4, 'MISC':5, 'PER':6} 104 | with open(osp.join(args.data_dir, "label_map.json"), "r") as f: 105 | self.label_map = json.load(f) 106 | 107 | self.__load_and_cache_examples() 108 | 109 | # Set up resource-constrained setting 110 | if self.split == "train" and args.num_train_ratio < 1: 111 | num_train = round(len(self.features) * self.args.num_train_ratio) 112 | # keep random sampling until label distribution resembles that of the full data 113 | if args.dataset_name == "DocRED": 114 | label_freq = [1163035, 264, 8921, 4193, 2004, 2689, 1044, 511, 79, 475, 79, 275, 356, 172, 76, 194, 539, 35, 583, 632, 414, 1052, 1142, 621, 95, 203, 316, 805, 196, 173, 210, 596, 85, 303, 74, 273, 360, 119, 155, 150, 238, 304, 104, 406, 96, 62, 335, 298, 246, 156, 82, 188, 192, 166, 108, 208, 185, 23, 163, 144, 299, 231, 152, 79, 63, 223, 110, 51, 36, 379, 320, 48, 111, 85, 137, 119, 191, 140, 144, 33, 66, 9, 77, 103, 95, 100, 172, 83, 92, 92, 2, 75, 36, 36, 18, 2, 4] 115 | elif args.dataset_name == "Re-DocRED": 116 | label_freq = [1125284, 263, 14401, 20402, 3369, 4665, 1172, 692, 155, 868, 181, 575, 761, 336, 178, 431, 948, 66, 923, 2313, 1299, 1773, 1621, 919, 200, 281, 503, 1000, 421, 340, 368, 2112, 178, 640, 168, 466, 703, 281, 366, 3055, 402, 460, 204, 403, 191, 102, 712, 1207, 341, 237, 152, 506, 506, 305, 191, 389, 356, 49, 370, 245, 669, 410, 264, 171, 145, 1168, 222, 105, 79, 379, 489, 83, 239, 174, 293, 249, 1168, 292, 357, 59, 107, 22, 152, 225, 192, 204, 298, 144, 230, 230, 2, 117, 65, 96, 96, 8, 8] 117 | label_dist = torch.tensor(label_freq) / sum(label_freq) 118 | 119 | sampled_features = random.sample(self.features, num_train) 120 | sampled_freq = torch.stack([torch.tensor(x["labels"]).sum(0) for x in sampled_features]).sum(0) 121 | sampled_dist = sampled_freq / sampled_freq.sum() 122 | while not torch.allclose(label_dist, sampled_dist, atol=1e-03): 123 | sampled_features = random.sample(self.features, num_train) 124 | sampled_freq = torch.stack([torch.tensor(x["labels"]).sum(0) for x in sampled_features]).sum(0) 125 | sampled_dist = sampled_freq / sampled_freq.sum() 126 | 127 | self.features = sampled_features 128 | 129 | 130 | def __load_and_cache_examples(self): 131 | if osp.exists(self.save_path): 132 | logging.info(f"Loading features from {self.save_path}") 133 | self.features = torch.load(self.save_path) 134 | return 135 | 136 | logging.info(f"Creating features to {self.save_path}") 137 | with open(osp.join(self.args.data_dir, f"{self.split}.json")) as f: 138 | examples = json.load(f) 139 | 140 | num_pos_samples, num_neg_samples = 0, 0 141 | 142 | for ex in tqdm(examples, desc="Converting examples to features"): 143 | ents = ex["vertexSet"] 144 | 145 | # Locate start & end of entity mention for entity marking 146 | ent_start, ent_end = set(), set() 147 | if self.args.mark_entities: 148 | for ent in ents: 149 | for ment in ent: 150 | ent_start.add((ment["sent_id"], ment["pos"][0])) 151 | ent_end.add((ment["sent_id"], ment["pos"][1]-1)) 152 | 153 | # Map each word idx to subword idx 154 | input_tokens = [] 155 | token_idx_map = [] 156 | tok_to_sent = [] 157 | for sent_idx, sent in enumerate(ex["sents"]): 158 | idx_map = {} 159 | for word_idx, word in enumerate(sent): 160 | tokens = self.tokenizer.tokenize(word) 161 | if (sent_idx, word_idx) in ent_start: 162 | tokens = ["*"] + tokens 163 | if (sent_idx, word_idx) in ent_end: 164 | tokens = tokens + ["*"] 165 | idx_map[word_idx] = len(input_tokens) 166 | tok_to_sent += [sent_idx] * len(tokens) 167 | input_tokens += tokens 168 | idx_map[word_idx+1] = len(input_tokens) 169 | token_idx_map.append(idx_map) 170 | 171 | input_tokens = input_tokens[:self.args.max_seq_length-2] # truncate to max sequence length 172 | input_ids = self.tokenizer.convert_tokens_to_ids(input_tokens) # convert tokens to ids 173 | input_ids = self.tokenizer.build_inputs_with_special_tokens(input_ids) # add [CLS] & [SEP] 174 | tok_to_sent = [None] + tok_to_sent + [None] 175 | 176 | # Locate spans of each entity mention 177 | ent_pos = [] 178 | for ent in ents: 179 | ent_pos.append([]) 180 | for ment in ent: 181 | # Get subword idx of the entity mention 182 | token_start_pos = token_idx_map[ment["sent_id"]][ment["pos"][0]] 183 | token_end_pos = token_idx_map[ment["sent_id"]][ment["pos"][1]] 184 | ent_pos[-1].append((token_start_pos, token_end_pos)) 185 | 186 | ground_truth_triples = defaultdict(list) 187 | if ex.get("labels"): # test file does not have "labels" 188 | for label in ex["labels"]: 189 | rel_id = self.label_map[label["r"]] 190 | ground_truth_triples[(label["h"], label["t"])].append({"relation": rel_id, "evidence": label["evidence"]}) 191 | 192 | # Create positive pairs 193 | ent_pairs, rel_vectors = [], [] 194 | for (h, t), instances in ground_truth_triples.items(): 195 | rel_vector = [0] * len(self.label_map) 196 | for instance in instances: 197 | rel_vector[instance["relation"]] = 1 198 | rel_vectors.append(rel_vector) 199 | ent_pairs.append((h, t)) 200 | num_pos_samples += 1 201 | 202 | # Create negative pairs 203 | for h in range(len(ents)): 204 | for t in range(len(ents)): 205 | if h != t and (h, t) not in ent_pairs: 206 | rel_vector = [1] + [0] * (len(self.label_map)-1) 207 | rel_vectors.append(rel_vector) 208 | ent_pairs.append((h, t)) 209 | num_neg_samples += 1 210 | 211 | assert len(rel_vectors) == len(ent_pairs) == (len(ents) * (len(ents)-1)) 212 | 213 | self.features.append({ 214 | "input_ids": input_ids, 215 | "ent_pos": ent_pos, 216 | "ent_pairs": ent_pairs, 217 | "title": ex["title"], # needed for test submission 218 | "labels": rel_vectors, 219 | }) 220 | 221 | logging.info(f"# of documents: {len(self.features)}") 222 | logging.info(f"# of positive pairs {num_pos_samples}") 223 | logging.info(f"# of negative pairs {num_neg_samples}") 224 | logging.info(f"Saving features to {self.save_path}") 225 | torch.save(self.features, self.save_path) 226 | 227 | 228 | def collate_fn(self, samples): 229 | PAD = self.config.pad_token_id 230 | max_len = max([len(x["input_ids"]) for x in samples]) 231 | input_ids = [x["input_ids"] + [PAD] * (max_len - len(x["input_ids"])) for x in samples] 232 | attention_mask = [[1] * len(x["input_ids"]) + [0] * (max_len - len(x["input_ids"])) for x in samples] 233 | 234 | ent_pos = [x["ent_pos"] for x in samples] 235 | ent_pairs = [x["ent_pairs"] for x in samples] 236 | labels = [x["labels"] for x in samples] 237 | 238 | input_ids = torch.tensor(input_ids, dtype=torch.long) 239 | attention_mask = torch.tensor(attention_mask, dtype=torch.long) 240 | 241 | return {"input_ids": input_ids, 242 | "attention_mask": attention_mask, 243 | "ent_pos": ent_pos, 244 | "ent_pairs": ent_pairs, 245 | "labels": labels} 246 | 247 | def __len__(self): 248 | return len(self.features) 249 | 250 | def __getitem__(self, idx): 251 | return self.features[idx] 252 | 253 | 254 | class DWIEDataset(Dataset): 255 | 256 | def __init__(self, args, config, tokenizer, split="train"): 257 | self.args = args 258 | self.config = config 259 | self.tokenizer = tokenizer 260 | self.split = split 261 | self.features = [] 262 | 263 | self.ent_marked = "_entmarked" if args.mark_entities else "" 264 | self.long = "_long" if args.long_seq == 1 else "" 265 | self.save_dir = osp.join(args.data_dir, "cached") 266 | self.save_path = osp.join(self.save_dir, f"{split}_{args.model_name_or_path}{self.ent_marked}{self.long}.pt") 267 | os.makedirs(self.save_dir, exist_ok=True) 268 | 269 | with open(osp.join(args.data_dir, "label_map.json"), "r") as f: 270 | self.label_map = json.load(f) 271 | 272 | self.__load_and_cache_examples() 273 | 274 | # Set up resource-constrained setting 275 | if self.split == "train": 276 | num_train = round(len(self.features) * self.args.num_train_ratio) 277 | # keep random sampling until label distribution resembles that of the full data import pdb; pdb.set_trace() 278 | label_freq = [601051, 83, 133, 470, 1403, 751, 1572, 1518, 307, 291, 211, 193, 1255, 2005, 1597, 137, 184, 170, 1703, 1206, 158, 361, 326, 68, 5, 11, 99, 18, 242, 253, 51, 57, 367, 32, 123, 30, 4, 21, 126, 87, 16, 43, 25, 7, 27, 6, 16, 16, 16, 11, 43, 30, 12, 7, 5, 9, 2, 2, 3, 0, 2, 1, 0, 0, 1, 1] 279 | label_dist = torch.tensor(label_freq) / sum(label_freq) 280 | sampled_features = random.sample(self.features, num_train) 281 | sampled_freq = torch.stack([torch.tensor(x["labels"]).sum(0) for x in sampled_features]).sum(0) 282 | sampled_dist = sampled_freq / sampled_freq.sum() 283 | while not torch.allclose(label_dist, sampled_dist, atol=1e-03): 284 | sampled_features = random.sample(self.features, num_train) 285 | sampled_freq = torch.stack([torch.tensor(x["labels"]).sum(0) for x in sampled_features]).sum(0) 286 | sampled_dist = sampled_freq / sampled_freq.sum() 287 | 288 | self.features = sampled_features 289 | 290 | 291 | def __load_and_cache_examples(self): 292 | if osp.exists(self.save_path): 293 | logging.info(f"Loading features from {self.save_path}") 294 | self.features = torch.load(self.save_path) 295 | return 296 | 297 | num_pos_samples, num_neg_samples = 0, 0 298 | 299 | logging.info(f"Creating features to {self.save_path}") 300 | for filename in tqdm(os.listdir(osp.join(self.args.data_dir, self.split)), desc="Converting examples to features"): 301 | if osp.isfile(osp.join(self.args.data_dir, self.split, filename)): 302 | with open(osp.join(self.args.data_dir, self.split, filename), 'r') as f: 303 | ex = json.load(f) 304 | 305 | start = 0 306 | input_tokens = [] 307 | token_idx_map = defaultdict(list) 308 | for ment in ex["mentions"]: 309 | # tokenize text up to entity mention 310 | end = ment["begin"] 311 | words = ex["content"][start:end].strip() 312 | before_tokens = self.tokenizer.tokenize(words) 313 | 314 | # tokenize entity mention 315 | start, end = ment["begin"], ment["end"] 316 | ment_word = ex["content"][start:end].strip() 317 | ment_tokens = self.tokenizer.tokenize(ment_word) 318 | if self.args.mark_entities: 319 | ment_tokens = ["*"] + ment_tokens + ["*"] 320 | 321 | # For each entity, store the token position (start, end) of mention 322 | token_start_pos = len(input_tokens) + len(before_tokens) 323 | token_end_pos = token_start_pos + len(ment_tokens) 324 | token_idx_map[ment["concept"]].append((token_start_pos, token_end_pos)) 325 | 326 | input_tokens += before_tokens + ment_tokens 327 | start = ment["end"] 328 | 329 | # Finish tokenizing the text 330 | after_tokens = self.tokenizer.tokenize(ex["content"][start:].strip()) 331 | input_tokens += after_tokens 332 | if not self.args.long_seq: 333 | input_tokens = input_tokens[:self.args.max_seq_length-2] # truncate to max sequence length 334 | input_ids = self.tokenizer.convert_tokens_to_ids(input_tokens) # convert tokens to ids 335 | input_ids = self.tokenizer.build_inputs_with_special_tokens(input_ids) # add [CLS] & [SEP] 336 | 337 | # Convert to ent_pos format (following DocRED) 338 | ent_pos = [] 339 | for i in range(len(ex["concepts"])): 340 | if token_idx_map.get(i): 341 | ent_pos.append(token_idx_map[i]) 342 | else: # there are annotated entities that do not exist in the document 343 | ent_pos.append([]) 344 | 345 | ground_truth_triples = defaultdict(list) 346 | if ex.get("relations"): 347 | for label in ex["relations"]: 348 | rel_id = self.label_map[label["p"]] 349 | # Remove relations where entities do not exist in the document 350 | if len(ent_pos[label["s"]]) != 0 and len(ent_pos[label["o"]]) != 0: 351 | ground_truth_triples[(label["s"], label["o"])].append(rel_id) 352 | 353 | # Create positive pairs 354 | ent_pairs, rel_vectors = [], [] 355 | for (h, t), relations in ground_truth_triples.items(): 356 | rel_vector = [0] * len(self.label_map) 357 | for r in relations: 358 | rel_vector[r] = 1 359 | rel_vectors.append(rel_vector) 360 | ent_pairs.append((h, t)) 361 | num_pos_samples += 1 362 | 363 | # Create negative pairs 364 | for h in range(len(ex["concepts"])): 365 | for t in range(len(ex["concepts"])): 366 | if h != t and (h, t) not in ent_pairs and len(ent_pos[h]) != 0 and len(ent_pos[t]) != 0: 367 | # if h != t and (h, t) not in ent_pairs: 368 | rel_vector = [1] + [0] * (len(self.label_map)-1) 369 | rel_vectors.append(rel_vector) 370 | ent_pairs.append((h, t)) 371 | num_neg_samples += 1 372 | 373 | assert len(rel_vectors) == len(ent_pairs) 374 | 375 | self.features.append({ 376 | "input_ids": input_ids, 377 | "ent_pos": ent_pos, 378 | "ent_pairs": ent_pairs, 379 | "labels": rel_vectors, 380 | }) 381 | 382 | logging.info(f"# of documents: {len(self.features)}") 383 | logging.info(f"# of positive pairs {num_pos_samples}") 384 | logging.info(f"# of negative pairs {num_neg_samples}") 385 | logging.info(f"Saving features to {self.save_path}") 386 | torch.save(self.features, self.save_path) 387 | 388 | 389 | def collate_fn(self, samples): 390 | PAD = self.config.pad_token_id 391 | max_len = max([len(x["input_ids"]) for x in samples]) 392 | input_ids = [x["input_ids"] + [PAD] * (max_len - len(x["input_ids"])) for x in samples] 393 | attention_mask = [[1] * len(x["input_ids"]) + [0] * (max_len - len(x["input_ids"])) for x in samples] 394 | 395 | ent_pos = [x["ent_pos"] for x in samples] 396 | ent_pairs = [x["ent_pairs"] for x in samples] 397 | labels = [x["labels"] for x in samples] 398 | 399 | input_ids = torch.tensor(input_ids, dtype=torch.long) 400 | attention_mask = torch.tensor(attention_mask, dtype=torch.long) 401 | 402 | return {"input_ids": input_ids, 403 | "attention_mask": attention_mask, 404 | "ent_pos": ent_pos, 405 | "ent_pairs": ent_pairs, 406 | "labels": labels} 407 | 408 | def __len__(self): 409 | return len(self.features) 410 | 411 | def __getitem__(self, idx): 412 | return self.features[idx] 413 | --------------------------------------------------------------------------------