├── arl ├── __init__.py ├── gadgets │ ├── __init__.py │ └── my_metrics.py ├── utils │ └── __init__.py ├── modules │ ├── language_encoders │ │ └── __init__.py │ ├── vision_encoders │ │ └── __init__.py │ ├── __init__.py │ ├── position_embeddings.py │ ├── prediction_heads.py │ └── dist_utils.py ├── transforms │ ├── __init__.py │ ├── utils.py │ ├── transform.py │ └── randaug.py ├── datasets │ ├── __init__.py │ ├── pretraining_roco_dataset.py │ ├── irtr_roco_dataset.py │ ├── pretraining_medicat_dataset.py │ ├── cls_melinda_dataset.py │ ├── vqa_slack_dataset.py │ ├── vqa_vqa_rad_dataset.py │ ├── vqa_medvqa_2019_dataset.py │ ├── vqa_medvqa_2021_dataset.py │ ├── utils.py │ └── pretraining_mimic_cxr_dataset.py └── datamodules │ ├── pretraining_roco_datamodule.py │ ├── irtr_roco_datamodule.py │ ├── pretraining_medicat_datamodule.py │ ├── cls_melinda_datamodule.py │ ├── pretraining_mimic_cxr_datamodule.py │ ├── __init__.py │ ├── vqa_slack_datamodule.py │ ├── vqa_vqa_rad_datamodule.py │ ├── vqa_medvqa_2019_datamodule.py │ ├── vqa_medvqa_2021_datamodule.py │ ├── multitask_datamodule.py │ └── base_datamodule.py ├── prepro ├── OpenKE │ ├── openke │ │ ├── make.sh │ │ ├── __init__.py │ │ ├── module │ │ │ ├── loss │ │ │ │ ├── Loss.py │ │ │ │ ├── __init__.py │ │ │ │ ├── SigmoidLoss.py │ │ │ │ ├── SoftplusLoss.py │ │ │ │ └── MarginLoss.py │ │ │ ├── strategy │ │ │ │ ├── Strategy.py │ │ │ │ ├── __init__.py │ │ │ │ └── NegativeSampling.py │ │ │ ├── __init__.py │ │ │ ├── model │ │ │ │ ├── Model.py │ │ │ │ ├── __init__.py │ │ │ │ ├── RESCAL.py │ │ │ │ ├── SimplE.py │ │ │ │ ├── DistMult.py │ │ │ │ ├── ComplEx.py │ │ │ │ ├── Analogy.py │ │ │ │ ├── TransE.py │ │ │ │ ├── TransR.py │ │ │ │ ├── HolE.py │ │ │ │ ├── RotatE.py │ │ │ │ ├── TransH.py │ │ │ │ └── TransD.py │ │ │ └── BaseModule.py │ │ ├── config │ │ │ ├── __init__.py │ │ │ ├── Trainer.py │ │ │ └── Tester.py │ │ └── base │ │ │ ├── Triple.h │ │ │ ├── Random.h │ │ │ ├── Setting.h │ │ │ ├── Corrupt.h │ │ │ ├── Base.cpp │ │ │ └── Reader.h │ ├── train_transe.py │ ├── examples │ │ ├── train_rescal_FB15K237.py │ │ ├── train_hole_WN18RR.py │ │ ├── train_transe_FB15K237.py │ │ ├── train_transh_FB15K237.py │ │ ├── train_analogy_WN18RR.py │ │ ├── train_complex_WN18RR.py │ │ ├── train_simple_WN18RR.py │ │ ├── train_transd_FB15K237.py │ │ ├── train_distmult_WN18RR.py │ │ ├── train_rotate_WN18RR_adv.py │ │ ├── train_distmult_WN18RR_adv.py │ │ ├── train_transe_WN18_adv_sigmoidloss.py │ │ └── train_transr_FB15K237.py │ └── README.md └── glossary.py ├── run_scripts ├── pretrain_arl.sh └── finetune_arl.sh ├── requirements.txt ├── .gitignore ├── main.py └── README.md /arl/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /arl/gadgets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /arl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /arl/modules/language_encoders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /arl/modules/vision_encoders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /arl/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .arl_module import ARLTransformerSS 2 | -------------------------------------------------------------------------------- /prepro/OpenKE/openke/make.sh: -------------------------------------------------------------------------------- 1 | mkdir release 2 | g++ ./base/Base.cpp -fPIC -shared -o ./release/Base.so -pthread -O3 -march=native 3 | -------------------------------------------------------------------------------- /prepro/OpenKE/openke/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function -------------------------------------------------------------------------------- /prepro/OpenKE/openke/module/loss/Loss.py: -------------------------------------------------------------------------------- 1 | from ..BaseModule import BaseModule 2 | 3 | class Loss(BaseModule): 4 | 5 | def __init__(self): 6 | super(Loss, self).__init__() -------------------------------------------------------------------------------- /prepro/OpenKE/openke/module/strategy/Strategy.py: -------------------------------------------------------------------------------- 1 | from ..BaseModule import BaseModule 2 | 3 | class Strategy(BaseModule): 4 | 5 | def __init__(self): 6 | super(Strategy, self).__init__() -------------------------------------------------------------------------------- /prepro/OpenKE/openke/module/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .BaseModule import BaseModule 6 | 7 | __init__ = { 8 | 'BaseModule' 9 | } -------------------------------------------------------------------------------- /prepro/OpenKE/openke/config/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .Trainer import Trainer 6 | from .Tester import Tester 7 | 8 | __all__ = [ 9 | 'Trainer', 10 | 'Tester' 11 | ] 12 | -------------------------------------------------------------------------------- /run_scripts/pretrain_arl.sh: -------------------------------------------------------------------------------- 1 | python main.py with data_root=data/pretrain_arrows_umls/ \ 2 | num_gpus=4 num_nodes=1 \ 3 | task_pretrain_arl \ 4 | per_gpu_batchsize=16 \ 5 | clip16 text_roberta \ 6 | image_size=288 max_text_len=64 max_num_ents=24 \ 7 | tokenizer=downloaded/roberta-base \ 8 | load_path=downloaded/meter.ckpt -------------------------------------------------------------------------------- /prepro/OpenKE/openke/module/strategy/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .Strategy import Strategy 6 | from .NegativeSampling import NegativeSampling 7 | 8 | __all__ = [ 9 | 'Strategy', 10 | 'NegativeSampling', 11 | ] -------------------------------------------------------------------------------- /prepro/OpenKE/openke/module/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .Loss import Loss 6 | from .MarginLoss import MarginLoss 7 | from .SoftplusLoss import SoftplusLoss 8 | from .SigmoidLoss import SigmoidLoss 9 | 10 | __all__ = [ 11 | 'Loss', 12 | 'MarginLoss', 13 | 'SoftplusLoss', 14 | 'SigmoidLoss', 15 | ] -------------------------------------------------------------------------------- /prepro/OpenKE/openke/module/model/Model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ..BaseModule import BaseModule 4 | 5 | 6 | class Model(BaseModule): 7 | 8 | def __init__(self, ent_tot, rel_tot): 9 | super(Model, self).__init__() 10 | self.ent_tot = ent_tot 11 | self.rel_tot = rel_tot 12 | 13 | def forward(self): 14 | raise NotImplementedError 15 | 16 | def predict(self): 17 | raise NotImplementedError -------------------------------------------------------------------------------- /arl/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .transform import ( 2 | clip_transform, 3 | clip_transform_randaug, 4 | clip_transform_resizedcrop 5 | ) 6 | 7 | _transforms = { 8 | "clip": clip_transform, 9 | "clip_randaug": clip_transform_randaug, 10 | "clip_resizedcrop": clip_transform_resizedcrop 11 | } 12 | 13 | 14 | def keys_to_transforms(keys: list, size=224): 15 | return [_transforms[key](size=size) for key in keys] 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.9.0 2 | torchvision==0.10.0 3 | pytorch-lightning==1.3.2 4 | transformers==4.6.0 5 | pyarrow==2.0.0 6 | timm==0.4.12 7 | Pillow==8.1.0 8 | tqdm==4.56.0 9 | ipdb==0.13.4 10 | numpy==1.19.5 11 | einops==0.3.0 12 | sacred==0.8.2 13 | pandas==1.1.5 14 | ftfy==6.0.3 15 | torchmetrics==0.7.2 16 | scikit-learn==1.0.2 17 | torch-cluster==1.5.9 18 | torch-geometric==2.0.2 19 | torch-scatter==2.0.9 20 | torch-sparse==0.6.12 21 | torch-spline-conv==1.2.1 22 | scispacy==0.5.0 -------------------------------------------------------------------------------- /arl/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .cls_melinda_dataset import CLSMELINDADataset 2 | from .irtr_roco_dataset import IRTRROCODataset 3 | from .pretraining_medicat_dataset import MedicatDataset 4 | from .pretraining_mimic_cxr_dataset import MIMICCXRDataset 5 | from .pretraining_roco_dataset import ROCODataset 6 | from .vqa_medvqa_2019_dataset import VQAMEDVQA2019Dataset 7 | from .vqa_medvqa_2021_dataset import VQAMEDVQA2021Dataset 8 | from .vqa_slack_dataset import VQASLACKDataset 9 | from .vqa_vqa_rad_dataset import VQAVQARADDataset 10 | -------------------------------------------------------------------------------- /arl/datamodules/pretraining_roco_datamodule.py: -------------------------------------------------------------------------------- 1 | from .base_datamodule import BaseDataModule 2 | from ..datasets import ROCODataset 3 | 4 | 5 | class ROCODataModule(BaseDataModule): 6 | def __init__(self, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | 9 | @property 10 | def dataset_cls(self): 11 | return ROCODataset 12 | 13 | @property 14 | def dataset_cls_no_false(self): 15 | return ROCODataset 16 | 17 | @property 18 | def dataset_name(self): 19 | return "roco" 20 | -------------------------------------------------------------------------------- /arl/datamodules/irtr_roco_datamodule.py: -------------------------------------------------------------------------------- 1 | from .base_datamodule import BaseDataModule 2 | from ..datasets import IRTRROCODataset 3 | 4 | 5 | class IRTRROCODataModule(BaseDataModule): 6 | def __init__(self, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | 9 | @property 10 | def dataset_cls(self): 11 | return IRTRROCODataset 12 | 13 | @property 14 | def dataset_cls_no_false(self): 15 | return IRTRROCODataset 16 | 17 | @property 18 | def dataset_name(self): 19 | return "irtr_roco" 20 | -------------------------------------------------------------------------------- /arl/datamodules/pretraining_medicat_datamodule.py: -------------------------------------------------------------------------------- 1 | from .base_datamodule import BaseDataModule 2 | from ..datasets import MedicatDataset 3 | 4 | 5 | class MedicatDataModule(BaseDataModule): 6 | def __init__(self, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | 9 | @property 10 | def dataset_cls(self): 11 | return MedicatDataset 12 | 13 | @property 14 | def dataset_cls_no_false(self): 15 | return MedicatDataset 16 | 17 | @property 18 | def dataset_name(self): 19 | return "medicat" 20 | -------------------------------------------------------------------------------- /arl/datamodules/cls_melinda_datamodule.py: -------------------------------------------------------------------------------- 1 | from .base_datamodule import BaseDataModule 2 | from ..datasets import CLSMELINDADataset 3 | 4 | 5 | class CLSMELINDADataModule(BaseDataModule): 6 | def __init__(self, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | 9 | @property 10 | def dataset_cls(self): 11 | return CLSMELINDADataset 12 | 13 | @property 14 | def dataset_cls_no_false(self): 15 | return CLSMELINDADataset 16 | 17 | @property 18 | def dataset_name(self): 19 | return "cls_melinda" 20 | -------------------------------------------------------------------------------- /arl/datamodules/pretraining_mimic_cxr_datamodule.py: -------------------------------------------------------------------------------- 1 | from .base_datamodule import BaseDataModule 2 | from ..datasets import MIMICCXRDataset 3 | 4 | 5 | class MIMICCXRDataModule(BaseDataModule): 6 | def __init__(self, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | 9 | @property 10 | def dataset_cls(self): 11 | return MIMICCXRDataset 12 | 13 | @property 14 | def dataset_cls_no_false(self): 15 | return MIMICCXRDataset 16 | 17 | @property 18 | def dataset_name(self): 19 | return "mimic_cxr" 20 | -------------------------------------------------------------------------------- /arl/datasets/pretraining_roco_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | 3 | 4 | class ROCODataset(BaseDataset): 5 | def __init__(self, *args, split="", **kwargs): 6 | assert split in ["train", "val", "test"] 7 | self.split = split 8 | 9 | if split == "train": 10 | names = ["roco_train"] 11 | elif split == "val": 12 | names = ["roco_val"] 13 | elif split == "test": 14 | names = ["roco_test"] 15 | else: 16 | raise ValueError 17 | 18 | super().__init__(*args, **kwargs, names=names, text_column_name="caption") 19 | 20 | def __getitem__(self, index): 21 | return self.get_suite(index) 22 | -------------------------------------------------------------------------------- /arl/datasets/irtr_roco_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | 3 | 4 | class IRTRROCODataset(BaseDataset): 5 | def __init__(self, *args, split="", **kwargs): 6 | assert split in ["train", "val", "test"] 7 | self.split = split 8 | 9 | if split == "train": 10 | names = ["irtr_roco_train"] 11 | elif split == "val": 12 | names = ["irtr_roco_val"] 13 | elif split == "test": 14 | names = ["irtr_roco_test"] 15 | else: 16 | raise ValueError 17 | 18 | super().__init__(*args, **kwargs, names=names, text_column_name="caption") 19 | 20 | def __getitem__(self, index): 21 | return self.get_suite(index) 22 | -------------------------------------------------------------------------------- /arl/datasets/pretraining_medicat_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | 3 | 4 | class MedicatDataset(BaseDataset): 5 | def __init__(self, *args, split="", **kwargs): 6 | assert split in ["train", "val", "test"] 7 | self.split = split 8 | 9 | if split == "train": 10 | names = ["medicat_train"] 11 | elif split == "val": 12 | names = ["medicat_val"] 13 | elif split == "test": 14 | names = ["medicat_test"] 15 | else: 16 | raise ValueError 17 | 18 | super().__init__(*args, **kwargs, names=names, text_column_name="caption") 19 | 20 | def __getitem__(self, index): 21 | return self.get_suite(index) 22 | -------------------------------------------------------------------------------- /prepro/OpenKE/openke/module/model/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .Model import Model 6 | from .TransE import TransE 7 | from .TransD import TransD 8 | from .TransR import TransR 9 | from .TransH import TransH 10 | from .DistMult import DistMult 11 | from .ComplEx import ComplEx 12 | from .RESCAL import RESCAL 13 | from .Analogy import Analogy 14 | from .SimplE import SimplE 15 | from .RotatE import RotatE 16 | 17 | __all__ = [ 18 | 'Model', 19 | 'TransE', 20 | 'TransD', 21 | 'TransR', 22 | 'TransH', 23 | 'DistMult', 24 | 'ComplEx', 25 | 'RESCAL', 26 | 'Analogy', 27 | 'SimplE', 28 | 'RotatE' 29 | ] -------------------------------------------------------------------------------- /prepro/OpenKE/openke/base/Triple.h: -------------------------------------------------------------------------------- 1 | #ifndef TRIPLE_H 2 | #define TRIPLE_H 3 | #include "Setting.h" 4 | 5 | struct Triple { 6 | 7 | INT h, r, t; 8 | 9 | static bool cmp_head(const Triple &a, const Triple &b) { 10 | return (a.h < b.h)||(a.h == b.h && a.r < b.r)||(a.h == b.h && a.r == b.r && a.t < b.t); 11 | } 12 | 13 | static bool cmp_tail(const Triple &a, const Triple &b) { 14 | return (a.t < b.t)||(a.t == b.t && a.r < b.r)||(a.t == b.t && a.r == b.r && a.h < b.h); 15 | } 16 | 17 | static bool cmp_rel(const Triple &a, const Triple &b) { 18 | return (a.h < b.h)||(a.h == b.h && a.t < b.t)||(a.h == b.h && a.t == b.t && a.r < b.r); 19 | } 20 | 21 | static bool cmp_rel2(const Triple &a, const Triple &b) { 22 | return (a.r < b.r)||(a.r == b.r && a.h < b.h)||(a.r == b.r && a.h == b.h && a.t < b.t); 23 | } 24 | 25 | }; 26 | 27 | #endif 28 | -------------------------------------------------------------------------------- /arl/datamodules/__init__.py: -------------------------------------------------------------------------------- 1 | from .cls_melinda_datamodule import CLSMELINDADataModule 2 | from .irtr_roco_datamodule import IRTRROCODataModule 3 | from .pretraining_medicat_datamodule import MedicatDataModule 4 | from .pretraining_mimic_cxr_datamodule import MIMICCXRDataModule 5 | from .pretraining_roco_datamodule import ROCODataModule 6 | from .vqa_medvqa_2019_datamodule import VQAMEDVQA2019DataModule 7 | from .vqa_medvqa_2021_datamodule import VQAMEDVQA2021DataModule 8 | from .vqa_slack_datamodule import VQASLACKDataModule 9 | from .vqa_vqa_rad_datamodule import VQAVQARADDataModule 10 | 11 | _datamodules = { 12 | "medicat": MedicatDataModule, 13 | "roco": ROCODataModule, 14 | "mimic_cxr": MIMICCXRDataModule, 15 | "vqa_vqa_rad": VQAVQARADDataModule, 16 | "vqa_slack": VQASLACKDataModule, 17 | "vqa_medvqa_2019": VQAMEDVQA2019DataModule, 18 | "vqa_medvqa_2021": VQAMEDVQA2021DataModule, 19 | "cls_melinda": CLSMELINDADataModule, 20 | "irtr_roco": IRTRROCODataModule 21 | } 22 | -------------------------------------------------------------------------------- /prepro/OpenKE/openke/base/Random.h: -------------------------------------------------------------------------------- 1 | #ifndef RANDOM_H 2 | #define RANDOM_H 3 | #include "Setting.h" 4 | #include 5 | 6 | // the random seeds for all threads. 7 | unsigned long long *next_random; 8 | 9 | // reset the random seeds for all threads 10 | extern "C" 11 | void randReset() { 12 | next_random = (unsigned long long *)calloc(workThreads, sizeof(unsigned long long)); 13 | for (INT i = 0; i < workThreads; i++) 14 | next_random[i] = rand(); 15 | } 16 | 17 | // get a random interger for the id-th thread with the corresponding random seed. 18 | unsigned long long randd(INT id) { 19 | next_random[id] = next_random[id] * (unsigned long long)(25214903917) + 11; 20 | return next_random[id]; 21 | } 22 | 23 | // get a random interger from the range [0,x) for the id-th thread. 24 | INT rand_max(INT id, INT x) { 25 | INT res = randd(id) % x; 26 | while (res < 0) 27 | res += x; 28 | return res; 29 | } 30 | 31 | // get a random interger from the range [a,b) for the id-th thread. 32 | INT rand(INT a, INT b){ 33 | return (rand() % (b-a))+ a; 34 | } 35 | #endif 36 | -------------------------------------------------------------------------------- /prepro/OpenKE/openke/module/loss/SigmoidLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from .Loss import Loss 6 | 7 | class SigmoidLoss(Loss): 8 | 9 | def __init__(self, adv_temperature = None): 10 | super(SigmoidLoss, self).__init__() 11 | self.criterion = nn.LogSigmoid() 12 | if adv_temperature != None: 13 | self.adv_temperature = nn.Parameter(torch.Tensor([adv_temperature])) 14 | self.adv_temperature.requires_grad = False 15 | self.adv_flag = True 16 | else: 17 | self.adv_flag = False 18 | 19 | def get_weights(self, n_score): 20 | return F.softmax(n_score * self.adv_temperature, dim = -1).detach() 21 | 22 | def forward(self, p_score, n_score): 23 | if self.adv_flag: 24 | return -(self.criterion(p_score).mean() + (self.get_weights(n_score) * self.criterion(-n_score)).sum(dim = -1).mean()) / 2 25 | else: 26 | return -(self.criterion(p_score).mean() + self.criterion(-n_score).mean()) / 2 27 | 28 | def predict(self, p_score, n_score): 29 | score = self.forward(p_score, n_score) 30 | return score.cpu().data.numpy() -------------------------------------------------------------------------------- /prepro/OpenKE/openke/module/loss/SoftplusLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from .Loss import Loss 6 | 7 | class SoftplusLoss(Loss): 8 | 9 | def __init__(self, adv_temperature = None): 10 | super(SoftplusLoss, self).__init__() 11 | self.criterion = nn.Softplus() 12 | if adv_temperature != None: 13 | self.adv_temperature = nn.Parameter(torch.Tensor([adv_temperature])) 14 | self.adv_temperature.requires_grad = False 15 | self.adv_flag = True 16 | else: 17 | self.adv_flag = False 18 | 19 | def get_weights(self, n_score): 20 | return F.softmax(n_score * self.adv_temperature, dim = -1).detach() 21 | 22 | def forward(self, p_score, n_score): 23 | if self.adv_flag: 24 | return (self.criterion(-p_score).mean() + (self.get_weights(n_score) * self.criterion(n_score)).sum(dim = -1).mean()) / 2 25 | else: 26 | return (self.criterion(-p_score).mean() + self.criterion(n_score).mean()) / 2 27 | 28 | 29 | def predict(self, p_score, n_score): 30 | score = self.forward(p_score, n_score) 31 | return score.cpu().data.numpy() -------------------------------------------------------------------------------- /prepro/OpenKE/openke/module/loss/MarginLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | from .Loss import Loss 7 | 8 | class MarginLoss(Loss): 9 | 10 | def __init__(self, adv_temperature = None, margin = 6.0): 11 | super(MarginLoss, self).__init__() 12 | self.margin = nn.Parameter(torch.Tensor([margin])) 13 | self.margin.requires_grad = False 14 | if adv_temperature != None: 15 | self.adv_temperature = nn.Parameter(torch.Tensor([adv_temperature])) 16 | self.adv_temperature.requires_grad = False 17 | self.adv_flag = True 18 | else: 19 | self.adv_flag = False 20 | 21 | def get_weights(self, n_score): 22 | return F.softmax(-n_score * self.adv_temperature, dim = -1).detach() 23 | 24 | def forward(self, p_score, n_score): 25 | if self.adv_flag: 26 | return (self.get_weights(n_score) * torch.max(p_score - n_score, -self.margin)).sum(dim = -1).mean() + self.margin 27 | else: 28 | return (torch.max(p_score - n_score, -self.margin)).mean() + self.margin 29 | 30 | 31 | def predict(self, p_score, n_score): 32 | score = self.forward(p_score, n_score) 33 | return score.cpu().data.numpy() -------------------------------------------------------------------------------- /prepro/OpenKE/openke/module/strategy/NegativeSampling.py: -------------------------------------------------------------------------------- 1 | from .Strategy import Strategy 2 | 3 | class NegativeSampling(Strategy): 4 | 5 | def __init__(self, model = None, loss = None, batch_size = 256, regul_rate = 0.0, l3_regul_rate = 0.0): 6 | super(NegativeSampling, self).__init__() 7 | self.model = model 8 | self.loss = loss 9 | self.batch_size = batch_size 10 | self.regul_rate = regul_rate 11 | self.l3_regul_rate = l3_regul_rate 12 | 13 | def _get_positive_score(self, score): 14 | positive_score = score[:self.batch_size] 15 | positive_score = positive_score.view(-1, self.batch_size).permute(1, 0) 16 | return positive_score 17 | 18 | def _get_negative_score(self, score): 19 | negative_score = score[self.batch_size:] 20 | negative_score = negative_score.view(-1, self.batch_size).permute(1, 0) 21 | return negative_score 22 | 23 | def forward(self, data): 24 | score = self.model(data) 25 | p_score = self._get_positive_score(score) 26 | n_score = self._get_negative_score(score) 27 | loss_res = self.loss(p_score, n_score) 28 | if self.regul_rate != 0: 29 | loss_res += self.regul_rate * self.model.regularization(data) 30 | if self.l3_regul_rate != 0: 31 | loss_res += self.l3_regul_rate * self.model.l3_regularization() 32 | return loss_res -------------------------------------------------------------------------------- /prepro/OpenKE/train_transe.py: -------------------------------------------------------------------------------- 1 | from OpenKE.openke.config import Trainer 2 | from OpenKE.openke.data import TrainDataLoader 3 | from OpenKE.openke.module.loss import MarginLoss 4 | from OpenKE.openke.module.model import TransE 5 | from OpenKE.openke.module.strategy import NegativeSampling 6 | 7 | 8 | def train_transe(): 9 | # dataloader for training 10 | train_dataloader = TrainDataLoader( 11 | in_path="data/knowledge/", 12 | nbatches=100, 13 | threads=8, 14 | sampling_mode="normal", 15 | bern_flag=1, 16 | filter_flag=1, 17 | neg_ent=25, 18 | neg_rel=0) 19 | 20 | # define the model 21 | transe = TransE( 22 | ent_tot=train_dataloader.get_ent_tot(), 23 | rel_tot=train_dataloader.get_rel_tot(), 24 | dim=256, 25 | p_norm=1, 26 | norm_flag=True) 27 | 28 | # define the loss function 29 | model = NegativeSampling( 30 | model=transe, 31 | loss=MarginLoss(margin=5.0), 32 | batch_size=train_dataloader.get_batch_size() 33 | ) 34 | 35 | # train the model 36 | trainer = Trainer(model=model, data_loader=train_dataloader, train_times=1000, alpha=1.0, use_gpu=True) 37 | trainer.run() 38 | transe.save_checkpoint('data/knowledge/ent_embeddings.ckpt') 39 | -------------------------------------------------------------------------------- /arl/datasets/cls_melinda_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .base_dataset import BaseDataset 4 | 5 | 6 | class CLSMELINDADataset(BaseDataset): 7 | def __init__(self, *args, split="", **kwargs): 8 | assert split in ["train", "val", "test"] 9 | self.split = split 10 | 11 | if split == "train": 12 | names = ["cls_melinda_train"] 13 | elif split == "val": 14 | names = ["cls_melinda_val"] 15 | elif split == "test": 16 | names = ["cls_melinda_test"] 17 | else: 18 | raise ValueError 19 | 20 | super().__init__(*args, **kwargs, names=names, text_column_name="caption") 21 | self.label_column_name = self.label_column_name 22 | self.labels = self.table[self.label_column_name].to_pandas().tolist() 23 | assert len(self.labels) == len(self.table) 24 | 25 | def __getitem__(self, index): 26 | return self.get_suite(index) 27 | 28 | def get_suite(self, index): 29 | ret = super(CLSMELINDADataset, self).get_suite(index) 30 | img_index, cap_index = self.index_mapper[index] 31 | ret["cls_labels"] = self.labels[img_index][cap_index] 32 | return ret 33 | 34 | def collate(self, batch, mlm_collator): 35 | dict_batch = super(CLSMELINDADataset, self).collate(batch, mlm_collator) 36 | 37 | dict_batch["cls_labels"] = torch.tensor([sample["cls_labels"] for sample in batch]) 38 | return dict_batch 39 | -------------------------------------------------------------------------------- /prepro/OpenKE/examples/train_rescal_FB15K237.py: -------------------------------------------------------------------------------- 1 | import openke 2 | from openke.config import Trainer, Tester 3 | from openke.module.model import RESCAL 4 | from openke.module.loss import MarginLoss 5 | from openke.module.strategy import NegativeSampling 6 | from openke.data import TrainDataLoader, TestDataLoader 7 | 8 | # dataloader for training 9 | train_dataloader = TrainDataLoader( 10 | in_path = "./benchmarks/FB15K237/", 11 | nbatches = 100, 12 | threads = 8, 13 | sampling_mode = "normal", 14 | bern_flag = 1, 15 | filter_flag = 1, 16 | neg_ent = 25, 17 | neg_rel = 0 18 | ) 19 | 20 | # dataloader for test 21 | test_dataloader = TestDataLoader("./benchmarks/FB15K237/", "link") 22 | 23 | # define the model 24 | rescal = RESCAL( 25 | ent_tot = train_dataloader.get_ent_tot(), 26 | rel_tot = train_dataloader.get_rel_tot(), 27 | dim = 50 28 | ) 29 | 30 | # define the loss function 31 | model = NegativeSampling( 32 | model = rescal, 33 | loss = MarginLoss(margin = 1.0), 34 | batch_size = train_dataloader.get_batch_size(), 35 | ) 36 | 37 | # train the model 38 | trainer = Trainer(model = model, data_loader = train_dataloader, train_times = 1000, alpha = 0.1, use_gpu = True, opt_method = "adagrad") 39 | trainer.run() 40 | rescal.save_checkpoint('./checkpoint/rescal.ckpt') 41 | 42 | # test the model 43 | rescal.load_checkpoint('./checkpoint/rescal.ckpt') 44 | tester = Tester(model = rescal, data_loader = test_dataloader, use_gpu = True) 45 | tester.run_link_prediction(type_constrain = False) -------------------------------------------------------------------------------- /prepro/OpenKE/examples/train_hole_WN18RR.py: -------------------------------------------------------------------------------- 1 | import openke 2 | from openke.config import Trainer, Tester 3 | from openke.module.model import HolE 4 | from openke.module.loss import SoftplusLoss 5 | from openke.module.strategy import NegativeSampling 6 | from openke.data import TrainDataLoader, TestDataLoader 7 | 8 | # dataloader for training 9 | train_dataloader = TrainDataLoader( 10 | in_path = "./benchmarks/WN18RR/", 11 | nbatches = 100, 12 | threads = 8, 13 | sampling_mode = "normal", 14 | bern_flag = 1, 15 | filter_flag = 1, 16 | neg_ent = 25, 17 | neg_rel = 0 18 | ) 19 | 20 | # dataloader for test 21 | test_dataloader = TestDataLoader("./benchmarks/WN18RR/", "link") 22 | 23 | # define the model 24 | hole = HolE( 25 | ent_tot = train_dataloader.get_ent_tot(), 26 | rel_tot = train_dataloader.get_rel_tot(), 27 | dim = 100 28 | ) 29 | 30 | # define the loss function 31 | model = NegativeSampling( 32 | model = hole, 33 | loss = SoftplusLoss(), 34 | batch_size = train_dataloader.get_batch_size(), 35 | regul_rate = 1.0 36 | ) 37 | 38 | 39 | # train the model 40 | trainer = Trainer(model = model, data_loader = train_dataloader, train_times = 1000, alpha = 0.5, use_gpu = True, opt_method = "adagrad") 41 | trainer.run() 42 | hole.save_checkpoint('./checkpoint/hole.ckpt') 43 | 44 | # test the model 45 | hole.load_checkpoint('./checkpoint/hole.ckpt') 46 | tester = Tester(model = hole, data_loader = test_dataloader, use_gpu = True) 47 | tester.run_link_prediction(type_constrain = False) 48 | -------------------------------------------------------------------------------- /prepro/OpenKE/examples/train_transe_FB15K237.py: -------------------------------------------------------------------------------- 1 | import openke 2 | from openke.config import Trainer, Tester 3 | from openke.module.model import TransE 4 | from openke.module.loss import MarginLoss 5 | from openke.module.strategy import NegativeSampling 6 | from openke.data import TrainDataLoader, TestDataLoader 7 | 8 | # dataloader for training 9 | train_dataloader = TrainDataLoader( 10 | in_path = "./benchmarks/FB15K237/", 11 | nbatches = 100, 12 | threads = 8, 13 | sampling_mode = "normal", 14 | bern_flag = 1, 15 | filter_flag = 1, 16 | neg_ent = 25, 17 | neg_rel = 0) 18 | 19 | # dataloader for test 20 | test_dataloader = TestDataLoader("./benchmarks/FB15K237/", "link") 21 | 22 | # define the model 23 | transe = TransE( 24 | ent_tot = train_dataloader.get_ent_tot(), 25 | rel_tot = train_dataloader.get_rel_tot(), 26 | dim = 200, 27 | p_norm = 1, 28 | norm_flag = True) 29 | 30 | 31 | # define the loss function 32 | model = NegativeSampling( 33 | model = transe, 34 | loss = MarginLoss(margin = 5.0), 35 | batch_size = train_dataloader.get_batch_size() 36 | ) 37 | 38 | # train the model 39 | trainer = Trainer(model = model, data_loader = train_dataloader, train_times = 1000, alpha = 1.0, use_gpu = True) 40 | trainer.run() 41 | transe.save_checkpoint('./checkpoint/transe.ckpt') 42 | 43 | # test the model 44 | transe.load_checkpoint('./checkpoint/transe.ckpt') 45 | tester = Tester(model = transe, data_loader = test_dataloader, use_gpu = True) 46 | tester.run_link_prediction(type_constrain = False) -------------------------------------------------------------------------------- /prepro/OpenKE/examples/train_transh_FB15K237.py: -------------------------------------------------------------------------------- 1 | import openke 2 | from openke.config import Trainer, Tester 3 | from openke.module.model import TransH 4 | from openke.module.loss import MarginLoss 5 | from openke.module.strategy import NegativeSampling 6 | from openke.data import TrainDataLoader, TestDataLoader 7 | 8 | # dataloader for training 9 | train_dataloader = TrainDataLoader( 10 | in_path = "./benchmarks/FB15K237/", 11 | nbatches = 100, 12 | threads = 8, 13 | sampling_mode = "normal", 14 | bern_flag = 1, 15 | filter_flag = 1, 16 | neg_ent = 25, 17 | neg_rel = 0) 18 | 19 | # dataloader for test 20 | test_dataloader = TestDataLoader("./benchmarks/FB15K237/", "link") 21 | 22 | # define the model 23 | transh = TransH( 24 | ent_tot = train_dataloader.get_ent_tot(), 25 | rel_tot = train_dataloader.get_rel_tot(), 26 | dim = 200, 27 | p_norm = 1, 28 | norm_flag = True) 29 | 30 | # define the loss function 31 | model = NegativeSampling( 32 | model = transh, 33 | loss = MarginLoss(margin = 4.0), 34 | batch_size = train_dataloader.get_batch_size() 35 | ) 36 | 37 | 38 | # train the model 39 | trainer = Trainer(model = model, data_loader = train_dataloader, train_times = 1000, alpha = 0.5, use_gpu = True) 40 | trainer.run() 41 | transh.save_checkpoint('./checkpoint/transh.ckpt') 42 | 43 | # test the model 44 | transh.load_checkpoint('./checkpoint/transh.ckpt') 45 | tester = Tester(model = transh, data_loader = test_dataloader, use_gpu = True) 46 | tester.run_link_prediction(type_constrain = False) -------------------------------------------------------------------------------- /prepro/OpenKE/examples/train_analogy_WN18RR.py: -------------------------------------------------------------------------------- 1 | import openke 2 | from openke.config import Trainer, Tester 3 | from openke.module.model import Analogy 4 | from openke.module.loss import SoftplusLoss 5 | from openke.module.strategy import NegativeSampling 6 | from openke.data import TrainDataLoader, TestDataLoader 7 | 8 | # dataloader for training 9 | train_dataloader = TrainDataLoader( 10 | in_path = "./benchmarks/WN18RR/", 11 | nbatches = 100, 12 | threads = 8, 13 | sampling_mode = "normal", 14 | bern_flag = 1, 15 | filter_flag = 1, 16 | neg_ent = 25, 17 | neg_rel = 0 18 | ) 19 | 20 | # dataloader for test 21 | test_dataloader = TestDataLoader("./benchmarks/WN18RR/", "link") 22 | 23 | # define the model 24 | analogy = Analogy( 25 | ent_tot = train_dataloader.get_ent_tot(), 26 | rel_tot = train_dataloader.get_rel_tot(), 27 | dim = 200 28 | ) 29 | 30 | # define the loss function 31 | model = NegativeSampling( 32 | model = analogy, 33 | loss = SoftplusLoss(), 34 | batch_size = train_dataloader.get_batch_size(), 35 | regul_rate = 1.0 36 | ) 37 | 38 | # train the model 39 | trainer = Trainer(model = model, data_loader = train_dataloader, train_times = 2000, alpha = 0.5, use_gpu = True, opt_method = "adagrad") 40 | trainer.run() 41 | analogy.save_checkpoint('./checkpoint/analogy.ckpt') 42 | 43 | # test the model 44 | analogy.load_checkpoint('./checkpoint/analogy.ckpt') 45 | tester = Tester(model = analogy, data_loader = test_dataloader, use_gpu = True) 46 | tester.run_link_prediction(type_constrain = False) -------------------------------------------------------------------------------- /prepro/OpenKE/examples/train_complex_WN18RR.py: -------------------------------------------------------------------------------- 1 | import openke 2 | from openke.config import Trainer, Tester 3 | from openke.module.model import ComplEx 4 | from openke.module.loss import SoftplusLoss 5 | from openke.module.strategy import NegativeSampling 6 | from openke.data import TrainDataLoader, TestDataLoader 7 | 8 | # dataloader for training 9 | train_dataloader = TrainDataLoader( 10 | in_path = "./benchmarks/WN18RR/", 11 | nbatches = 100, 12 | threads = 8, 13 | sampling_mode = "normal", 14 | bern_flag = 1, 15 | filter_flag = 1, 16 | neg_ent = 25, 17 | neg_rel = 0 18 | ) 19 | 20 | # dataloader for test 21 | test_dataloader = TestDataLoader("./benchmarks/WN18RR/", "link") 22 | 23 | # define the model 24 | complEx = ComplEx( 25 | ent_tot = train_dataloader.get_ent_tot(), 26 | rel_tot = train_dataloader.get_rel_tot(), 27 | dim = 200 28 | ) 29 | 30 | # define the loss function 31 | model = NegativeSampling( 32 | model = complEx, 33 | loss = SoftplusLoss(), 34 | batch_size = train_dataloader.get_batch_size(), 35 | regul_rate = 1.0 36 | ) 37 | 38 | # train the model 39 | trainer = Trainer(model = model, data_loader = train_dataloader, train_times = 2000, alpha = 0.5, use_gpu = True, opt_method = "adagrad") 40 | trainer.run() 41 | complEx.save_checkpoint('./checkpoint/complEx.ckpt') 42 | 43 | # test the model 44 | complEx.load_checkpoint('./checkpoint/complEx.ckpt') 45 | tester = Tester(model = complEx, data_loader = test_dataloader, use_gpu = True) 46 | tester.run_link_prediction(type_constrain = False) -------------------------------------------------------------------------------- /prepro/OpenKE/openke/module/model/RESCAL.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .Model import Model 4 | 5 | class RESCAL(Model): 6 | 7 | def __init__(self, ent_tot, rel_tot, dim = 100): 8 | super(RESCAL, self).__init__(ent_tot, rel_tot) 9 | 10 | self.dim = dim 11 | self.ent_embeddings = nn.Embedding(self.ent_tot, self.dim) 12 | self.rel_matrices = nn.Embedding(self.rel_tot, self.dim * self.dim) 13 | 14 | nn.init.xavier_uniform_(self.ent_embeddings.weight.data) 15 | nn.init.xavier_uniform_(self.rel_matrices.weight.data) 16 | 17 | def _calc(self, h, t, r): 18 | t = t.view(-1, self.dim, 1) 19 | r = r.view(-1, self.dim, self.dim) 20 | tr = torch.matmul(r, t) 21 | tr = tr.view(-1, self.dim) 22 | return -torch.sum(h * tr, -1) 23 | 24 | def forward(self, data): 25 | batch_h = data['batch_h'] 26 | batch_t = data['batch_t'] 27 | batch_r = data['batch_r'] 28 | h = self.ent_embeddings(batch_h) 29 | t = self.ent_embeddings(batch_t) 30 | r = self.rel_matrices(batch_r) 31 | score = self._calc(h ,t, r) 32 | return score 33 | 34 | def regularization(self, data): 35 | batch_h = data['batch_h'] 36 | batch_t = data['batch_t'] 37 | batch_r = data['batch_r'] 38 | h = self.ent_embeddings(batch_h) 39 | t = self.ent_embeddings(batch_t) 40 | r = self.rel_matrices(batch_r) 41 | regul = (torch.mean(h ** 2) + torch.mean(t ** 2) + torch.mean(r ** 2)) / 3 42 | return regul 43 | 44 | def predict(self, data): 45 | score = -self.forward(data) 46 | return score.cpu().data.numpy() -------------------------------------------------------------------------------- /prepro/OpenKE/examples/train_simple_WN18RR.py: -------------------------------------------------------------------------------- 1 | import openke 2 | from openke.config import Trainer, Tester 3 | from openke.module.model import SimplE 4 | from openke.module.loss import SoftplusLoss 5 | from openke.module.strategy import NegativeSampling 6 | from openke.data import TrainDataLoader, TestDataLoader 7 | 8 | # dataloader for training 9 | train_dataloader = TrainDataLoader( 10 | in_path = "./benchmarks/WN18RR/", 11 | nbatches = 100, 12 | threads = 8, 13 | sampling_mode = "normal", 14 | bern_flag = 1, 15 | filter_flag = 1, 16 | neg_ent = 25, 17 | neg_rel = 0 18 | ) 19 | 20 | # dataloader for test 21 | test_dataloader = TestDataLoader("./benchmarks/WN18RR/", "link") 22 | 23 | # define the model 24 | simple = SimplE( 25 | ent_tot = train_dataloader.get_ent_tot(), 26 | rel_tot = train_dataloader.get_rel_tot(), 27 | dim = 200 28 | ) 29 | 30 | # define the loss function 31 | model = NegativeSampling( 32 | model = simple, 33 | loss = SoftplusLoss(), 34 | batch_size = train_dataloader.get_batch_size(), 35 | regul_rate = 1.0 36 | ) 37 | 38 | 39 | # train the model 40 | trainer = Trainer(model = model, data_loader = train_dataloader, train_times = 2000, alpha = 0.5, use_gpu = True, opt_method = "adagrad") 41 | trainer.run() 42 | simple.save_checkpoint('./checkpoint/simple.ckpt') 43 | 44 | # test the model 45 | simple.load_checkpoint('./checkpoint/simple.ckpt') 46 | tester = Tester(model = simple, data_loader = test_dataloader, use_gpu = True) 47 | tester.run_link_prediction(type_constrain = False) 48 | -------------------------------------------------------------------------------- /prepro/OpenKE/examples/train_transd_FB15K237.py: -------------------------------------------------------------------------------- 1 | import openke 2 | from openke.config import Trainer, Tester 3 | from openke.module.model import TransD 4 | from openke.module.loss import MarginLoss 5 | from openke.module.strategy import NegativeSampling 6 | from openke.data import TrainDataLoader, TestDataLoader 7 | 8 | # dataloader for training 9 | train_dataloader = TrainDataLoader( 10 | in_path = "./benchmarks/FB15K237/", 11 | nbatches = 100, 12 | threads = 8, 13 | sampling_mode = "normal", 14 | bern_flag = 1, 15 | filter_flag = 1, 16 | neg_ent = 25, 17 | neg_rel = 0) 18 | 19 | # dataloader for test 20 | test_dataloader = TestDataLoader("./benchmarks/FB15K237/", "link") 21 | 22 | # define the model 23 | transd = TransD( 24 | ent_tot = train_dataloader.get_ent_tot(), 25 | rel_tot = train_dataloader.get_rel_tot(), 26 | dim_e = 200, 27 | dim_r = 200, 28 | p_norm = 1, 29 | norm_flag = True) 30 | 31 | 32 | # define the loss function 33 | model = NegativeSampling( 34 | model = transd, 35 | loss = MarginLoss(margin = 4.0), 36 | batch_size = train_dataloader.get_batch_size() 37 | ) 38 | 39 | # train the model 40 | trainer = Trainer(model = model, data_loader = train_dataloader, train_times = 1000, alpha = 1.0, use_gpu = True) 41 | trainer.run() 42 | transd.save_checkpoint('./checkpoint/transd.ckpt') 43 | 44 | # test the model 45 | transd.load_checkpoint('./checkpoint/transd.ckpt') 46 | tester = Tester(model = transd, data_loader = test_dataloader, use_gpu = True) 47 | tester.run_link_prediction(type_constrain = False) -------------------------------------------------------------------------------- /prepro/OpenKE/examples/train_distmult_WN18RR.py: -------------------------------------------------------------------------------- 1 | import openke 2 | from openke.config import Trainer, Tester 3 | from openke.module.model import DistMult 4 | from openke.module.loss import SoftplusLoss 5 | from openke.module.strategy import NegativeSampling 6 | from openke.data import TrainDataLoader, TestDataLoader 7 | 8 | # dataloader for training 9 | train_dataloader = TrainDataLoader( 10 | in_path = "./benchmarks/WN18RR/", 11 | nbatches = 100, 12 | threads = 8, 13 | sampling_mode = "normal", 14 | bern_flag = 1, 15 | filter_flag = 1, 16 | neg_ent = 25, 17 | neg_rel = 0 18 | ) 19 | 20 | # dataloader for test 21 | test_dataloader = TestDataLoader("./benchmarks/WN18RR/", "link") 22 | 23 | # define the model 24 | distmult = DistMult( 25 | ent_tot = train_dataloader.get_ent_tot(), 26 | rel_tot = train_dataloader.get_rel_tot(), 27 | dim = 200 28 | ) 29 | 30 | # define the loss function 31 | model = NegativeSampling( 32 | model = distmult, 33 | loss = SoftplusLoss(), 34 | batch_size = train_dataloader.get_batch_size(), 35 | regul_rate = 1.0 36 | ) 37 | 38 | 39 | # train the model 40 | trainer = Trainer(model = model, data_loader = train_dataloader, train_times = 2000, alpha = 0.5, use_gpu = True, opt_method = "adagrad") 41 | trainer.run() 42 | distmult.save_checkpoint('./checkpoint/distmult.ckpt') 43 | 44 | # test the model 45 | distmult.load_checkpoint('./checkpoint/distmult.ckpt') 46 | tester = Tester(model = distmult, data_loader = test_dataloader, use_gpu = True) 47 | tester.run_link_prediction(type_constrain = False) 48 | -------------------------------------------------------------------------------- /prepro/OpenKE/examples/train_rotate_WN18RR_adv.py: -------------------------------------------------------------------------------- 1 | import openke 2 | from openke.config import Trainer, Tester 3 | from openke.module.model import RotatE 4 | from openke.module.loss import SigmoidLoss 5 | from openke.module.strategy import NegativeSampling 6 | from openke.data import TrainDataLoader, TestDataLoader 7 | 8 | # dataloader for training 9 | train_dataloader = TrainDataLoader( 10 | in_path = "./benchmarks/WN18RR/", 11 | batch_size = 2000, 12 | threads = 8, 13 | sampling_mode = "cross", 14 | bern_flag = 0, 15 | filter_flag = 1, 16 | neg_ent = 64, 17 | neg_rel = 0 18 | ) 19 | 20 | # dataloader for test 21 | test_dataloader = TestDataLoader("./benchmarks/WN18RR/", "link") 22 | 23 | # define the model 24 | rotate = RotatE( 25 | ent_tot = train_dataloader.get_ent_tot(), 26 | rel_tot = train_dataloader.get_rel_tot(), 27 | dim = 1024, 28 | margin = 6.0, 29 | epsilon = 2.0, 30 | ) 31 | 32 | # define the loss function 33 | model = NegativeSampling( 34 | model = rotate, 35 | loss = SigmoidLoss(adv_temperature = 2), 36 | batch_size = train_dataloader.get_batch_size(), 37 | regul_rate = 0.0 38 | ) 39 | 40 | # train the model 41 | trainer = Trainer(model = model, data_loader = train_dataloader, train_times = 6000, alpha = 2e-5, use_gpu = True, opt_method = "adam") 42 | trainer.run() 43 | rotate.save_checkpoint('./checkpoint/rotate.ckpt') 44 | 45 | # test the model 46 | rotate.load_checkpoint('./checkpoint/rotate.ckpt') 47 | tester = Tester(model = rotate, data_loader = test_dataloader, use_gpu = True) 48 | tester.run_link_prediction(type_constrain = False) -------------------------------------------------------------------------------- /arl/datamodules/vqa_slack_datamodule.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | from .base_datamodule import BaseDataModule 4 | from ..datasets import VQASLACKDataset 5 | 6 | 7 | class VQASLACKDataModule(BaseDataModule): 8 | def __init__(self, *args, **kwargs): 9 | super().__init__(*args, **kwargs) 10 | 11 | @property 12 | def dataset_cls(self): 13 | return VQASLACKDataset 14 | 15 | @property 16 | def dataset_name(self): 17 | return "vqa_slack" 18 | 19 | def setup(self, stage): 20 | super().setup(stage) 21 | 22 | train_answers = self.train_dataset.table["answers"].to_pandas().tolist() 23 | val_answers = self.val_dataset.table["answers"].to_pandas().tolist() 24 | train_labels = self.train_dataset.table["answer_labels"].to_pandas().tolist() 25 | val_labels = self.val_dataset.table["answer_labels"].to_pandas().tolist() 26 | 27 | all_answers = [c for c in train_answers + val_answers if c is not None] 28 | all_answers = [l for lll in all_answers for ll in lll for l in ll] 29 | all_labels = [c for c in train_labels + val_labels if c is not None] 30 | all_labels = [l for lll in all_labels for ll in lll for l in ll] 31 | 32 | self.answer2id = {k: v for k, v in zip(all_answers, all_labels)} 33 | sorted_a2i = sorted(self.answer2id.items(), key=lambda x: x[1]) 34 | self.num_class = max(self.answer2id.values()) + 1 35 | 36 | self.id2answer = defaultdict(lambda: "unknown") 37 | for k, v in sorted_a2i: 38 | self.id2answer[v] = k 39 | -------------------------------------------------------------------------------- /arl/datamodules/vqa_vqa_rad_datamodule.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | from .base_datamodule import BaseDataModule 4 | from ..datasets import VQAVQARADDataset 5 | 6 | 7 | class VQAVQARADDataModule(BaseDataModule): 8 | def __init__(self, *args, **kwargs): 9 | super().__init__(*args, **kwargs) 10 | 11 | @property 12 | def dataset_cls(self): 13 | return VQAVQARADDataset 14 | 15 | @property 16 | def dataset_name(self): 17 | return "vqa_vqa_rad" 18 | 19 | def setup(self, stage): 20 | super().setup(stage) 21 | 22 | train_answers = self.train_dataset.table["answers"].to_pandas().tolist() 23 | val_answers = self.val_dataset.table["answers"].to_pandas().tolist() 24 | train_labels = self.train_dataset.table["answer_labels"].to_pandas().tolist() 25 | val_labels = self.val_dataset.table["answer_labels"].to_pandas().tolist() 26 | 27 | all_answers = [c for c in train_answers + val_answers if c is not None] 28 | all_answers = [l for lll in all_answers for ll in lll for l in ll] 29 | all_labels = [c for c in train_labels + val_labels if c is not None] 30 | all_labels = [l for lll in all_labels for ll in lll for l in ll] 31 | 32 | self.answer2id = {k: v for k, v in zip(all_answers, all_labels)} 33 | sorted_a2i = sorted(self.answer2id.items(), key=lambda x: x[1]) 34 | self.num_class = max(self.answer2id.values()) + 1 35 | 36 | self.id2answer = defaultdict(lambda: "unknown") 37 | for k, v in sorted_a2i: 38 | self.id2answer[v] = k 39 | -------------------------------------------------------------------------------- /prepro/OpenKE/examples/train_distmult_WN18RR_adv.py: -------------------------------------------------------------------------------- 1 | import openke 2 | from openke.config import Trainer, Tester 3 | from openke.module.model import DistMult 4 | from openke.module.loss import SigmoidLoss 5 | from openke.module.strategy import NegativeSampling 6 | from openke.data import TrainDataLoader, TestDataLoader 7 | 8 | # dataloader for training 9 | train_dataloader = TrainDataLoader( 10 | in_path = "./benchmarks/WN18RR/", 11 | batch_size = 2000, 12 | threads = 8, 13 | sampling_mode = "cross", 14 | bern_flag = 0, 15 | filter_flag = 1, 16 | neg_ent = 64, 17 | neg_rel = 0 18 | ) 19 | 20 | # dataloader for test 21 | test_dataloader = TestDataLoader("./benchmarks/WN18RR/", "link") 22 | 23 | # define the model 24 | distmult = DistMult( 25 | ent_tot = train_dataloader.get_ent_tot(), 26 | rel_tot = train_dataloader.get_rel_tot(), 27 | dim = 1024, 28 | margin = 200.0, 29 | epsilon = 2.0 30 | ) 31 | 32 | # define the loss function 33 | model = NegativeSampling( 34 | model = distmult, 35 | loss = SigmoidLoss(adv_temperature = 0.5), 36 | batch_size = train_dataloader.get_batch_size(), 37 | l3_regul_rate = 0.000005 38 | ) 39 | 40 | # train the model 41 | trainer = Trainer(model = model, data_loader = train_dataloader, train_times = 400, alpha = 0.002, use_gpu = True, opt_method = "adam") 42 | trainer.run() 43 | distmult.save_checkpoint('./checkpoint/distmult.ckpt') 44 | 45 | # test the model 46 | distmult.load_checkpoint('./checkpoint/distmult.ckpt') 47 | tester = Tester(model = distmult, data_loader = test_dataloader, use_gpu = True) 48 | tester.run_link_prediction(type_constrain = False) 49 | -------------------------------------------------------------------------------- /prepro/OpenKE/examples/train_transe_WN18_adv_sigmoidloss.py: -------------------------------------------------------------------------------- 1 | import openke 2 | from openke.config import Trainer, Tester 3 | from openke.module.model import TransE 4 | from openke.module.loss import SigmoidLoss 5 | from openke.module.strategy import NegativeSampling 6 | from openke.data import TrainDataLoader, TestDataLoader 7 | 8 | # dataloader for training 9 | train_dataloader = TrainDataLoader( 10 | in_path = "./benchmarks/WN18RR/", 11 | batch_size = 2000, 12 | threads = 8, 13 | sampling_mode = "cross", 14 | bern_flag = 0, 15 | filter_flag = 1, 16 | neg_ent = 64, 17 | neg_rel = 0 18 | ) 19 | 20 | # dataloader for test 21 | test_dataloader = TestDataLoader("./benchmarks/WN18RR/", "link") 22 | 23 | # define the model 24 | transe = TransE( 25 | ent_tot = train_dataloader.get_ent_tot(), 26 | rel_tot = train_dataloader.get_rel_tot(), 27 | dim = 1024, 28 | p_norm = 1, 29 | norm_flag = False, 30 | margin = 6.0) 31 | 32 | 33 | # define the loss function 34 | model = NegativeSampling( 35 | model = transe, 36 | loss = SigmoidLoss(adv_temperature = 1), 37 | batch_size = train_dataloader.get_batch_size(), 38 | regul_rate = 0.0 39 | ) 40 | 41 | # train the model 42 | trainer = Trainer(model = model, data_loader = train_dataloader, train_times = 3000, alpha = 2e-5, use_gpu = True, opt_method = "adam") 43 | trainer.run() 44 | transe.save_checkpoint('./checkpoint/transe_2.ckpt') 45 | 46 | # test the model 47 | transe.load_checkpoint('./checkpoint/transe_2.ckpt') 48 | tester = Tester(model = transe, data_loader = test_dataloader, use_gpu = True) 49 | tester.run_link_prediction(type_constrain = False) 50 | -------------------------------------------------------------------------------- /arl/datamodules/vqa_medvqa_2019_datamodule.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | from .base_datamodule import BaseDataModule 4 | from ..datasets import VQAMEDVQA2019Dataset 5 | 6 | 7 | class VQAMEDVQA2019DataModule(BaseDataModule): 8 | def __init__(self, *args, **kwargs): 9 | super().__init__(*args, **kwargs) 10 | 11 | @property 12 | def dataset_cls(self): 13 | return VQAMEDVQA2019Dataset 14 | 15 | @property 16 | def dataset_name(self): 17 | return "vqa_medvqa_2019" 18 | 19 | def setup(self, stage): 20 | super().setup(stage) 21 | 22 | train_answers = self.train_dataset.table["answers"].to_pandas().tolist() 23 | val_answers = self.val_dataset.table["answers"].to_pandas().tolist() 24 | train_labels = self.train_dataset.table["answer_labels"].to_pandas().tolist() 25 | val_labels = self.val_dataset.table["answer_labels"].to_pandas().tolist() 26 | 27 | all_answers = [c for c in train_answers + val_answers if c is not None] 28 | all_answers = [l for lll in all_answers for ll in lll for l in ll] 29 | all_labels = [c for c in train_labels + val_labels if c is not None] 30 | all_labels = [l for lll in all_labels for ll in lll for l in ll] 31 | 32 | self.answer2id = {k: v for k, v in zip(all_answers, all_labels)} 33 | sorted_a2i = sorted(self.answer2id.items(), key=lambda x: x[1]) 34 | self.num_class = max(self.answer2id.values()) + 1 35 | 36 | self.id2answer = defaultdict(lambda: "unknown") 37 | for k, v in sorted_a2i: 38 | self.id2answer[v] = k 39 | -------------------------------------------------------------------------------- /arl/datamodules/vqa_medvqa_2021_datamodule.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | from .base_datamodule import BaseDataModule 4 | from ..datasets import VQAMEDVQA2021Dataset 5 | 6 | 7 | class VQAMEDVQA2021DataModule(BaseDataModule): 8 | def __init__(self, *args, **kwargs): 9 | super().__init__(*args, **kwargs) 10 | 11 | @property 12 | def dataset_cls(self): 13 | return VQAMEDVQA2021Dataset 14 | 15 | @property 16 | def dataset_name(self): 17 | return "vqa_medvqa_2019" 18 | 19 | def setup(self, stage): 20 | super().setup(stage) 21 | 22 | train_answers = self.train_dataset.table["answers"].to_pandas().tolist() 23 | val_answers = self.val_dataset.table["answers"].to_pandas().tolist() 24 | train_labels = self.train_dataset.table["answer_labels"].to_pandas().tolist() 25 | val_labels = self.val_dataset.table["answer_labels"].to_pandas().tolist() 26 | 27 | all_answers = [c for c in train_answers + val_answers if c is not None] 28 | all_answers = [l for lll in all_answers for ll in lll for l in ll] 29 | all_labels = [c for c in train_labels + val_labels if c is not None] 30 | all_labels = [l for lll in all_labels for ll in lll for l in ll] 31 | 32 | self.answer2id = {k: v for k, v in zip(all_answers, all_labels)} 33 | sorted_a2i = sorted(self.answer2id.items(), key=lambda x: x[1]) 34 | self.num_class = max(self.answer2id.values()) + 1 35 | 36 | self.id2answer = defaultdict(lambda: "unknown") 37 | for k, v in sorted_a2i: 38 | self.id2answer[v] = k 39 | -------------------------------------------------------------------------------- /prepro/OpenKE/openke/module/BaseModule.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | import json 5 | import numpy as np 6 | 7 | class BaseModule(nn.Module): 8 | 9 | def __init__(self): 10 | super(BaseModule, self).__init__() 11 | self.zero_const = nn.Parameter(torch.Tensor([0])) 12 | self.zero_const.requires_grad = False 13 | self.pi_const = nn.Parameter(torch.Tensor([3.14159265358979323846])) 14 | self.pi_const.requires_grad = False 15 | 16 | def load_checkpoint(self, path): 17 | self.load_state_dict(torch.load(os.path.join(path))) 18 | self.eval() 19 | 20 | def save_checkpoint(self, path): 21 | torch.save(self.state_dict(), path) 22 | 23 | def load_parameters(self, path): 24 | f = open(path, "r") 25 | parameters = json.loads(f.read()) 26 | f.close() 27 | for i in parameters: 28 | parameters[i] = torch.Tensor(parameters[i]) 29 | self.load_state_dict(parameters, strict = False) 30 | self.eval() 31 | 32 | def save_parameters(self, path): 33 | f = open(path, "w") 34 | f.write(json.dumps(self.get_parameters("list"))) 35 | f.close() 36 | 37 | def get_parameters(self, mode = "numpy", param_dict = None): 38 | all_param_dict = self.state_dict() 39 | if param_dict == None: 40 | param_dict = all_param_dict.keys() 41 | res = {} 42 | for param in param_dict: 43 | if mode == "numpy": 44 | res[param] = all_param_dict[param].cpu().numpy() 45 | elif mode == "list": 46 | res[param] = all_param_dict[param].cpu().numpy().tolist() 47 | else: 48 | res[param] = all_param_dict[param] 49 | return res 50 | 51 | def set_parameters(self, parameters): 52 | for i in parameters: 53 | parameters[i] = torch.Tensor(parameters[i]) 54 | self.load_state_dict(parameters, strict = False) 55 | self.eval() -------------------------------------------------------------------------------- /arl/datasets/vqa_slack_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | 3 | 4 | class VQASLACKDataset(BaseDataset): 5 | def __init__(self, *args, split="", **kwargs): 6 | assert split in ["train", "val", "test"] 7 | self.split = split 8 | 9 | if split == "train": 10 | names = ["vqa_slack_train"] 11 | elif split == "val": 12 | names = ["vqa_slack_val"] 13 | elif split == "test": 14 | names = ["vqa_slack_test"] 15 | 16 | super().__init__( 17 | *args, 18 | **kwargs, 19 | names=names, 20 | text_column_name="questions" 21 | ) 22 | 23 | def __getitem__(self, index): 24 | image_tensor = self.get_image(index)["image"] 25 | txt = self.get_text(index) 26 | text = txt["text"] 27 | img_label = txt["img_label"] 28 | txt_label = txt["txt_label"] 29 | 30 | index, question_index = self.index_mapper[index] 31 | qid = self.table["question_id"][index][question_index].as_py() 32 | 33 | answers = self.table["answers"][index][question_index].as_py() 34 | labels = self.table["answer_labels"][index][question_index].as_py() 35 | scores = self.table["answer_scores"][index][question_index].as_py() 36 | answer_types = self.table["answer_type"][index][question_index].as_py() 37 | 38 | return { 39 | "image": image_tensor, 40 | "text": text, 41 | "vqa_answer": answers, 42 | "vqa_labels": labels, 43 | "vqa_scores": scores, 44 | "answer_types": answer_types, 45 | "qid": qid, 46 | "img_label": img_label, 47 | "txt_label": txt_label 48 | } 49 | -------------------------------------------------------------------------------- /arl/datasets/vqa_vqa_rad_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | 3 | 4 | class VQAVQARADDataset(BaseDataset): 5 | def __init__(self, *args, split="", **kwargs): 6 | assert split in ["train", "val", "test"] 7 | self.split = split 8 | 9 | if split == "train": 10 | names = ["vqa_vqa_rad_train"] 11 | elif split == "val": 12 | names = ["vqa_vqa_rad_val"] 13 | elif split == "test": 14 | names = ["vqa_vqa_rad_test"] 15 | 16 | super().__init__( 17 | *args, 18 | **kwargs, 19 | names=names, 20 | text_column_name="questions" 21 | ) 22 | 23 | def __getitem__(self, index): 24 | image_tensor = self.get_image(index)["image"] 25 | txt = self.get_text(index) 26 | text = txt["text"] 27 | img_label = txt["img_label"] 28 | txt_label = txt["txt_label"] 29 | 30 | index, question_index = self.index_mapper[index] 31 | qid = self.table["question_id"][index][question_index].as_py() 32 | 33 | answers = self.table["answers"][index][question_index].as_py() 34 | labels = self.table["answer_labels"][index][question_index].as_py() 35 | scores = self.table["answer_scores"][index][question_index].as_py() 36 | answer_types = self.table["answer_type"][index][question_index].as_py() 37 | 38 | return { 39 | "image": image_tensor, 40 | "text": text, 41 | "vqa_answer": answers, 42 | "vqa_labels": labels, 43 | "vqa_scores": scores, 44 | "answer_types": answer_types, 45 | "qid": qid, 46 | "img_label": img_label, 47 | "txt_label": txt_label 48 | } 49 | -------------------------------------------------------------------------------- /arl/datasets/vqa_medvqa_2019_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | 3 | 4 | class VQAMEDVQA2019Dataset(BaseDataset): 5 | def __init__(self, *args, split="", **kwargs): 6 | assert split in ["train", "val", "test"] 7 | self.split = split 8 | 9 | if split == "train": 10 | names = ["vqa_medvqa_2019_train"] 11 | elif split == "val": 12 | names = ["vqa_medvqa_2019_val"] 13 | elif split == "test": 14 | names = ["vqa_medvqa_2019_test"] 15 | 16 | super().__init__( 17 | *args, 18 | **kwargs, 19 | names=names, 20 | text_column_name="questions", 21 | ) 22 | 23 | def __getitem__(self, index): 24 | image_tensor = self.get_image(index)["image"] 25 | txt = self.get_text(index) 26 | text = txt["text"] 27 | img_label = txt["img_label"] 28 | txt_label = txt["txt_label"] 29 | 30 | index, question_index = self.index_mapper[index] 31 | qid = self.table["question_id"][index][question_index].as_py() 32 | 33 | answers = self.table["answers"][index][question_index].as_py() 34 | labels = self.table["answer_labels"][index][question_index].as_py() 35 | scores = self.table["answer_scores"][index][question_index].as_py() 36 | answer_types = self.table["answer_type"][index][question_index].as_py() 37 | 38 | return { 39 | "image": image_tensor, 40 | "text": text, 41 | "vqa_answer": answers, 42 | "vqa_labels": labels, 43 | "vqa_scores": scores, 44 | "answer_types": answer_types, 45 | "qid": qid, 46 | "img_label": img_label, 47 | "txt_label": txt_label 48 | } 49 | -------------------------------------------------------------------------------- /arl/datasets/vqa_medvqa_2021_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | 3 | 4 | class VQAMEDVQA2021Dataset(BaseDataset): 5 | def __init__(self, *args, split="", **kwargs): 6 | assert split in ["train", "val", "test"] 7 | self.split = split 8 | 9 | if split == "train": 10 | names = ["vqa_medvqa_2021_train"] 11 | elif split == "val": 12 | names = ["vqa_medvqa_2021_val"] 13 | elif split == "test": 14 | names = ["vqa_medvqa_2021_test"] 15 | 16 | super().__init__( 17 | *args, 18 | **kwargs, 19 | names=names, 20 | text_column_name="questions", 21 | ) 22 | 23 | def __getitem__(self, index): 24 | image_tensor = self.get_image(index)["image"] 25 | txt = self.get_text(index) 26 | text = txt["text"] 27 | img_label = txt["img_label"] 28 | txt_label = txt["txt_label"] 29 | 30 | index, question_index = self.index_mapper[index] 31 | qid = self.table["question_id"][index][question_index].as_py() 32 | 33 | answers = self.table["answers"][index][question_index].as_py() 34 | labels = self.table["answer_labels"][index][question_index].as_py() 35 | scores = self.table["answer_scores"][index][question_index].as_py() 36 | answer_types = self.table["answer_type"][index][question_index].as_py() 37 | 38 | return { 39 | "image": image_tensor, 40 | "text": text, 41 | "vqa_answer": answers, 42 | "vqa_labels": labels, 43 | "vqa_scores": scores, 44 | "answer_types": answer_types, 45 | "qid": qid, 46 | "img_label": img_label, 47 | "txt_label": txt_label 48 | } 49 | -------------------------------------------------------------------------------- /arl/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def record_ent_ref(encoding, txt_ents): 5 | encoding["txt_label"] = [] 6 | encoding["txt_ents"] = [] 7 | encoding["ent_ref"] = [False] * len(encoding["input_ids"]) 8 | for ent in txt_ents: 9 | beg_char = ent[0] 10 | end_char = ent[1] - 1 11 | beg_token = encoding.char_to_token(beg_char) 12 | end_token = encoding.char_to_token(end_char) 13 | if beg_token is not None and end_token is not None: 14 | for pos in range(beg_token + 1, end_token + 1): 15 | encoding["ent_ref"][pos] = True 16 | encoding["txt_label"].append(ent[2]) 17 | encoding["txt_ents"].append(ent) 18 | return encoding 19 | 20 | 21 | def create_pos_matrix(encoding, max_text_len, max_ent_len, mlm_labels=None): 22 | txt_ents = encoding["txt_ents"] 23 | pos_matrix = torch.zeros((max_ent_len, max_text_len), dtype=torch.float) 24 | ent_ids = torch.ones(max_ent_len, dtype=torch.long) * (-100) 25 | ent_masks = torch.zeros(max_ent_len, dtype=torch.bool) 26 | 27 | counter = 0 28 | for ent_idx, ent in enumerate(txt_ents): 29 | beg_char = ent[0] 30 | end_char = ent[1] - 1 31 | beg_token = encoding.char_to_token(beg_char) 32 | end_token = encoding.char_to_token(end_char) 33 | if beg_token is not None and end_token is not None: 34 | if mlm_labels is not None and (mlm_labels[beg_token:end_token + 1] != -100).sum() != 0: 35 | continue 36 | for pos in range(beg_token, end_token + 1): 37 | pos_matrix[counter][pos] = True 38 | ent_ids[counter] = ent[2] 39 | ent_masks[counter] = True 40 | counter = counter + 1 41 | if counter == max_ent_len: 42 | break 43 | return pos_matrix, ent_ids, ent_masks 44 | -------------------------------------------------------------------------------- /prepro/OpenKE/examples/train_transr_FB15K237.py: -------------------------------------------------------------------------------- 1 | import openke 2 | from openke.config import Trainer, Tester 3 | from openke.module.model import TransE, TransR 4 | from openke.module.loss import MarginLoss 5 | from openke.module.strategy import NegativeSampling 6 | from openke.data import TrainDataLoader, TestDataLoader 7 | 8 | # dataloader for training 9 | train_dataloader = TrainDataLoader( 10 | in_path = "./benchmarks/FB15K237/", 11 | nbatches = 100, 12 | threads = 8, 13 | sampling_mode = "normal", 14 | bern_flag = 1, 15 | filter_flag = 1, 16 | neg_ent = 25, 17 | neg_rel = 0) 18 | 19 | # dataloader for test 20 | test_dataloader = TestDataLoader( 21 | in_path = "./benchmarks/FB15K237/", 22 | sampling_mode = 'link') 23 | 24 | # define the model 25 | transe = TransE( 26 | ent_tot = train_dataloader.get_ent_tot(), 27 | rel_tot = train_dataloader.get_rel_tot(), 28 | dim = 200, 29 | p_norm = 1, 30 | norm_flag = True) 31 | 32 | model_e = NegativeSampling( 33 | model = transe, 34 | loss = MarginLoss(margin = 5.0), 35 | batch_size = train_dataloader.get_batch_size()) 36 | 37 | transr = TransR( 38 | ent_tot = train_dataloader.get_ent_tot(), 39 | rel_tot = train_dataloader.get_rel_tot(), 40 | dim_e = 200, 41 | dim_r = 200, 42 | p_norm = 1, 43 | norm_flag = True, 44 | rand_init = False) 45 | 46 | model_r = NegativeSampling( 47 | model = transr, 48 | loss = MarginLoss(margin = 4.0), 49 | batch_size = train_dataloader.get_batch_size() 50 | ) 51 | 52 | # pretrain transe 53 | trainer = Trainer(model = model_e, data_loader = train_dataloader, train_times = 1, alpha = 0.5, use_gpu = True) 54 | trainer.run() 55 | parameters = transe.get_parameters() 56 | transe.save_parameters("./result/transr_transe.json") 57 | 58 | # train transr 59 | transr.set_parameters(parameters) 60 | trainer = Trainer(model = model_r, data_loader = train_dataloader, train_times = 1000, alpha = 1.0, use_gpu = True) 61 | trainer.run() 62 | transr.save_checkpoint('./checkpoint/transr.ckpt') 63 | 64 | # test the model 65 | transr.load_checkpoint('./checkpoint/transr.ckpt') 66 | tester = Tester(model = transr, data_loader = test_dataloader, use_gpu = True) 67 | tester.run_link_prediction(type_constrain = False) -------------------------------------------------------------------------------- /arl/transforms/utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from torchvision import transforms 3 | 4 | 5 | class MinMaxResize: 6 | def __init__(self, shorter=800, longer=1333): 7 | self.min = shorter 8 | self.max = longer 9 | 10 | def __call__(self, x): 11 | w, h = x.size 12 | scale = self.min / min(w, h) 13 | if h < w: 14 | newh, neww = self.min, scale * w 15 | else: 16 | newh, neww = scale * h, self.min 17 | 18 | if max(newh, neww) > self.max: 19 | scale = self.max / max(newh, neww) 20 | newh = newh * scale 21 | neww = neww * scale 22 | 23 | newh, neww = int(newh + 0.5), int(neww + 0.5) 24 | newh, neww = newh // 32 * 32, neww // 32 * 32 25 | 26 | return x.resize((neww, newh), resample=Image.BICUBIC) 27 | 28 | 29 | class UnNormalize(object): 30 | def __init__(self, mean, std): 31 | self.mean = mean 32 | self.std = std 33 | 34 | def __call__(self, tensor): 35 | """ 36 | Args: 37 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 38 | Returns: 39 | Tensor: Normalized image. 40 | """ 41 | for t, m, s in zip(tensor, self.mean, self.std): 42 | t.mul_(s).add_(m) 43 | # The normalize code -> t.sub_(m).div_(s) 44 | return tensor 45 | 46 | 47 | # This is simple maximum entropy normalization performed in Inception paper 48 | inception_normalize = transforms.Compose( 49 | [transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])] 50 | ) 51 | 52 | # ViT uses simple non-biased inception normalization 53 | # https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py#L132 54 | inception_unnormalize = transforms.Compose( 55 | [UnNormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])] 56 | ) 57 | 58 | # ImageNet normalize 59 | imagenet_normalize = transforms.Compose( 60 | [transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])] 61 | ) 62 | 63 | imagenet_unnormalize = transforms.Compose( 64 | [UnNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])] 65 | ) 66 | -------------------------------------------------------------------------------- /prepro/OpenKE/openke/module/model/SimplE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .Model import Model 4 | 5 | class SimplE(Model): 6 | 7 | def __init__(self, ent_tot, rel_tot, dim = 100): 8 | super(SimplE, self).__init__(ent_tot, rel_tot) 9 | 10 | self.dim = dim 11 | self.ent_embeddings = nn.Embedding(self.ent_tot, self.dim) 12 | self.rel_embeddings = nn.Embedding(self.rel_tot, self.dim) 13 | self.rel_inv_embeddings = nn.Embedding(self.rel_tot, self.dim) 14 | 15 | nn.init.xavier_uniform_(self.ent_embeddings.weight.data) 16 | nn.init.xavier_uniform_(self.rel_embeddings.weight.data) 17 | nn.init.xavier_uniform_(self.rel_inv_embeddings.weight.data) 18 | 19 | def _calc_avg(self, h, t, r, r_inv): 20 | return (torch.sum(h * r * t, -1) + torch.sum(h * r_inv * t, -1))/2 21 | 22 | def _calc_ingr(self, h, r, t): 23 | return torch.sum(h * r * t, -1) 24 | 25 | def forward(self, data): 26 | batch_h = data['batch_h'] 27 | batch_t = data['batch_t'] 28 | batch_r = data['batch_r'] 29 | h = self.ent_embeddings(batch_h) 30 | t = self.ent_embeddings(batch_t) 31 | r = self.rel_embeddings(batch_r) 32 | r_inv = self.rel_inv_embeddings(batch_r) 33 | score = self._calc_avg(h, t, r, r_inv) 34 | return score 35 | 36 | def regularization(self, data): 37 | batch_h = data['batch_h'] 38 | batch_t = data['batch_t'] 39 | batch_r = data['batch_r'] 40 | h = self.ent_embeddings(batch_h) 41 | t = self.ent_embeddings(batch_t) 42 | r = self.rel_embeddings(batch_r) 43 | r_inv = self.rel_inv_embeddings(batch_r) 44 | regul = (torch.mean(h ** 2) + torch.mean(t ** 2) + torch.mean(r ** 2) + torch.mean(r_inv ** 2)) / 4 45 | return regul 46 | 47 | def predict(self, data): 48 | batch_h = data['batch_h'] 49 | batch_t = data['batch_t'] 50 | batch_r = data['batch_r'] 51 | h = self.ent_embeddings(batch_h) 52 | t = self.ent_embeddings(batch_t) 53 | r = self.rel_embeddings(batch_r) 54 | score = -self._calc_ingr(h, r, t) 55 | return score.cpu().data.numpy() -------------------------------------------------------------------------------- /prepro/OpenKE/openke/module/model/DistMult.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .Model import Model 4 | 5 | class DistMult(Model): 6 | 7 | def __init__(self, ent_tot, rel_tot, dim = 100, margin = None, epsilon = None): 8 | super(DistMult, self).__init__(ent_tot, rel_tot) 9 | 10 | self.dim = dim 11 | self.margin = margin 12 | self.epsilon = epsilon 13 | self.ent_embeddings = nn.Embedding(self.ent_tot, self.dim) 14 | self.rel_embeddings = nn.Embedding(self.rel_tot, self.dim) 15 | 16 | if margin == None or epsilon == None: 17 | nn.init.xavier_uniform_(self.ent_embeddings.weight.data) 18 | nn.init.xavier_uniform_(self.rel_embeddings.weight.data) 19 | else: 20 | self.embedding_range = nn.Parameter( 21 | torch.Tensor([(self.margin + self.epsilon) / self.dim]), requires_grad=False 22 | ) 23 | nn.init.uniform_( 24 | tensor = self.ent_embeddings.weight.data, 25 | a = -self.embedding_range.item(), 26 | b = self.embedding_range.item() 27 | ) 28 | nn.init.uniform_( 29 | tensor = self.rel_embeddings.weight.data, 30 | a= -self.embedding_range.item(), 31 | b= self.embedding_range.item() 32 | ) 33 | 34 | def _calc(self, h, t, r, mode): 35 | if mode != 'normal': 36 | h = h.view(-1, r.shape[0], h.shape[-1]) 37 | t = t.view(-1, r.shape[0], t.shape[-1]) 38 | r = r.view(-1, r.shape[0], r.shape[-1]) 39 | if mode == 'head_batch': 40 | score = h * (r * t) 41 | else: 42 | score = (h * r) * t 43 | score = torch.sum(score, -1).flatten() 44 | return score 45 | 46 | def forward(self, data): 47 | batch_h = data['batch_h'] 48 | batch_t = data['batch_t'] 49 | batch_r = data['batch_r'] 50 | mode = data['mode'] 51 | h = self.ent_embeddings(batch_h) 52 | t = self.ent_embeddings(batch_t) 53 | r = self.rel_embeddings(batch_r) 54 | score = self._calc(h ,t, r, mode) 55 | return score 56 | 57 | def regularization(self, data): 58 | batch_h = data['batch_h'] 59 | batch_t = data['batch_t'] 60 | batch_r = data['batch_r'] 61 | h = self.ent_embeddings(batch_h) 62 | t = self.ent_embeddings(batch_t) 63 | r = self.rel_embeddings(batch_r) 64 | regul = (torch.mean(h ** 2) + torch.mean(t ** 2) + torch.mean(r ** 2)) / 3 65 | return regul 66 | 67 | def l3_regularization(self): 68 | return (self.ent_embeddings.weight.norm(p = 3)**3 + self.rel_embeddings.weight.norm(p = 3)**3) 69 | 70 | def predict(self, data): 71 | score = -self.forward(data) 72 | return score.cpu().data.numpy() 73 | -------------------------------------------------------------------------------- /run_scripts/finetune_arl.sh: -------------------------------------------------------------------------------- 1 | seed=42 2 | num_gpus=1 3 | per_gpu_batchsize=16 4 | finetune_embeddings=True 5 | load_path= 6 | 7 | # === VQA === 8 | python main.py with seed=${seed} data_root=data/finetune_arrows_umls/ \ 9 | num_gpus=${num_gpus} num_nodes=1 \ 10 | task_finetune_vqa_vqa_rad \ 11 | per_gpu_batchsize=${per_gpu_batchsize} \ 12 | clip16 text_roberta \ 13 | image_size=384 \ 14 | tokenizer=downloaded/roberta-base \ 15 | finetune_embeddings=${finetune_embeddings} \ 16 | load_path=${load_path} 17 | 18 | python main.py with seed=${seed} data_root=data/finetune_arrows_umls/ \ 19 | num_gpus=${num_gpus} num_nodes=1 \ 20 | task_finetune_vqa_slack \ 21 | per_gpu_batchsize=${per_gpu_batchsize} \ 22 | clip16 text_roberta \ 23 | image_size=384 \ 24 | tokenizer=downloaded/roberta-base \ 25 | finetune_embeddings=${finetune_embeddings} \ 26 | load_path=${load_path} \ 27 | clip_resizedcrop 28 | 29 | python main.py with seed=${seed} data_root=data/finetune_arrows_umls/ \ 30 | num_gpus=${num_gpus} num_nodes=1 \ 31 | task_finetune_vqa_medvqa_2019 \ 32 | per_gpu_batchsize=${per_gpu_batchsize} \ 33 | clip16 text_roberta \ 34 | image_size=384 \ 35 | tokenizer=downloaded/roberta-base \ 36 | finetune_embeddings=${finetune_embeddings} \ 37 | load_path=${load_path} \ 38 | clip_resizedcrop 39 | 40 | # === CLS === 41 | python main.py with seed=${seed} data_root=data/finetune_arrows_umls/ \ 42 | num_gpus=${num_gpus} num_nodes=1 \ 43 | task_finetune_cls_melinda_p_meth_label \ 44 | per_gpu_batchsize=${per_gpu_batchsize} \ 45 | clip16 text_roberta \ 46 | image_size=384 \ 47 | tokenizer=downloaded/roberta-base \ 48 | finetune_embeddings=${finetune_embeddings} \ 49 | load_path=${load_path} \ 50 | clip_resizedcrop 51 | 52 | # === IRTR === 53 | python main.py with data_root=data/finetune_arrows_umls/ \ 54 | num_gpus=${num_gpus} num_nodes=1 \ 55 | task_finetune_irtr_roco get_recall_metric=True \ 56 | pwdper_gpu_batchsize=${per_gpu_batchsize} \ 57 | clip16 text_roberta \ 58 | image_size=288 \ 59 | tokenizer=downloaded/roberta-base \ 60 | finetune_embeddings=${finetune_embeddings} \ 61 | test_only=True \ 62 | load_path=${load_path} 63 | 64 | python main.py with data_root=data/finetune_arrows_umls/ \ 65 | num_gpus=${num_gpus} num_nodes=1 \ 66 | task_finetune_irtr_roco get_recall_metric=False \ 67 | per_gpu_batchsize=${per_gpu_batchsize} \ 68 | clip16 text_roberta \ 69 | image_size=384 \ 70 | tokenizer=downloaded/roberta-base \ 71 | finetune_embeddings=${finetune_embeddings} \ 72 | load_path=${load_path} \ 73 | clip_resizedcrop 74 | -------------------------------------------------------------------------------- /prepro/OpenKE/openke/module/model/ComplEx.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .Model import Model 4 | 5 | class ComplEx(Model): 6 | def __init__(self, ent_tot, rel_tot, dim = 100): 7 | super(ComplEx, self).__init__(ent_tot, rel_tot) 8 | 9 | self.dim = dim 10 | self.ent_re_embeddings = nn.Embedding(self.ent_tot, self.dim) 11 | self.ent_im_embeddings = nn.Embedding(self.ent_tot, self.dim) 12 | self.rel_re_embeddings = nn.Embedding(self.rel_tot, self.dim) 13 | self.rel_im_embeddings = nn.Embedding(self.rel_tot, self.dim) 14 | 15 | nn.init.xavier_uniform_(self.ent_re_embeddings.weight.data) 16 | nn.init.xavier_uniform_(self.ent_im_embeddings.weight.data) 17 | nn.init.xavier_uniform_(self.rel_re_embeddings.weight.data) 18 | nn.init.xavier_uniform_(self.rel_im_embeddings.weight.data) 19 | 20 | def _calc(self, h_re, h_im, t_re, t_im, r_re, r_im): 21 | return torch.sum( 22 | h_re * t_re * r_re 23 | + h_im * t_im * r_re 24 | + h_re * t_im * r_im 25 | - h_im * t_re * r_im, 26 | -1 27 | ) 28 | 29 | def forward(self, data): 30 | batch_h = data['batch_h'] 31 | batch_t = data['batch_t'] 32 | batch_r = data['batch_r'] 33 | h_re = self.ent_re_embeddings(batch_h) 34 | h_im = self.ent_im_embeddings(batch_h) 35 | t_re = self.ent_re_embeddings(batch_t) 36 | t_im = self.ent_im_embeddings(batch_t) 37 | r_re = self.rel_re_embeddings(batch_r) 38 | r_im = self.rel_im_embeddings(batch_r) 39 | score = self._calc(h_re, h_im, t_re, t_im, r_re, r_im) 40 | return score 41 | 42 | def regularization(self, data): 43 | batch_h = data['batch_h'] 44 | batch_t = data['batch_t'] 45 | batch_r = data['batch_r'] 46 | h_re = self.ent_re_embeddings(batch_h) 47 | h_im = self.ent_im_embeddings(batch_h) 48 | t_re = self.ent_re_embeddings(batch_t) 49 | t_im = self.ent_im_embeddings(batch_t) 50 | r_re = self.rel_re_embeddings(batch_r) 51 | r_im = self.rel_im_embeddings(batch_r) 52 | regul = (torch.mean(h_re ** 2) + 53 | torch.mean(h_im ** 2) + 54 | torch.mean(t_re ** 2) + 55 | torch.mean(t_im ** 2) + 56 | torch.mean(r_re ** 2) + 57 | torch.mean(r_im ** 2)) / 6 58 | return regul 59 | 60 | def predict(self, data): 61 | score = -self.forward(data) 62 | return score.cpu().data.numpy() -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | simcse/__pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | .DS_Store 132 | .vscode 133 | data 134 | result 135 | 136 | # Pycharm 137 | .idea/ 138 | downloaded/ 139 | -------------------------------------------------------------------------------- /arl/datasets/pretraining_mimic_cxr_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import defaultdict 3 | 4 | from .base_dataset import BaseDataset 5 | from .utils import record_ent_ref 6 | 7 | 8 | class MIMICCXRDataset(BaseDataset): 9 | def __init__(self, *args, split="", **kwargs): 10 | assert split in ["train", "val", "test"] 11 | self.split = split 12 | 13 | if split == "train": 14 | names = ["mimic_cxr_train"] 15 | elif split == "val": 16 | names = ["mimic_cxr_val"] 17 | elif split == "test": 18 | names = ["mimic_cxr_test"] 19 | else: 20 | raise ValueError 21 | 22 | super().__init__(*args, **kwargs, names=names, text_column_name="caption") 23 | self.chexpert_labels = self.table["chexpert"].to_pandas().tolist() 24 | dup_indices = [self.index_mapper[i][0] for i in self.index_mapper] 25 | self.chexpert_labels = [self.chexpert_labels[idx] for idx in dup_indices] 26 | self.group_mappings = defaultdict(set) 27 | for idx, label in enumerate(self.chexpert_labels): 28 | self.group_mappings[str(label)].add(idx) 29 | full_index_set = set(list(range(len(self.chexpert_labels)))) 30 | for k in self.group_mappings: 31 | self.group_mappings[k] = full_index_set - self.group_mappings[k] 32 | 33 | def get_false_image(self, rep, image_key="image", selected_index=None): 34 | chexpert_label = str(self.chexpert_labels[selected_index]) 35 | candidate_index = self.group_mappings[chexpert_label] 36 | 37 | random_index = random.sample(candidate_index, 1)[0] 38 | image = self.get_raw_image(random_index, image_key=image_key) 39 | image_tensor = [tr(image) for tr in self.transforms] 40 | return {f"false_image_{rep}": image_tensor} 41 | 42 | def get_false_text(self, rep, selected_index=None): 43 | chexpert_label = str(self.chexpert_labels[selected_index]) 44 | candidate_index = self.group_mappings[chexpert_label] 45 | 46 | random_index = random.sample(candidate_index, 1)[0] 47 | index, caption_index = self.index_mapper[random_index] 48 | text = self.all_texts[index][caption_index] 49 | encoding = self.tokenizer( 50 | text, 51 | padding="max_length", 52 | truncation=True, 53 | max_length=self.max_text_len, 54 | return_special_tokens_mask=True, 55 | return_offsets_mapping=True, 56 | ) 57 | encoding = record_ent_ref(encoding, self.all_txt_ents[index][caption_index]) 58 | return {f"false_text_{rep}": (text, encoding)} 59 | 60 | def __getitem__(self, index): 61 | return self.get_suite(index) 62 | -------------------------------------------------------------------------------- /arl/transforms/transform.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from torchvision import transforms 3 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, RandomResizedCrop 4 | 5 | from .randaug import RandAugment 6 | from .utils import ( 7 | inception_normalize, 8 | imagenet_normalize, 9 | ) 10 | 11 | 12 | def imagenet_transform(size=800): 13 | return transforms.Compose( 14 | [ 15 | Resize(size, interpolation=Image.BICUBIC), 16 | CenterCrop(size), 17 | transforms.ToTensor(), 18 | imagenet_normalize, 19 | ] 20 | ) 21 | 22 | 23 | def imagenet_transform_randaug(size=800): 24 | trs = transforms.Compose( 25 | [ 26 | Resize(size, interpolation=Image.BICUBIC), 27 | CenterCrop(size), 28 | transforms.ToTensor(), 29 | imagenet_normalize, 30 | ] 31 | ) 32 | trs.transforms.insert(0, RandAugment(2, 9)) 33 | return trs 34 | 35 | 36 | def vit_transform(size=800): 37 | return transforms.Compose( 38 | [ 39 | Resize(size, interpolation=Image.BICUBIC), 40 | CenterCrop(size), 41 | transforms.ToTensor(), 42 | inception_normalize, 43 | ] 44 | ) 45 | 46 | 47 | def vit_transform_randaug(size=800): 48 | trs = transforms.Compose( 49 | [ 50 | Resize(size, interpolation=Image.BICUBIC), 51 | CenterCrop(size), 52 | transforms.ToTensor(), 53 | inception_normalize, 54 | ] 55 | ) 56 | trs.transforms.insert(0, RandAugment(2, 9)) 57 | return trs 58 | 59 | 60 | def clip_transform(size): 61 | return Compose([ 62 | Resize(size, interpolation=Image.BICUBIC), 63 | CenterCrop(size), 64 | lambda image: image.convert("RGB"), 65 | ToTensor(), 66 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 67 | ]) 68 | 69 | 70 | def clip_transform_resizedcrop(size): 71 | return Compose([ 72 | RandomResizedCrop(size, scale=(0.9, 1.0), interpolation=Image.BICUBIC), 73 | CenterCrop(size), 74 | lambda image: image.convert("RGB"), 75 | ToTensor(), 76 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 77 | ]) 78 | 79 | 80 | def clip_transform_randaug(size): 81 | trs = Compose([ 82 | Resize(size, interpolation=Image.BICUBIC), 83 | CenterCrop(size), 84 | lambda image: image.convert("RGB"), 85 | ToTensor(), 86 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 87 | ]) 88 | trs.transforms.insert(0, lambda image: image.convert('RGBA')) 89 | trs.transforms.insert(0, RandAugment(2, 9)) 90 | trs.transforms.insert(0, lambda image: image.convert('RGB')) 91 | return trs 92 | -------------------------------------------------------------------------------- /prepro/OpenKE/openke/module/model/Analogy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .Model import Model 5 | 6 | class Analogy(Model): 7 | 8 | def __init__(self, ent_tot, rel_tot, dim = 100): 9 | super(Analogy, self).__init__(ent_tot, rel_tot) 10 | 11 | self.dim = dim 12 | self.ent_re_embeddings = nn.Embedding(self.ent_tot, self.dim) 13 | self.ent_im_embeddings = nn.Embedding(self.ent_tot, self.dim) 14 | self.rel_re_embeddings = nn.Embedding(self.rel_tot, self.dim) 15 | self.rel_im_embeddings = nn.Embedding(self.rel_tot, self.dim) 16 | self.ent_embeddings = nn.Embedding(self.ent_tot, self.dim * 2) 17 | self.rel_embeddings = nn.Embedding(self.rel_tot, self.dim * 2) 18 | 19 | nn.init.xavier_uniform_(self.ent_re_embeddings.weight.data) 20 | nn.init.xavier_uniform_(self.ent_im_embeddings.weight.data) 21 | nn.init.xavier_uniform_(self.rel_re_embeddings.weight.data) 22 | nn.init.xavier_uniform_(self.rel_im_embeddings.weight.data) 23 | nn.init.xavier_uniform_(self.ent_embeddings.weight.data) 24 | nn.init.xavier_uniform_(self.rel_embeddings.weight.data) 25 | 26 | def _calc(self, h_re, h_im, h, t_re, t_im, t, r_re, r_im, r): 27 | return (-torch.sum(r_re * h_re * t_re + 28 | r_re * h_im * t_im + 29 | r_im * h_re * t_im - 30 | r_im * h_im * t_re, -1) 31 | -torch.sum(h * t * r, -1)) 32 | 33 | def forward(self, data): 34 | batch_h = data['batch_h'] 35 | batch_t = data['batch_t'] 36 | batch_r = data['batch_r'] 37 | h_re = self.ent_re_embeddings(batch_h) 38 | h_im = self.ent_im_embeddings(batch_h) 39 | h = self.ent_embeddings(batch_h) 40 | t_re = self.ent_re_embeddings(batch_t) 41 | t_im = self.ent_im_embeddings(batch_t) 42 | t = self.ent_embeddings(batch_t) 43 | r_re = self.rel_re_embeddings(batch_r) 44 | r_im = self.rel_im_embeddings(batch_r) 45 | r = self.rel_embeddings(batch_r) 46 | score = self._calc(h_re, h_im, h, t_re, t_im, t, r_re, r_im, r) 47 | return score 48 | 49 | def regularization(self, data): 50 | batch_h = data['batch_h'] 51 | batch_t = data['batch_t'] 52 | batch_r = data['batch_r'] 53 | h_re = self.ent_re_embeddings(batch_h) 54 | h_im = self.ent_im_embeddings(batch_h) 55 | h = self.ent_embeddings(batch_h) 56 | t_re = self.ent_re_embeddings(batch_t) 57 | t_im = self.ent_im_embeddings(batch_t) 58 | t = self.ent_embeddings(batch_t) 59 | r_re = self.rel_re_embeddings(batch_r) 60 | r_im = self.rel_im_embeddings(batch_r) 61 | r = self.rel_embeddings(batch_r) 62 | regul = (torch.mean(h_re ** 2) + 63 | torch.mean(h_im ** 2) + 64 | torch.mean(h ** 2) + 65 | torch.mean(t_re ** 2) + 66 | torch.mean(t_im ** 2) + 67 | torch.mean(t ** 2) + 68 | torch.mean(r_re ** 2) + 69 | torch.mean(r_im ** 2) + 70 | torch.mean(r ** 2)) / 9 71 | return regul 72 | 73 | def predict(self, data): 74 | score = -self.forward(data) 75 | return score.cpu().data.numpy() -------------------------------------------------------------------------------- /prepro/OpenKE/openke/module/model/TransE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .Model import Model 5 | 6 | class TransE(Model): 7 | 8 | def __init__(self, ent_tot, rel_tot, dim = 100, p_norm = 1, norm_flag = True, margin = None, epsilon = None): 9 | super(TransE, self).__init__(ent_tot, rel_tot) 10 | 11 | self.dim = dim 12 | self.margin = margin 13 | self.epsilon = epsilon 14 | self.norm_flag = norm_flag 15 | self.p_norm = p_norm 16 | 17 | self.ent_embeddings = nn.Embedding(self.ent_tot, self.dim) 18 | self.rel_embeddings = nn.Embedding(self.rel_tot, self.dim) 19 | 20 | if margin == None or epsilon == None: 21 | nn.init.xavier_uniform_(self.ent_embeddings.weight.data) 22 | nn.init.xavier_uniform_(self.rel_embeddings.weight.data) 23 | else: 24 | self.embedding_range = nn.Parameter( 25 | torch.Tensor([(self.margin + self.epsilon) / self.dim]), requires_grad=False 26 | ) 27 | nn.init.uniform_( 28 | tensor = self.ent_embeddings.weight.data, 29 | a = -self.embedding_range.item(), 30 | b = self.embedding_range.item() 31 | ) 32 | nn.init.uniform_( 33 | tensor = self.rel_embeddings.weight.data, 34 | a= -self.embedding_range.item(), 35 | b= self.embedding_range.item() 36 | ) 37 | 38 | if margin != None: 39 | self.margin = nn.Parameter(torch.Tensor([margin])) 40 | self.margin.requires_grad = False 41 | self.margin_flag = True 42 | else: 43 | self.margin_flag = False 44 | 45 | 46 | def _calc(self, h, t, r, mode): 47 | if self.norm_flag: 48 | h = F.normalize(h, 2, -1) 49 | r = F.normalize(r, 2, -1) 50 | t = F.normalize(t, 2, -1) 51 | if mode != 'normal': 52 | h = h.view(-1, r.shape[0], h.shape[-1]) 53 | t = t.view(-1, r.shape[0], t.shape[-1]) 54 | r = r.view(-1, r.shape[0], r.shape[-1]) 55 | if mode == 'head_batch': 56 | score = h + (r - t) 57 | else: 58 | score = (h + r) - t 59 | score = torch.norm(score, self.p_norm, -1).flatten() 60 | return score 61 | 62 | def forward(self, data): 63 | batch_h = data['batch_h'] 64 | batch_t = data['batch_t'] 65 | batch_r = data['batch_r'] 66 | mode = data['mode'] 67 | h = self.ent_embeddings(batch_h) 68 | t = self.ent_embeddings(batch_t) 69 | r = self.rel_embeddings(batch_r) 70 | score = self._calc(h ,t, r, mode) 71 | if self.margin_flag: 72 | return self.margin - score 73 | else: 74 | return score 75 | 76 | def regularization(self, data): 77 | batch_h = data['batch_h'] 78 | batch_t = data['batch_t'] 79 | batch_r = data['batch_r'] 80 | h = self.ent_embeddings(batch_h) 81 | t = self.ent_embeddings(batch_t) 82 | r = self.rel_embeddings(batch_r) 83 | regul = (torch.mean(h ** 2) + 84 | torch.mean(t ** 2) + 85 | torch.mean(r ** 2)) / 3 86 | return regul 87 | 88 | def predict(self, data): 89 | score = self.forward(data) 90 | if self.margin_flag: 91 | score = self.margin - score 92 | return score.cpu().data.numpy() 93 | else: 94 | return score.cpu().data.numpy() -------------------------------------------------------------------------------- /arl/datamodules/multitask_datamodule.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | from pytorch_lightning import LightningDataModule 4 | from torch.utils.data import DataLoader 5 | from torch.utils.data.dataset import ConcatDataset 6 | from torch.utils.data.distributed import DistributedSampler 7 | 8 | from . import _datamodules 9 | 10 | 11 | class MTDataModule(LightningDataModule): 12 | def __init__(self, _config, dist=False): 13 | datamodule_keys = _config["datasets"] 14 | assert len(datamodule_keys) > 0 15 | 16 | super().__init__() 17 | 18 | self.dm_keys = datamodule_keys 19 | self.dm_dicts = {key: _datamodules[key](_config) for key in datamodule_keys} 20 | self.dms = [v for k, v in self.dm_dicts.items()] 21 | 22 | self.batch_size = self.dms[0].batch_size 23 | self.vocab_size = self.dms[0].vocab_size 24 | self.num_workers = self.dms[0].num_workers 25 | 26 | self.dist = dist 27 | 28 | def prepare_data(self): 29 | for dm in self.dms: 30 | dm.prepare_data() 31 | 32 | def setup(self, stage): 33 | for dm in self.dms: 34 | dm.setup(stage) 35 | 36 | self.train_dataset = ConcatDataset([dm.train_dataset for dm in self.dms]) 37 | self.val_dataset = ConcatDataset([dm.val_dataset for dm in self.dms]) 38 | self.test_dataset = ConcatDataset([dm.test_dataset for dm in self.dms]) 39 | self.tokenizer = self.dms[0].tokenizer 40 | self.collate = functools.partial(self.dms[0].train_dataset.collate, mlm_collator=self.dms[0].mlm_collator) 41 | 42 | if self.dist: 43 | self.train_sampler = DistributedSampler(self.train_dataset, shuffle=True) 44 | self.val_sampler = DistributedSampler(self.val_dataset, shuffle=True) 45 | self.test_sampler = DistributedSampler(self.test_dataset, shuffle=False) 46 | else: 47 | self.train_sampler = None 48 | self.val_sampler = None 49 | self.test_sampler = None 50 | 51 | def train_dataloader(self): 52 | loader = DataLoader( 53 | self.train_dataset, 54 | batch_size=self.batch_size, 55 | sampler=self.train_sampler, 56 | num_workers=self.num_workers, 57 | collate_fn=self.collate, 58 | ) 59 | return loader 60 | 61 | def val_dataloader(self, batch_size=None): 62 | loader = DataLoader( 63 | self.val_dataset, 64 | batch_size=batch_size if batch_size is not None else self.batch_size, 65 | sampler=self.val_sampler, 66 | num_workers=self.num_workers, 67 | collate_fn=self.collate, 68 | ) 69 | return loader 70 | 71 | def test_dataloader(self): 72 | loader = DataLoader( 73 | self.test_dataset, 74 | batch_size=self.batch_size, 75 | sampler=self.test_sampler, 76 | num_workers=self.num_workers, 77 | collate_fn=self.collate, 78 | ) 79 | return loader 80 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import resource 4 | 5 | import pytorch_lightning as pl 6 | 7 | from arl.config import ex 8 | from arl.datamodules.multitask_datamodule import MTDataModule 9 | from arl.modules import ARLTransformerSS 10 | 11 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 12 | resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) 13 | 14 | 15 | @ex.automain 16 | def main(_config): 17 | _config = copy.deepcopy(_config) 18 | pl.seed_everything(_config["seed"]) 19 | 20 | # Data modules 21 | dm = MTDataModule(_config, dist=True) 22 | 23 | # Module 24 | model = ARLTransformerSS(_config) 25 | 26 | # Loggers 27 | os.makedirs(_config["log_dir"], exist_ok=True) 28 | exp_name = f'{_config["exp_name"]}' 29 | run_name = f'{exp_name}-seed{_config["seed"]}-from_{_config["load_path"].replace("/", "_")}' 30 | tb_logger = pl.loggers.TensorBoardLogger(_config["log_dir"], name=run_name) 31 | wb_logger = pl.loggers.WandbLogger( 32 | project="ACMMM-2022-ARL-finetune" if "finetune" in exp_name else "ACMMM-2022-ARL-pretrain", 33 | name=run_name 34 | ) 35 | loggers = [tb_logger, wb_logger] 36 | 37 | # Callback 38 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 39 | save_top_k=1, 40 | verbose=True, 41 | monitor="val/the_metric", 42 | mode="max", 43 | save_last=True, 44 | save_weights_only=True if "finetune" in exp_name else False 45 | ) 46 | lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step") 47 | callbacks = [checkpoint_callback, lr_callback] 48 | 49 | # Training Hyper-Parameters 50 | num_gpus = (_config["num_gpus"] if isinstance(_config["num_gpus"], int) else len(_config["num_gpus"])) 51 | grad_steps = max(_config["batch_size"] // (_config["per_gpu_batchsize"] * num_gpus * _config["num_nodes"]), 1) 52 | max_steps = _config["max_steps"] if _config["max_steps"] is not None else None 53 | max_epochs = _config["max_epoch"] if max_steps is None else 1000 54 | 55 | # Trainer 56 | trainer = pl.Trainer( 57 | gpus=num_gpus, 58 | num_nodes=_config["num_nodes"], 59 | precision=_config["precision"], 60 | accelerator="ddp", 61 | benchmark=True, 62 | deterministic=True, 63 | max_epochs=max_epochs, 64 | max_steps=max_steps, 65 | callbacks=callbacks, 66 | logger=loggers, 67 | prepare_data_per_node=False, 68 | replace_sampler_ddp=False, 69 | accumulate_grad_batches=grad_steps, 70 | log_every_n_steps=10, 71 | flush_logs_every_n_steps=10, 72 | resume_from_checkpoint=_config["resume_from"], 73 | weights_summary="top", 74 | fast_dev_run=_config["fast_dev_run"], 75 | val_check_interval=_config["val_check_interval"], 76 | default_root_dir=_config["default_root_dir"] 77 | ) 78 | 79 | if not _config["test_only"]: 80 | trainer.fit(model, datamodule=dm) 81 | if "pretrain" not in _config["exp_name"]: 82 | trainer.test(ckpt_path="best" if "irtr" not in _config["exp_name"] else None, datamodule=dm) 83 | else: 84 | trainer.test(model, datamodule=dm) 85 | -------------------------------------------------------------------------------- /prepro/OpenKE/openke/base/Setting.h: -------------------------------------------------------------------------------- 1 | #ifndef SETTING_H 2 | #define SETTING_H 3 | #define INT long 4 | #define REAL float 5 | #include 6 | #include 7 | #include 8 | 9 | std::string inPath = "../data/FB15K/"; 10 | std::string outPath = "../data/FB15K/"; 11 | std::string ent_file = ""; 12 | std::string rel_file = ""; 13 | std::string train_file = ""; 14 | std::string valid_file = ""; 15 | std::string test_file = ""; 16 | 17 | extern "C" 18 | void setInPath(char *path) { 19 | INT len = strlen(path); 20 | inPath = ""; 21 | for (INT i = 0; i < len; i++) 22 | inPath = inPath + path[i]; 23 | printf("Input Files Path : %s\n", inPath.c_str()); 24 | } 25 | 26 | extern "C" 27 | void setOutPath(char *path) { 28 | INT len = strlen(path); 29 | outPath = ""; 30 | for (INT i = 0; i < len; i++) 31 | outPath = outPath + path[i]; 32 | printf("Output Files Path : %s\n", outPath.c_str()); 33 | } 34 | 35 | extern "C" 36 | void setTrainPath(char *path) { 37 | INT len = strlen(path); 38 | train_file = ""; 39 | for (INT i = 0; i < len; i++) 40 | train_file = train_file + path[i]; 41 | printf("Training Files Path : %s\n", train_file.c_str()); 42 | } 43 | 44 | extern "C" 45 | void setValidPath(char *path) { 46 | INT len = strlen(path); 47 | valid_file = ""; 48 | for (INT i = 0; i < len; i++) 49 | valid_file = valid_file + path[i]; 50 | printf("Valid Files Path : %s\n", valid_file.c_str()); 51 | } 52 | 53 | extern "C" 54 | void setTestPath(char *path) { 55 | INT len = strlen(path); 56 | test_file = ""; 57 | for (INT i = 0; i < len; i++) 58 | test_file = test_file + path[i]; 59 | printf("Test Files Path : %s\n", test_file.c_str()); 60 | } 61 | 62 | extern "C" 63 | void setEntPath(char *path) { 64 | INT len = strlen(path); 65 | ent_file = ""; 66 | for (INT i = 0; i < len; i++) 67 | ent_file = ent_file + path[i]; 68 | printf("Entity Files Path : %s\n", ent_file.c_str()); 69 | } 70 | 71 | extern "C" 72 | void setRelPath(char *path) { 73 | INT len = strlen(path); 74 | rel_file = ""; 75 | for (INT i = 0; i < len; i++) 76 | rel_file = rel_file + path[i]; 77 | printf("Relation Files Path : %s\n", rel_file.c_str()); 78 | } 79 | 80 | /* 81 | ============================================================ 82 | */ 83 | 84 | INT workThreads = 1; 85 | 86 | extern "C" 87 | void setWorkThreads(INT threads) { 88 | workThreads = threads; 89 | } 90 | 91 | extern "C" 92 | INT getWorkThreads() { 93 | return workThreads; 94 | } 95 | 96 | /* 97 | ============================================================ 98 | */ 99 | 100 | INT relationTotal = 0; 101 | INT entityTotal = 0; 102 | INT tripleTotal = 0; 103 | INT testTotal = 0; 104 | INT trainTotal = 0; 105 | INT validTotal = 0; 106 | 107 | extern "C" 108 | INT getEntityTotal() { 109 | return entityTotal; 110 | } 111 | 112 | extern "C" 113 | INT getRelationTotal() { 114 | return relationTotal; 115 | } 116 | 117 | extern "C" 118 | INT getTripleTotal() { 119 | return tripleTotal; 120 | } 121 | 122 | extern "C" 123 | INT getTrainTotal() { 124 | return trainTotal; 125 | } 126 | 127 | extern "C" 128 | INT getTestTotal() { 129 | return testTotal; 130 | } 131 | 132 | extern "C" 133 | INT getValidTotal() { 134 | return validTotal; 135 | } 136 | /* 137 | ============================================================ 138 | */ 139 | 140 | INT bernFlag = 0; 141 | 142 | extern "C" 143 | void setBern(INT con) { 144 | bernFlag = con; 145 | } 146 | 147 | #endif 148 | -------------------------------------------------------------------------------- /prepro/OpenKE/openke/module/model/TransR.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .Model import Model 5 | 6 | class TransR(Model): 7 | 8 | def __init__(self, ent_tot, rel_tot, dim_e = 100, dim_r = 100, p_norm = 1, norm_flag = True, rand_init = False, margin = None): 9 | super(TransR, self).__init__(ent_tot, rel_tot) 10 | 11 | self.dim_e = dim_e 12 | self.dim_r = dim_r 13 | self.norm_flag = norm_flag 14 | self.p_norm = p_norm 15 | self.rand_init = rand_init 16 | 17 | self.ent_embeddings = nn.Embedding(self.ent_tot, self.dim_e) 18 | self.rel_embeddings = nn.Embedding(self.rel_tot, self.dim_r) 19 | nn.init.xavier_uniform_(self.ent_embeddings.weight.data) 20 | nn.init.xavier_uniform_(self.rel_embeddings.weight.data) 21 | 22 | self.transfer_matrix = nn.Embedding(self.rel_tot, self.dim_e * self.dim_r) 23 | if not self.rand_init: 24 | identity = torch.zeros(self.dim_e, self.dim_r) 25 | for i in range(min(self.dim_e, self.dim_r)): 26 | identity[i][i] = 1 27 | identity = identity.view(self.dim_r * self.dim_e) 28 | for i in range(self.rel_tot): 29 | self.transfer_matrix.weight.data[i] = identity 30 | else: 31 | nn.init.xavier_uniform_(self.transfer_matrix.weight.data) 32 | 33 | if margin != None: 34 | self.margin = nn.Parameter(torch.Tensor([margin])) 35 | self.margin.requires_grad = False 36 | self.margin_flag = True 37 | else: 38 | self.margin_flag = False 39 | 40 | def _calc(self, h, t, r, mode): 41 | if self.norm_flag: 42 | h = F.normalize(h, 2, -1) 43 | r = F.normalize(r, 2, -1) 44 | t = F.normalize(t, 2, -1) 45 | if mode != 'normal': 46 | h = h.view(-1, r.shape[0], h.shape[-1]) 47 | t = t.view(-1, r.shape[0], t.shape[-1]) 48 | r = r.view(-1, r.shape[0], r.shape[-1]) 49 | if mode == 'head_batch': 50 | score = h + (r - t) 51 | else: 52 | score = (h + r) - t 53 | score = torch.norm(score, self.p_norm, -1).flatten() 54 | return score 55 | 56 | def _transfer(self, e, r_transfer): 57 | r_transfer = r_transfer.view(-1, self.dim_e, self.dim_r) 58 | if e.shape[0] != r_transfer.shape[0]: 59 | e = e.view(-1, r_transfer.shape[0], self.dim_e).permute(1, 0, 2) 60 | e = torch.matmul(e, r_transfer).permute(1, 0, 2) 61 | else: 62 | e = e.view(-1, 1, self.dim_e) 63 | e = torch.matmul(e, r_transfer) 64 | return e.view(-1, self.dim_r) 65 | 66 | def forward(self, data): 67 | batch_h = data['batch_h'] 68 | batch_t = data['batch_t'] 69 | batch_r = data['batch_r'] 70 | mode = data['mode'] 71 | h = self.ent_embeddings(batch_h) 72 | t = self.ent_embeddings(batch_t) 73 | r = self.rel_embeddings(batch_r) 74 | r_transfer = self.transfer_matrix(batch_r) 75 | h = self._transfer(h, r_transfer) 76 | t = self._transfer(t, r_transfer) 77 | score = self._calc(h ,t, r, mode) 78 | if self.margin_flag: 79 | return self.margin - score 80 | else: 81 | return score 82 | 83 | def regularization(self, data): 84 | batch_h = data['batch_h'] 85 | batch_t = data['batch_t'] 86 | batch_r = data['batch_r'] 87 | h = self.ent_embeddings(batch_h) 88 | t = self.ent_embeddings(batch_t) 89 | r = self.rel_embeddings(batch_r) 90 | r_transfer = self.transfer_matrix(batch_r) 91 | regul = (torch.mean(h ** 2) + 92 | torch.mean(t ** 2) + 93 | torch.mean(r ** 2) + 94 | torch.mean(r_transfer ** 2)) / 4 95 | return regul * regul 96 | 97 | def predict(self, data): 98 | score = self.forward(data) 99 | if self.margin_flag: 100 | score = self.margin - score 101 | return score.cpu().data.numpy() 102 | else: 103 | return score.cpu().data.numpy() -------------------------------------------------------------------------------- /prepro/OpenKE/openke/module/model/HolE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .Model import Model 4 | import numpy 5 | from numpy import fft 6 | 7 | class HolE(Model): 8 | 9 | def __init__(self, ent_tot, rel_tot, dim = 100, margin = None, epsilon = None): 10 | super(HolE, self).__init__(ent_tot, rel_tot) 11 | 12 | self.dim = dim 13 | self.margin = margin 14 | self.epsilon = epsilon 15 | self.ent_embeddings = nn.Embedding(self.ent_tot, self.dim) 16 | self.rel_embeddings = nn.Embedding(self.rel_tot, self.dim) 17 | 18 | if margin == None or epsilon == None: 19 | nn.init.xavier_uniform_(self.ent_embeddings.weight.data) 20 | nn.init.xavier_uniform_(self.rel_embeddings.weight.data) 21 | else: 22 | self.embedding_range = nn.Parameter( 23 | torch.Tensor([(self.margin + self.epsilon) / self.dim]), requires_grad=False 24 | ) 25 | nn.init.uniform_( 26 | tensor = self.ent_embeddings.weight.data, 27 | a = -self.embedding_range.item(), 28 | b = self.embedding_range.item() 29 | ) 30 | nn.init.uniform_( 31 | tensor = self.rel_embeddings.weight.data, 32 | a= -self.embedding_range.item(), 33 | b= self.embedding_range.item() 34 | ) 35 | 36 | def _conj(self, tensor): 37 | zero_shape = (list)(tensor.shape) 38 | one_shape = (list)(tensor.shape) 39 | zero_shape[-1] = 1 40 | one_shape[-1] -= 1 41 | ze = torch.zeros(size = zero_shape, device = tensor.device) 42 | on = torch.ones(size = one_shape, device = tensor.device) 43 | matrix = torch.cat([ze, on], -1) 44 | matrix = 2 * matrix 45 | return tensor - matrix * tensor 46 | 47 | def _real(self, tensor): 48 | dimensions = len(tensor.shape) 49 | return tensor.narrow(dimensions - 1, 0, 1) 50 | 51 | def _imag(self, tensor): 52 | dimensions = len(tensor.shape) 53 | return tensor.narrow(dimensions - 1, 1, 1) 54 | 55 | def _mul(self, real_1, imag_1, real_2, imag_2): 56 | real = real_1 * real_2 - imag_1 * imag_2 57 | imag = real_1 * imag_2 + imag_1 * real_2 58 | return torch.cat([real, imag], -1) 59 | 60 | def _ccorr(self, a, b): 61 | a = self._conj(torch.rfft(a, signal_ndim = 1, onesided = False)) 62 | b = torch.rfft(b, signal_ndim = 1, onesided = False) 63 | res = self._mul(self._real(a), self._imag(a), self._real(b), self._imag(b)) 64 | res = torch.ifft(res, signal_ndim = 1) 65 | return self._real(res).flatten(start_dim = -2) 66 | 67 | def _calc(self, h, t, r, mode): 68 | if mode != 'normal': 69 | h = h.view(-1, r.shape[0], h.shape[-1]) 70 | t = t.view(-1, r.shape[0], t.shape[-1]) 71 | r = r.view(-1, r.shape[0], r.shape[-1]) 72 | score = self._ccorr(h, t) * r 73 | score = torch.sum(score, -1).flatten() 74 | return score 75 | 76 | def forward(self, data): 77 | batch_h = data['batch_h'] 78 | batch_t = data['batch_t'] 79 | batch_r = data['batch_r'] 80 | mode = data['mode'] 81 | h = self.ent_embeddings(batch_h) 82 | t = self.ent_embeddings(batch_t) 83 | r = self.rel_embeddings(batch_r) 84 | score = self._calc(h ,t, r, mode) 85 | return score 86 | 87 | def regularization(self, data): 88 | batch_h = data['batch_h'] 89 | batch_t = data['batch_t'] 90 | batch_r = data['batch_r'] 91 | h = self.ent_embeddings(batch_h) 92 | t = self.ent_embeddings(batch_t) 93 | r = self.rel_embeddings(batch_r) 94 | regul = (torch.mean(h ** 2) + torch.mean(t ** 2) + torch.mean(r ** 2)) / 3 95 | return regul 96 | 97 | def l3_regularization(self): 98 | return (self.ent_embeddings.weight.norm(p = 3)**3 + self.rel_embeddings.weight.norm(p = 3)**3) 99 | 100 | def predict(self, data): 101 | score = -self.forward(data) 102 | return score.cpu().data.numpy() 103 | -------------------------------------------------------------------------------- /prepro/OpenKE/openke/module/model/RotatE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd as autograd 3 | import torch.nn as nn 4 | from .Model import Model 5 | 6 | class RotatE(Model): 7 | 8 | def __init__(self, ent_tot, rel_tot, dim = 100, margin = 6.0, epsilon = 2.0): 9 | super(RotatE, self).__init__(ent_tot, rel_tot) 10 | 11 | self.margin = margin 12 | self.epsilon = epsilon 13 | 14 | self.dim_e = dim * 2 15 | self.dim_r = dim 16 | 17 | self.ent_embeddings = nn.Embedding(self.ent_tot, self.dim_e) 18 | self.rel_embeddings = nn.Embedding(self.rel_tot, self.dim_r) 19 | 20 | self.ent_embedding_range = nn.Parameter( 21 | torch.Tensor([(self.margin + self.epsilon) / self.dim_e]), 22 | requires_grad=False 23 | ) 24 | 25 | nn.init.uniform_( 26 | tensor = self.ent_embeddings.weight.data, 27 | a=-self.ent_embedding_range.item(), 28 | b=self.ent_embedding_range.item() 29 | ) 30 | 31 | self.rel_embedding_range = nn.Parameter( 32 | torch.Tensor([(self.margin + self.epsilon) / self.dim_r]), 33 | requires_grad=False 34 | ) 35 | 36 | nn.init.uniform_( 37 | tensor = self.rel_embeddings.weight.data, 38 | a=-self.rel_embedding_range.item(), 39 | b=self.rel_embedding_range.item() 40 | ) 41 | 42 | self.margin = nn.Parameter(torch.Tensor([margin])) 43 | self.margin.requires_grad = False 44 | 45 | def _calc(self, h, t, r, mode): 46 | pi = self.pi_const 47 | 48 | re_head, im_head = torch.chunk(h, 2, dim=-1) 49 | re_tail, im_tail = torch.chunk(t, 2, dim=-1) 50 | 51 | phase_relation = r / (self.rel_embedding_range.item() / pi) 52 | 53 | re_relation = torch.cos(phase_relation) 54 | im_relation = torch.sin(phase_relation) 55 | 56 | re_head = re_head.view(-1, re_relation.shape[0], re_head.shape[-1]).permute(1, 0, 2) 57 | re_tail = re_tail.view(-1, re_relation.shape[0], re_tail.shape[-1]).permute(1, 0, 2) 58 | im_head = im_head.view(-1, re_relation.shape[0], im_head.shape[-1]).permute(1, 0, 2) 59 | im_tail = im_tail.view(-1, re_relation.shape[0], im_tail.shape[-1]).permute(1, 0, 2) 60 | im_relation = im_relation.view(-1, re_relation.shape[0], im_relation.shape[-1]).permute(1, 0, 2) 61 | re_relation = re_relation.view(-1, re_relation.shape[0], re_relation.shape[-1]).permute(1, 0, 2) 62 | 63 | if mode == "head_batch": 64 | re_score = re_relation * re_tail + im_relation * im_tail 65 | im_score = re_relation * im_tail - im_relation * re_tail 66 | re_score = re_score - re_head 67 | im_score = im_score - im_head 68 | else: 69 | re_score = re_head * re_relation - im_head * im_relation 70 | im_score = re_head * im_relation + im_head * re_relation 71 | re_score = re_score - re_tail 72 | im_score = im_score - im_tail 73 | 74 | score = torch.stack([re_score, im_score], dim = 0) 75 | score = score.norm(dim = 0).sum(dim = -1) 76 | return score.permute(1, 0).flatten() 77 | 78 | def forward(self, data): 79 | batch_h = data['batch_h'] 80 | batch_t = data['batch_t'] 81 | batch_r = data['batch_r'] 82 | mode = data['mode'] 83 | h = self.ent_embeddings(batch_h) 84 | t = self.ent_embeddings(batch_t) 85 | r = self.rel_embeddings(batch_r) 86 | score = self.margin - self._calc(h ,t, r, mode) 87 | return score 88 | 89 | def predict(self, data): 90 | score = -self.forward(data) 91 | return score.cpu().data.numpy() 92 | 93 | def regularization(self, data): 94 | batch_h = data['batch_h'] 95 | batch_t = data['batch_t'] 96 | batch_r = data['batch_r'] 97 | h = self.ent_embeddings(batch_h) 98 | t = self.ent_embeddings(batch_t) 99 | r = self.rel_embeddings(batch_r) 100 | regul = (torch.mean(h ** 2) + 101 | torch.mean(t ** 2) + 102 | torch.mean(r ** 2)) / 3 103 | return regul -------------------------------------------------------------------------------- /prepro/OpenKE/openke/module/model/TransH.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .Model import Model 5 | 6 | class TransH(Model): 7 | 8 | def __init__(self, ent_tot, rel_tot, dim = 100, p_norm = 1, norm_flag = True, margin = None, epsilon = None): 9 | super(TransH, self).__init__(ent_tot, rel_tot) 10 | 11 | self.dim = dim 12 | self.margin = margin 13 | self.epsilon = epsilon 14 | self.norm_flag = norm_flag 15 | self.p_norm = p_norm 16 | 17 | self.ent_embeddings = nn.Embedding(self.ent_tot, self.dim) 18 | self.rel_embeddings = nn.Embedding(self.rel_tot, self.dim) 19 | self.norm_vector = nn.Embedding(self.rel_tot, self.dim) 20 | 21 | if margin == None or epsilon == None: 22 | nn.init.xavier_uniform_(self.ent_embeddings.weight.data) 23 | nn.init.xavier_uniform_(self.rel_embeddings.weight.data) 24 | nn.init.xavier_uniform_(self.norm_vector.weight.data) 25 | else: 26 | self.embedding_range = nn.Parameter( 27 | torch.Tensor([(self.margin + self.epsilon) / self.dim]), requires_grad=False 28 | ) 29 | nn.init.uniform_( 30 | tensor = self.ent_embeddings.weight.data, 31 | a = -self.embedding_range.item(), 32 | b = self.embedding_range.item() 33 | ) 34 | nn.init.uniform_( 35 | tensor = self.rel_embeddings.weight.data, 36 | a= -self.embedding_range.item(), 37 | b= self.embedding_range.item() 38 | ) 39 | nn.init.uniform_( 40 | tensor = self.norm_vector.weight.data, 41 | a= -self.embedding_range.item(), 42 | b= self.embedding_range.item() 43 | ) 44 | 45 | if margin != None: 46 | self.margin = nn.Parameter(torch.Tensor([margin])) 47 | self.margin.requires_grad = False 48 | self.margin_flag = True 49 | else: 50 | self.margin_flag = False 51 | 52 | def _calc(self, h, t, r, mode): 53 | if self.norm_flag: 54 | h = F.normalize(h, 2, -1) 55 | r = F.normalize(r, 2, -1) 56 | t = F.normalize(t, 2, -1) 57 | if mode != 'normal': 58 | h = h.view(-1, r.shape[0], h.shape[-1]) 59 | t = t.view(-1, r.shape[0], t.shape[-1]) 60 | r = r.view(-1, r.shape[0], r.shape[-1]) 61 | if mode == 'head_batch': 62 | score = h + (r - t) 63 | else: 64 | score = (h + r) - t 65 | score = torch.norm(score, self.p_norm, -1).flatten() 66 | return score 67 | 68 | def _transfer(self, e, norm): 69 | norm = F.normalize(norm, p = 2, dim = -1) 70 | if e.shape[0] != norm.shape[0]: 71 | e = e.view(-1, norm.shape[0], e.shape[-1]) 72 | norm = norm.view(-1, norm.shape[0], norm.shape[-1]) 73 | e = e - torch.sum(e * norm, -1, True) * norm 74 | return e.view(-1, e.shape[-1]) 75 | else: 76 | return e - torch.sum(e * norm, -1, True) * norm 77 | 78 | def forward(self, data): 79 | batch_h = data['batch_h'] 80 | batch_t = data['batch_t'] 81 | batch_r = data['batch_r'] 82 | mode = data['mode'] 83 | h = self.ent_embeddings(batch_h) 84 | t = self.ent_embeddings(batch_t) 85 | r = self.rel_embeddings(batch_r) 86 | r_norm = self.norm_vector(batch_r) 87 | h = self._transfer(h, r_norm) 88 | t = self._transfer(t, r_norm) 89 | score = self._calc(h ,t, r, mode) 90 | if self.margin_flag: 91 | return self.margin - score 92 | else: 93 | return score 94 | 95 | def regularization(self, data): 96 | batch_h = data['batch_h'] 97 | batch_t = data['batch_t'] 98 | batch_r = data['batch_r'] 99 | h = self.ent_embeddings(batch_h) 100 | t = self.ent_embeddings(batch_t) 101 | r = self.rel_embeddings(batch_r) 102 | r_norm = self.norm_vector(batch_r) 103 | regul = (torch.mean(h ** 2) + 104 | torch.mean(t ** 2) + 105 | torch.mean(r ** 2) + 106 | torch.mean(r_norm ** 2)) / 4 107 | return regul 108 | 109 | def predict(self, data): 110 | score = self.forward(data) 111 | if self.margin_flag: 112 | score = self.margin - score 113 | return score.cpu().data.numpy() 114 | else: 115 | return score.cpu().data.numpy() -------------------------------------------------------------------------------- /prepro/OpenKE/openke/config/Trainer.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import torch.optim as optim 6 | import os 7 | import time 8 | import sys 9 | import datetime 10 | import ctypes 11 | import json 12 | import numpy as np 13 | import copy 14 | from tqdm import tqdm 15 | 16 | class Trainer(object): 17 | 18 | def __init__(self, 19 | model = None, 20 | data_loader = None, 21 | train_times = 1000, 22 | alpha = 0.5, 23 | use_gpu = True, 24 | opt_method = "sgd", 25 | save_steps = None, 26 | checkpoint_dir = None): 27 | 28 | self.work_threads = 8 29 | self.train_times = train_times 30 | 31 | self.opt_method = opt_method 32 | self.optimizer = None 33 | self.lr_decay = 0 34 | self.weight_decay = 0 35 | self.alpha = alpha 36 | 37 | self.model = model 38 | self.data_loader = data_loader 39 | self.use_gpu = use_gpu 40 | self.save_steps = save_steps 41 | self.checkpoint_dir = checkpoint_dir 42 | 43 | def train_one_step(self, data): 44 | self.optimizer.zero_grad() 45 | loss = self.model({ 46 | 'batch_h': self.to_var(data['batch_h'], self.use_gpu), 47 | 'batch_t': self.to_var(data['batch_t'], self.use_gpu), 48 | 'batch_r': self.to_var(data['batch_r'], self.use_gpu), 49 | 'batch_y': self.to_var(data['batch_y'], self.use_gpu), 50 | 'mode': data['mode'] 51 | }) 52 | loss.backward() 53 | self.optimizer.step() 54 | return loss.item() 55 | 56 | def run(self): 57 | if self.use_gpu: 58 | self.model.cuda() 59 | 60 | if self.optimizer != None: 61 | pass 62 | elif self.opt_method == "Adagrad" or self.opt_method == "adagrad": 63 | self.optimizer = optim.Adagrad( 64 | self.model.parameters(), 65 | lr=self.alpha, 66 | lr_decay=self.lr_decay, 67 | weight_decay=self.weight_decay, 68 | ) 69 | elif self.opt_method == "Adadelta" or self.opt_method == "adadelta": 70 | self.optimizer = optim.Adadelta( 71 | self.model.parameters(), 72 | lr=self.alpha, 73 | weight_decay=self.weight_decay, 74 | ) 75 | elif self.opt_method == "Adam" or self.opt_method == "adam": 76 | self.optimizer = optim.Adam( 77 | self.model.parameters(), 78 | lr=self.alpha, 79 | weight_decay=self.weight_decay, 80 | ) 81 | else: 82 | self.optimizer = optim.SGD( 83 | self.model.parameters(), 84 | lr = self.alpha, 85 | weight_decay=self.weight_decay, 86 | ) 87 | print("Finish initializing...") 88 | 89 | training_range = tqdm(range(self.train_times)) 90 | for epoch in training_range: 91 | res = 0.0 92 | for data in self.data_loader: 93 | loss = self.train_one_step(data) 94 | res += loss 95 | training_range.set_description("Epoch %d | loss: %f" % (epoch, res)) 96 | 97 | if self.save_steps and self.checkpoint_dir and (epoch + 1) % self.save_steps == 0: 98 | print("Epoch %d has finished, saving..." % (epoch)) 99 | self.model.save_checkpoint(os.path.join(self.checkpoint_dir + "-" + str(epoch) + ".ckpt")) 100 | 101 | def set_model(self, model): 102 | self.model = model 103 | 104 | def to_var(self, x, use_gpu): 105 | if use_gpu: 106 | return Variable(torch.from_numpy(x).cuda()) 107 | else: 108 | return Variable(torch.from_numpy(x)) 109 | 110 | def set_use_gpu(self, use_gpu): 111 | self.use_gpu = use_gpu 112 | 113 | def set_alpha(self, alpha): 114 | self.alpha = alpha 115 | 116 | def set_lr_decay(self, lr_decay): 117 | self.lr_decay = lr_decay 118 | 119 | def set_weight_decay(self, weight_decay): 120 | self.weight_decay = weight_decay 121 | 122 | def set_opt_method(self, opt_method): 123 | self.opt_method = opt_method 124 | 125 | def set_train_times(self, train_times): 126 | self.train_times = train_times 127 | 128 | def set_save_steps(self, save_steps, checkpoint_dir = None): 129 | self.save_steps = save_steps 130 | if not self.checkpoint_dir: 131 | self.set_checkpoint_dir(checkpoint_dir) 132 | 133 | def set_checkpoint_dir(self, checkpoint_dir): 134 | self.checkpoint_dir = checkpoint_dir -------------------------------------------------------------------------------- /arl/modules/position_embeddings.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | 15 | # -------------------------------------------------------- 16 | # 2D sine-cosine position embedding 17 | # References: 18 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 19 | # MoCo v3: https://github.com/facebookresearch/moco-v3 20 | # -------------------------------------------------------- 21 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 22 | """ 23 | grid_size: int of the grid height and width 24 | return: 25 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 26 | """ 27 | grid_h = np.arange(grid_size, dtype=np.float32) 28 | grid_w = np.arange(grid_size, dtype=np.float32) 29 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 30 | grid = np.stack(grid, axis=0) 31 | 32 | grid = grid.reshape([2, 1, grid_size, grid_size]) 33 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 34 | if cls_token: 35 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 36 | return pos_embed 37 | 38 | 39 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 40 | assert embed_dim % 2 == 0 41 | 42 | # use half of dimensions to encode grid_h 43 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 44 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 45 | 46 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 47 | return emb 48 | 49 | 50 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 51 | """ 52 | embed_dim: output dimension for each position 53 | pos: a list of positions to be encoded: size (M,) 54 | out: (M, D) 55 | """ 56 | assert embed_dim % 2 == 0 57 | omega = np.arange(embed_dim // 2, dtype=np.float) 58 | omega /= embed_dim / 2. 59 | omega = 1. / 10000 ** omega # (D/2,) 60 | 61 | pos = pos.reshape(-1) # (M,) 62 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 63 | 64 | emb_sin = np.sin(out) # (M, D/2) 65 | emb_cos = np.cos(out) # (M, D/2) 66 | 67 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 68 | return emb 69 | 70 | 71 | # -------------------------------------------------------- 72 | # Interpolate position embeddings for high-resolution 73 | # References: 74 | # DeiT: https://github.com/facebookresearch/deit 75 | # -------------------------------------------------------- 76 | def interpolate_pos_embed(model, checkpoint_model): 77 | if 'pos_embed' in checkpoint_model: 78 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 79 | embedding_size = pos_embed_checkpoint.shape[-1] 80 | num_patches = model.patch_embed.num_patches 81 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 82 | # height (== width) for the checkpoint position embedding 83 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 84 | # height (== width) for the new position embedding 85 | new_size = int(num_patches ** 0.5) 86 | # class_token and dist_token are kept unchanged 87 | if orig_size != new_size: 88 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 89 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 90 | # only the position tokens are interpolated 91 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 92 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 93 | pos_tokens = torch.nn.functional.interpolate( 94 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 95 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 96 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 97 | checkpoint_model['pos_embed'] = new_pos_embed 98 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ARL 2 | 3 | This is the implementation 4 | of [Align, Reason and Learn: Enhancing Medical Vision-and-Language Pre-training with Knowledge](https://arxiv.org/abs/2209.07118) 5 | at ACMMM-2022. 6 | 7 | ## Table of Contents 8 | 9 | - [Requirements](#requirements) 10 | - [Pre-training](#pre-training) 11 | - [Downstream Evaluation](#downstream-evaluation) 12 | 13 | ## Requirements 14 | 15 | Run the following command to install the required packages: 16 | 17 | ```bash 18 | pip install -r requirements.txt 19 | ``` 20 | 21 | Note: ARL involves the knowledge extraction and knowledge integration, which need more packages. Therefore, please be 22 | patient to install the environment. :-) 23 | 24 | ## Pre-training 25 | 26 | ### 1. Dataset Preparation 27 | 28 | Please organize the pre-training datasets as the following structure: 29 | 30 | ```angular2 31 | root:[data] 32 | +--pretrain_data 33 | | +--roco 34 | | | +--train 35 | | | +--val 36 | | | +--test 37 | | +--medicat 38 | | | +--net 39 | | | +--release 40 | | +--mimic_cxr 41 | | | +--files 42 | | | +--mimic_cxr_sectioned.csv 43 | | | +--mimic-cxr-2.0.0-split.csv 44 | | | +--mimic-cxr-2.0.0-metadata.csv 45 | | | +--mimic-cxr-2.0.0-chexpert.csv 46 | ``` 47 | 48 | ### 2. Pre-processing 49 | 50 | Run the following command to pre-process the data: 51 | 52 | ```angular2 53 | python prepro/prepro_pretraining_data.py 54 | ``` 55 | 56 | to get the following arrow files: 57 | 58 | ```angular2 59 | root:[data] 60 | +--pretrain_arrows 61 | | +--medicat_train.arrow 62 | | +--medicat_val.arrow 63 | | +--medicat_test.arrow 64 | | +--roco_train.arrow 65 | | +--roco_val.arrow 66 | | +--roco_test.arrow 67 | | +--mimic_cxr_train.arrow 68 | | +--mimic_cxr_val.arrow 69 | | +--mimic_cxr_test.arrow 70 | ``` 71 | 72 | ### 3. Download the initialized weights for pre-training 73 | 74 | Download the initialized meter 75 | weights [here](https://drive.google.com/drive/folders/1PiXnT65WR8qb6VAqE1lwidLrOLThuqY7?usp=share_link). 76 | 77 | ### 4. Pre-training 78 | 79 | Now we can start to pre-train the arl model: 80 | 81 | ```angular2 82 | bash run_scripts/pretrain_arl.sh 83 | ``` 84 | 85 | ## Downstream Evaluation 86 | 87 | ### 1. Dataset Preparation 88 | 89 | Please organize the fine-tuning datasets as the following structure: 90 | 91 | ```angular2 92 | root:[data] 93 | +--finetune_data 94 | | +--melinda 95 | | | +--train.csv 96 | | | +--dev.csv 97 | | | +--test.csv 98 | | | +--melinda_images 99 | | +--slack 100 | | | +--train.json 101 | | | +--validate.json 102 | | | +--test.json 103 | | | +--imgs 104 | | +--vqa_rad 105 | | | +--trainset.json 106 | | | +--valset.json 107 | | | +--testset.json 108 | | | +--images 109 | | +--medvqa_2019 110 | | | +--val 111 | | | +--test 112 | | | +--train 113 | ``` 114 | 115 | ### 2. Pre-processing 116 | 117 | Run the following command to pre-process the data: 118 | 119 | ```angular2 120 | python prepro/prepro_finetuning_data.py 121 | ``` 122 | 123 | to get the following arrow files: 124 | 125 | ```angular2 126 | root:[data] 127 | +--finetune_arrows 128 | | +--vqa_vqa_rad_train.arrow 129 | | +--vqa_vqa_rad_val.arrow 130 | | +--vqa_vqa_rad_test.arrow 131 | | +--vqa_slack_train.arrow 132 | | +--vqa_slack_test.arrow 133 | | +--vqa_slack_val.arrow 134 | | +--vqa_medvqa_2019_train.arrow 135 | | +--vqa_medvqa_2019_val.arrow 136 | | +--vqa_medvqa_2019_test.arrow 137 | | +--cls_melinda_train.arrow 138 | | +--cls_melinda_val.arrow 139 | | +--cls_melinda_test.arrow 140 | | +--irtr_roco_train.arrow 141 | | +--irtr_roco_val.arrow 142 | | +--irtr_roco_test.arrow 143 | ``` 144 | 145 | ### 3. Fine-Tuning 146 | 147 | Now you can start to fine-tune the arl model: 148 | 149 | ```angular2 150 | bash run_scripts/finetune_arl.sh 151 | ``` 152 | 153 | ## Acknowledgement 154 | 155 | The code is based on [OpenKE](https://github.com/thunlp/OpenKE), [ViLT](https://github.com/dandelin/ViLT) 156 | , [METER](https://github.com/zdou0830/METER) 157 | and [MAE](https://github.com/facebookresearch/mae). 158 | We thank the authors for their open-sourced code and encourage users to cite their works when applicable. 159 | 160 | ## Citations 161 | 162 | If ARL is useful for your research, please consider citing: 163 | 164 | ```angular2 165 | @inproceedings{chen2022arl, 166 | title={Align, Reason and Learn: Enhancing Medical Vision-and-Language Pre-training with Knowledge}, 167 | author={Chen, Zhihong and Li, Guanbin and Wan, Xiang}, 168 | booktitle={Proceedings of the 30th ACM International Conference on Multimedia}, 169 | year={2022} 170 | } 171 | ``` 172 | -------------------------------------------------------------------------------- /prepro/glossary.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | contractions = { 4 | "aint": "ain't", 5 | "arent": "aren't", 6 | "cant": "can't", 7 | "couldve": "could've", 8 | "couldnt": "couldn't", 9 | "couldn'tve": "couldn't've", 10 | "couldnt've": "couldn't've", 11 | "didnt": "didn't", 12 | "doesnt": "doesn't", 13 | "dont": "don't", 14 | "hadnt": "hadn't", 15 | "hadnt've": "hadn't've", 16 | "hadn'tve": "hadn't've", 17 | "hasnt": "hasn't", 18 | "havent": "haven't", 19 | "hed": "he'd", 20 | "hed've": "he'd've", 21 | "he'dve": "he'd've", 22 | "hes": "he's", 23 | "howd": "how'd", 24 | "howll": "how'll", 25 | "hows": "how's", 26 | "Id've": "I'd've", 27 | "I'dve": "I'd've", 28 | "Im": "I'm", 29 | "Ive": "I've", 30 | "isnt": "isn't", 31 | "itd": "it'd", 32 | "itd've": "it'd've", 33 | "it'dve": "it'd've", 34 | "itll": "it'll", 35 | "let's": "let's", 36 | "maam": "ma'am", 37 | "mightnt": "mightn't", 38 | "mightnt've": "mightn't've", 39 | "mightn'tve": "mightn't've", 40 | "mightve": "might've", 41 | "mustnt": "mustn't", 42 | "mustve": "must've", 43 | "neednt": "needn't", 44 | "notve": "not've", 45 | "oclock": "o'clock", 46 | "oughtnt": "oughtn't", 47 | "ow's'at": "'ow's'at", 48 | "'ows'at": "'ow's'at", 49 | "'ow'sat": "'ow's'at", 50 | "shant": "shan't", 51 | "shed've": "she'd've", 52 | "she'dve": "she'd've", 53 | "she's": "she's", 54 | "shouldve": "should've", 55 | "shouldnt": "shouldn't", 56 | "shouldnt've": "shouldn't've", 57 | "shouldn'tve": "shouldn't've", 58 | "somebody'd": "somebodyd", 59 | "somebodyd've": "somebody'd've", 60 | "somebody'dve": "somebody'd've", 61 | "somebodyll": "somebody'll", 62 | "somebodys": "somebody's", 63 | "someoned": "someone'd", 64 | "someoned've": "someone'd've", 65 | "someone'dve": "someone'd've", 66 | "someonell": "someone'll", 67 | "someones": "someone's", 68 | "somethingd": "something'd", 69 | "somethingd've": "something'd've", 70 | "something'dve": "something'd've", 71 | "somethingll": "something'll", 72 | "thats": "that's", 73 | "thered": "there'd", 74 | "thered've": "there'd've", 75 | "there'dve": "there'd've", 76 | "therere": "there're", 77 | "theres": "there's", 78 | "theyd": "they'd", 79 | "theyd've": "they'd've", 80 | "they'dve": "they'd've", 81 | "theyll": "they'll", 82 | "theyre": "they're", 83 | "theyve": "they've", 84 | "twas": "'twas", 85 | "wasnt": "wasn't", 86 | "wed've": "we'd've", 87 | "we'dve": "we'd've", 88 | "weve": "we've", 89 | "werent": "weren't", 90 | "whatll": "what'll", 91 | "whatre": "what're", 92 | "whats": "what's", 93 | "whatve": "what've", 94 | "whens": "when's", 95 | "whered": "where'd", 96 | "wheres": "where's", 97 | "whereve": "where've", 98 | "whod": "who'd", 99 | "whod've": "who'd've", 100 | "who'dve": "who'd've", 101 | "wholl": "who'll", 102 | "whos": "who's", 103 | "whove": "who've", 104 | "whyll": "why'll", 105 | "whyre": "why're", 106 | "whys": "why's", 107 | "wont": "won't", 108 | "wouldve": "would've", 109 | "wouldnt": "wouldn't", 110 | "wouldnt've": "wouldn't've", 111 | "wouldn'tve": "wouldn't've", 112 | "yall": "y'all", 113 | "yall'll": "y'all'll", 114 | "y'allll": "y'all'll", 115 | "yall'd've": "y'all'd've", 116 | "y'alld've": "y'all'd've", 117 | "y'all'dve": "y'all'd've", 118 | "youd": "you'd", 119 | "youd've": "you'd've", 120 | "you'dve": "you'd've", 121 | "youll": "you'll", 122 | "youre": "you're", 123 | "youve": "you've", 124 | } 125 | 126 | manual_map = { 127 | "none": "0", 128 | "zero": "0", 129 | "one": "1", 130 | "two": "2", 131 | "three": "3", 132 | "four": "4", 133 | "five": "5", 134 | "six": "6", 135 | "seven": "7", 136 | "eight": "8", 137 | "nine": "9", 138 | "ten": "10", 139 | } 140 | articles = ["a", "an", "the"] 141 | period_strip = re.compile("(?!<=\d)(\.)(?!\d)") 142 | comma_strip = re.compile("(\d)(\,)(\d)") 143 | punct = [ 144 | ";", 145 | r"/", 146 | "[", 147 | "]", 148 | '"', 149 | "{", 150 | "}", 151 | "(", 152 | ")", 153 | "=", 154 | "+", 155 | "\\", 156 | "_", 157 | "-", 158 | ">", 159 | "<", 160 | "@", 161 | "`", 162 | ",", 163 | "?", 164 | "!", 165 | ] 166 | 167 | 168 | def normalize_word(token): 169 | _token = token 170 | for p in punct: 171 | if (p + " " in token or " " + p in token) or ( 172 | re.search(comma_strip, token) != None 173 | ): 174 | _token = _token.replace(p, "") 175 | else: 176 | _token = _token.replace(p, " ") 177 | token = period_strip.sub("", _token, re.UNICODE) 178 | 179 | _token = [] 180 | temp = token.lower().split() 181 | for word in temp: 182 | word = manual_map.setdefault(word, word) 183 | if word not in articles: 184 | _token.append(word) 185 | for i, word in enumerate(_token): 186 | if word in contractions: 187 | _token[i] = contractions[word] 188 | token = " ".join(_token) 189 | token = token.replace(",", "") 190 | return token 191 | -------------------------------------------------------------------------------- /prepro/OpenKE/openke/base/Corrupt.h: -------------------------------------------------------------------------------- 1 | #ifndef CORRUPT_H 2 | #define CORRUPT_H 3 | #include "Random.h" 4 | #include "Triple.h" 5 | #include "Reader.h" 6 | 7 | INT corrupt_head(INT id, INT h, INT r, bool filter_flag = true) { 8 | INT lef, rig, mid, ll, rr; 9 | if (not filter_flag) { 10 | INT tmp = rand_max(id, entityTotal - 1); 11 | if (tmp < h) 12 | return tmp; 13 | else 14 | return tmp + 1; 15 | } 16 | lef = lefHead[h] - 1; 17 | rig = rigHead[h]; 18 | while (lef + 1 < rig) { 19 | mid = (lef + rig) >> 1; 20 | if (trainHead[mid].r >= r) rig = mid; else 21 | lef = mid; 22 | } 23 | ll = rig; 24 | lef = lefHead[h]; 25 | rig = rigHead[h] + 1; 26 | while (lef + 1 < rig) { 27 | mid = (lef + rig) >> 1; 28 | if (trainHead[mid].r <= r) lef = mid; else 29 | rig = mid; 30 | } 31 | rr = lef; 32 | INT tmp = rand_max(id, entityTotal - (rr - ll + 1)); 33 | if (tmp < trainHead[ll].t) return tmp; 34 | if (tmp > trainHead[rr].t - rr + ll - 1) return tmp + rr - ll + 1; 35 | lef = ll, rig = rr + 1; 36 | while (lef + 1 < rig) { 37 | mid = (lef + rig) >> 1; 38 | if (trainHead[mid].t - mid + ll - 1 < tmp) 39 | lef = mid; 40 | else 41 | rig = mid; 42 | } 43 | return tmp + lef - ll + 1; 44 | } 45 | 46 | INT corrupt_tail(INT id, INT t, INT r, bool filter_flag = true) { 47 | INT lef, rig, mid, ll, rr; 48 | if (not filter_flag) { 49 | INT tmp = rand_max(id, entityTotal - 1); 50 | if (tmp < t) 51 | return tmp; 52 | else 53 | return tmp + 1; 54 | } 55 | lef = lefTail[t] - 1; 56 | rig = rigTail[t]; 57 | while (lef + 1 < rig) { 58 | mid = (lef + rig) >> 1; 59 | if (trainTail[mid].r >= r) rig = mid; else 60 | lef = mid; 61 | } 62 | ll = rig; 63 | lef = lefTail[t]; 64 | rig = rigTail[t] + 1; 65 | while (lef + 1 < rig) { 66 | mid = (lef + rig) >> 1; 67 | if (trainTail[mid].r <= r) lef = mid; else 68 | rig = mid; 69 | } 70 | rr = lef; 71 | INT tmp = rand_max(id, entityTotal - (rr - ll + 1)); 72 | if (tmp < trainTail[ll].h) return tmp; 73 | if (tmp > trainTail[rr].h - rr + ll - 1) return tmp + rr - ll + 1; 74 | lef = ll, rig = rr + 1; 75 | while (lef + 1 < rig) { 76 | mid = (lef + rig) >> 1; 77 | if (trainTail[mid].h - mid + ll - 1 < tmp) 78 | lef = mid; 79 | else 80 | rig = mid; 81 | } 82 | return tmp + lef - ll + 1; 83 | } 84 | 85 | 86 | INT corrupt_rel(INT id, INT h, INT t, INT r, bool p = false, bool filter_flag = true) { 87 | INT lef, rig, mid, ll, rr; 88 | if (not filter_flag) { 89 | INT tmp = rand_max(id, relationTotal - 1); 90 | if (tmp < r) 91 | return tmp; 92 | else 93 | return tmp + 1; 94 | } 95 | lef = lefRel[h] - 1; 96 | rig = rigRel[h]; 97 | while (lef + 1 < rig) { 98 | mid = (lef + rig) >> 1; 99 | if (trainRel[mid].t >= t) rig = mid; else 100 | lef = mid; 101 | } 102 | ll = rig; 103 | lef = lefRel[h]; 104 | rig = rigRel[h] + 1; 105 | while (lef + 1 < rig) { 106 | mid = (lef + rig) >> 1; 107 | if (trainRel[mid].t <= t) lef = mid; else 108 | rig = mid; 109 | } 110 | rr = lef; 111 | INT tmp; 112 | if(p == false) { 113 | tmp = rand_max(id, relationTotal - (rr - ll + 1)); 114 | } 115 | else { 116 | INT start = r * (relationTotal - 1); 117 | REAL sum = 1; 118 | bool *record = (bool *)calloc(relationTotal - 1, sizeof(bool)); 119 | for (INT i = ll; i <= rr; ++i){ 120 | if (trainRel[i].r > r){ 121 | sum -= prob[start + trainRel[i].r-1]; 122 | record[trainRel[i].r-1] = true; 123 | } 124 | else if (trainRel[i].r < r){ 125 | sum -= prob[start + trainRel[i].r]; 126 | record[trainRel[i].r] = true; 127 | } 128 | } 129 | REAL *prob_tmp = (REAL *)calloc(relationTotal-(rr-ll+1), sizeof(REAL)); 130 | INT cnt = 0; 131 | REAL rec = 0; 132 | for (INT i = start; i < start + relationTotal - 1; ++i) { 133 | if (record[i-start]) 134 | continue; 135 | rec += prob[i] / sum; 136 | prob_tmp[cnt++] = rec; 137 | } 138 | REAL m = rand_max(id, 10000) / 10000.0; 139 | lef = 0; 140 | rig = cnt - 1; 141 | while (lef < rig) { 142 | mid = (lef + rig) >> 1; 143 | if (prob_tmp[mid] < m) 144 | lef = mid + 1; 145 | else 146 | rig = mid; 147 | } 148 | tmp = rig; 149 | free(prob_tmp); 150 | free(record); 151 | } 152 | if (tmp < trainRel[ll].r) return tmp; 153 | if (tmp > trainRel[rr].r - rr + ll - 1) return tmp + rr - ll + 1; 154 | lef = ll, rig = rr + 1; 155 | while (lef + 1 < rig) { 156 | mid = (lef + rig) >> 1; 157 | if (trainRel[mid].r - mid + ll - 1 < tmp) 158 | lef = mid; 159 | else 160 | rig = mid; 161 | } 162 | return tmp + lef - ll + 1; 163 | } 164 | 165 | 166 | bool _find(INT h, INT t, INT r) { 167 | INT lef = 0; 168 | INT rig = tripleTotal - 1; 169 | INT mid; 170 | while (lef + 1 < rig) { 171 | INT mid = (lef + rig) >> 1; 172 | if ((tripleList[mid]. h < h) || (tripleList[mid]. h == h && tripleList[mid]. r < r) || (tripleList[mid]. h == h && tripleList[mid]. r == r && tripleList[mid]. t < t)) lef = mid; else rig = mid; 173 | } 174 | if (tripleList[lef].h == h && tripleList[lef].r == r && tripleList[lef].t == t) return true; 175 | if (tripleList[rig].h == h && tripleList[rig].r == r && tripleList[rig].t == t) return true; 176 | return false; 177 | } 178 | 179 | INT corrupt(INT h, INT r){ 180 | INT ll = tail_lef[r]; 181 | INT rr = tail_rig[r]; 182 | INT loop = 0; 183 | INT t; 184 | while(true) { 185 | t = tail_type[rand(ll, rr)]; 186 | if (not _find(h, t, r)) { 187 | return t; 188 | } else { 189 | loop ++; 190 | if (loop >= 1000) { 191 | return corrupt_head(0, h, r); 192 | } 193 | } 194 | } 195 | } 196 | #endif 197 | -------------------------------------------------------------------------------- /prepro/OpenKE/openke/module/model/TransD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .Model import Model 5 | 6 | class TransD(Model): 7 | 8 | def __init__(self, ent_tot, rel_tot, dim_e = 100, dim_r = 100, p_norm = 1, norm_flag = True, margin = None, epsilon = None): 9 | super(TransD, self).__init__(ent_tot, rel_tot) 10 | 11 | self.dim_e = dim_e 12 | self.dim_r = dim_r 13 | self.margin = margin 14 | self.epsilon = epsilon 15 | self.norm_flag = norm_flag 16 | self.p_norm = p_norm 17 | 18 | self.ent_embeddings = nn.Embedding(self.ent_tot, self.dim_e) 19 | self.rel_embeddings = nn.Embedding(self.rel_tot, self.dim_r) 20 | self.ent_transfer = nn.Embedding(self.ent_tot, self.dim_e) 21 | self.rel_transfer = nn.Embedding(self.rel_tot, self.dim_r) 22 | 23 | if margin == None or epsilon == None: 24 | nn.init.xavier_uniform_(self.ent_embeddings.weight.data) 25 | nn.init.xavier_uniform_(self.rel_embeddings.weight.data) 26 | nn.init.xavier_uniform_(self.ent_transfer.weight.data) 27 | nn.init.xavier_uniform_(self.rel_transfer.weight.data) 28 | else: 29 | self.ent_embedding_range = nn.Parameter( 30 | torch.Tensor([(self.margin + self.epsilon) / self.dim_e]), requires_grad=False 31 | ) 32 | self.rel_embedding_range = nn.Parameter( 33 | torch.Tensor([(self.margin + self.epsilon) / self.dim_r]), requires_grad=False 34 | ) 35 | nn.init.uniform_( 36 | tensor = self.ent_embeddings.weight.data, 37 | a = -self.ent_embedding_range.item(), 38 | b = self.ent_embedding_range.item() 39 | ) 40 | nn.init.uniform_( 41 | tensor = self.rel_embeddings.weight.data, 42 | a= -self.rel_embedding_range.item(), 43 | b= self.rel_embedding_range.item() 44 | ) 45 | nn.init.uniform_( 46 | tensor = self.ent_transfer.weight.data, 47 | a= -self.ent_embedding_range.item(), 48 | b= self.ent_embedding_range.item() 49 | ) 50 | nn.init.uniform_( 51 | tensor = self.rel_transfer.weight.data, 52 | a= -self.rel_embedding_range.item(), 53 | b= self.rel_embedding_range.item() 54 | ) 55 | if margin != None: 56 | self.margin = nn.Parameter(torch.Tensor([margin])) 57 | self.margin.requires_grad = False 58 | self.margin_flag = True 59 | else: 60 | self.margin_flag = False 61 | 62 | def _resize(self, tensor, axis, size): 63 | shape = tensor.size() 64 | osize = shape[axis] 65 | if osize == size: 66 | return tensor 67 | if (osize > size): 68 | return torch.narrow(tensor, axis, 0, size) 69 | paddings = [] 70 | for i in range(len(shape)): 71 | if i == axis: 72 | paddings = [0, size - osize] + paddings 73 | else: 74 | paddings = [0, 0] + paddings 75 | print (paddings) 76 | return F.pad(tensor, paddings = paddings, mode = "constant", value = 0) 77 | 78 | def _calc(self, h, t, r, mode): 79 | if self.norm_flag: 80 | h = F.normalize(h, 2, -1) 81 | r = F.normalize(r, 2, -1) 82 | t = F.normalize(t, 2, -1) 83 | if mode != 'normal': 84 | h = h.view(-1, r.shape[0], h.shape[-1]) 85 | t = t.view(-1, r.shape[0], t.shape[-1]) 86 | r = r.view(-1, r.shape[0], r.shape[-1]) 87 | if mode == 'head_batch': 88 | score = h + (r - t) 89 | else: 90 | score = (h + r) - t 91 | score = torch.norm(score, self.p_norm, -1).flatten() 92 | return score 93 | 94 | def _transfer(self, e, e_transfer, r_transfer): 95 | if e.shape[0] != r_transfer.shape[0]: 96 | e = e.view(-1, r_transfer.shape[0], e.shape[-1]) 97 | e_transfer = e_transfer.view(-1, r_transfer.shape[0], e_transfer.shape[-1]) 98 | r_transfer = r_transfer.view(-1, r_transfer.shape[0], r_transfer.shape[-1]) 99 | e = F.normalize( 100 | self._resize(e, -1, r_transfer.size()[-1]) + torch.sum(e * e_transfer, -1, True) * r_transfer, 101 | p = 2, 102 | dim = -1 103 | ) 104 | return e.view(-1, e.shape[-1]) 105 | else: 106 | return F.normalize( 107 | self._resize(e, -1, r_transfer.size()[-1]) + torch.sum(e * e_transfer, -1, True) * r_transfer, 108 | p = 2, 109 | dim = -1 110 | ) 111 | 112 | def forward(self, data): 113 | batch_h = data['batch_h'] 114 | batch_t = data['batch_t'] 115 | batch_r = data['batch_r'] 116 | mode = data['mode'] 117 | h = self.ent_embeddings(batch_h) 118 | t = self.ent_embeddings(batch_t) 119 | r = self.rel_embeddings(batch_r) 120 | h_transfer = self.ent_transfer(batch_h) 121 | t_transfer = self.ent_transfer(batch_t) 122 | r_transfer = self.rel_transfer(batch_r) 123 | h = self._transfer(h, h_transfer, r_transfer) 124 | t = self._transfer(t, t_transfer, r_transfer) 125 | score = self._calc(h ,t, r, mode) 126 | if self.margin_flag: 127 | return self.margin - score 128 | else: 129 | return score 130 | 131 | def regularization(self, data): 132 | batch_h = data['batch_h'] 133 | batch_t = data['batch_t'] 134 | batch_r = data['batch_r'] 135 | h = self.ent_embeddings(batch_h) 136 | t = self.ent_embeddings(batch_t) 137 | r = self.rel_embeddings(batch_r) 138 | h_transfer = self.ent_transfer(batch_h) 139 | t_transfer = self.ent_transfer(batch_t) 140 | r_transfer = self.rel_transfer(batch_r) 141 | regul = (torch.mean(h ** 2) + 142 | torch.mean(t ** 2) + 143 | torch.mean(r ** 2) + 144 | torch.mean(h_transfer ** 2) + 145 | torch.mean(t_transfer ** 2) + 146 | torch.mean(r_transfer ** 2)) / 6 147 | return regul 148 | 149 | def predict(self, data): 150 | score = self.forward(data) 151 | if self.margin_flag: 152 | score = self.margin - score 153 | return score.cpu().data.numpy() 154 | else: 155 | return score.cpu().data.numpy() -------------------------------------------------------------------------------- /prepro/OpenKE/openke/base/Base.cpp: -------------------------------------------------------------------------------- 1 | #include "Setting.h" 2 | #include "Random.h" 3 | #include "Reader.h" 4 | #include "Corrupt.h" 5 | #include "Test.h" 6 | #include 7 | #include 8 | 9 | extern "C" 10 | void setInPath(char *path); 11 | 12 | extern "C" 13 | void setTrainPath(char *path); 14 | 15 | extern "C" 16 | void setValidPath(char *path); 17 | 18 | extern "C" 19 | void setTestPath(char *path); 20 | 21 | extern "C" 22 | void setEntPath(char *path); 23 | 24 | extern "C" 25 | void setRelPath(char *path); 26 | 27 | extern "C" 28 | void setOutPath(char *path); 29 | 30 | extern "C" 31 | void setWorkThreads(INT threads); 32 | 33 | extern "C" 34 | void setBern(INT con); 35 | 36 | extern "C" 37 | INT getWorkThreads(); 38 | 39 | extern "C" 40 | INT getEntityTotal(); 41 | 42 | extern "C" 43 | INT getRelationTotal(); 44 | 45 | extern "C" 46 | INT getTripleTotal(); 47 | 48 | extern "C" 49 | INT getTrainTotal(); 50 | 51 | extern "C" 52 | INT getTestTotal(); 53 | 54 | extern "C" 55 | INT getValidTotal(); 56 | 57 | extern "C" 58 | void randReset(); 59 | 60 | extern "C" 61 | void importTrainFiles(); 62 | 63 | struct Parameter { 64 | INT id; 65 | INT *batch_h; 66 | INT *batch_t; 67 | INT *batch_r; 68 | REAL *batch_y; 69 | INT batchSize; 70 | INT negRate; 71 | INT negRelRate; 72 | bool p; 73 | bool val_loss; 74 | INT mode; 75 | bool filter_flag; 76 | }; 77 | 78 | void* getBatch(void* con) { 79 | Parameter *para = (Parameter *)(con); 80 | INT id = para -> id; 81 | INT *batch_h = para -> batch_h; 82 | INT *batch_t = para -> batch_t; 83 | INT *batch_r = para -> batch_r; 84 | REAL *batch_y = para -> batch_y; 85 | INT batchSize = para -> batchSize; 86 | INT negRate = para -> negRate; 87 | INT negRelRate = para -> negRelRate; 88 | bool p = para -> p; 89 | bool val_loss = para -> val_loss; 90 | INT mode = para -> mode; 91 | bool filter_flag = para -> filter_flag; 92 | INT lef, rig; 93 | if (batchSize % workThreads == 0) { 94 | lef = id * (batchSize / workThreads); 95 | rig = (id + 1) * (batchSize / workThreads); 96 | } else { 97 | lef = id * (batchSize / workThreads + 1); 98 | rig = (id + 1) * (batchSize / workThreads + 1); 99 | if (rig > batchSize) rig = batchSize; 100 | } 101 | REAL prob = 500; 102 | if (val_loss == false) { 103 | for (INT batch = lef; batch < rig; batch++) { 104 | INT i = rand_max(id, trainTotal); 105 | batch_h[batch] = trainList[i].h; 106 | batch_t[batch] = trainList[i].t; 107 | batch_r[batch] = trainList[i].r; 108 | batch_y[batch] = 1; 109 | INT last = batchSize; 110 | for (INT times = 0; times < negRate; times ++) { 111 | if (mode == 0){ 112 | if (bernFlag) 113 | prob = 1000 * right_mean[trainList[i].r] / (right_mean[trainList[i].r] + left_mean[trainList[i].r]); 114 | if (randd(id) % 1000 < prob) { 115 | batch_h[batch + last] = trainList[i].h; 116 | batch_t[batch + last] = corrupt_head(id, trainList[i].h, trainList[i].r); 117 | batch_r[batch + last] = trainList[i].r; 118 | } else { 119 | batch_h[batch + last] = corrupt_tail(id, trainList[i].t, trainList[i].r); 120 | batch_t[batch + last] = trainList[i].t; 121 | batch_r[batch + last] = trainList[i].r; 122 | } 123 | batch_y[batch + last] = -1; 124 | last += batchSize; 125 | } else { 126 | if(mode == -1){ 127 | batch_h[batch + last] = corrupt_tail(id, trainList[i].t, trainList[i].r); 128 | batch_t[batch + last] = trainList[i].t; 129 | batch_r[batch + last] = trainList[i].r; 130 | } else { 131 | batch_h[batch + last] = trainList[i].h; 132 | batch_t[batch + last] = corrupt_head(id, trainList[i].h, trainList[i].r); 133 | batch_r[batch + last] = trainList[i].r; 134 | } 135 | batch_y[batch + last] = -1; 136 | last += batchSize; 137 | } 138 | } 139 | for (INT times = 0; times < negRelRate; times++) { 140 | batch_h[batch + last] = trainList[i].h; 141 | batch_t[batch + last] = trainList[i].t; 142 | batch_r[batch + last] = corrupt_rel(id, trainList[i].h, trainList[i].t, trainList[i].r, p); 143 | batch_y[batch + last] = -1; 144 | last += batchSize; 145 | } 146 | } 147 | } 148 | else 149 | { 150 | for (INT batch = lef; batch < rig; batch++) 151 | { 152 | batch_h[batch] = validList[batch].h; 153 | batch_t[batch] = validList[batch].t; 154 | batch_r[batch] = validList[batch].r; 155 | batch_y[batch] = 1; 156 | } 157 | } 158 | pthread_exit(NULL); 159 | } 160 | 161 | extern "C" 162 | void sampling( 163 | INT *batch_h, 164 | INT *batch_t, 165 | INT *batch_r, 166 | REAL *batch_y, 167 | INT batchSize, 168 | INT negRate = 1, 169 | INT negRelRate = 0, 170 | INT mode = 0, 171 | bool filter_flag = true, 172 | bool p = false, 173 | bool val_loss = false 174 | ) { 175 | pthread_t *pt = (pthread_t *)malloc(workThreads * sizeof(pthread_t)); 176 | Parameter *para = (Parameter *)malloc(workThreads * sizeof(Parameter)); 177 | for (INT threads = 0; threads < workThreads; threads++) { 178 | para[threads].id = threads; 179 | para[threads].batch_h = batch_h; 180 | para[threads].batch_t = batch_t; 181 | para[threads].batch_r = batch_r; 182 | para[threads].batch_y = batch_y; 183 | para[threads].batchSize = batchSize; 184 | para[threads].negRate = negRate; 185 | para[threads].negRelRate = negRelRate; 186 | para[threads].p = p; 187 | para[threads].val_loss = val_loss; 188 | para[threads].mode = mode; 189 | para[threads].filter_flag = filter_flag; 190 | pthread_create(&pt[threads], NULL, getBatch, (void*)(para+threads)); 191 | } 192 | for (INT threads = 0; threads < workThreads; threads++) 193 | pthread_join(pt[threads], NULL); 194 | 195 | free(pt); 196 | free(para); 197 | } 198 | 199 | int main() { 200 | importTrainFiles(); 201 | return 0; 202 | } -------------------------------------------------------------------------------- /prepro/OpenKE/openke/config/Tester.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import torch.optim as optim 6 | import os 7 | import time 8 | import sys 9 | import datetime 10 | import ctypes 11 | import json 12 | import numpy as np 13 | from sklearn.metrics import roc_auc_score 14 | import copy 15 | from tqdm import tqdm 16 | 17 | class Tester(object): 18 | 19 | def __init__(self, model = None, data_loader = None, use_gpu = True): 20 | base_file = os.path.abspath(os.path.join(os.path.dirname(__file__), "../release/Base.so")) 21 | self.lib = ctypes.cdll.LoadLibrary(base_file) 22 | self.lib.testHead.argtypes = [ctypes.c_void_p, ctypes.c_int64, ctypes.c_int64] 23 | self.lib.testTail.argtypes = [ctypes.c_void_p, ctypes.c_int64, ctypes.c_int64] 24 | self.lib.test_link_prediction.argtypes = [ctypes.c_int64] 25 | 26 | self.lib.getTestLinkMRR.argtypes = [ctypes.c_int64] 27 | self.lib.getTestLinkMR.argtypes = [ctypes.c_int64] 28 | self.lib.getTestLinkHit10.argtypes = [ctypes.c_int64] 29 | self.lib.getTestLinkHit3.argtypes = [ctypes.c_int64] 30 | self.lib.getTestLinkHit1.argtypes = [ctypes.c_int64] 31 | 32 | self.lib.getTestLinkMRR.restype = ctypes.c_float 33 | self.lib.getTestLinkMR.restype = ctypes.c_float 34 | self.lib.getTestLinkHit10.restype = ctypes.c_float 35 | self.lib.getTestLinkHit3.restype = ctypes.c_float 36 | self.lib.getTestLinkHit1.restype = ctypes.c_float 37 | 38 | self.model = model 39 | self.data_loader = data_loader 40 | self.use_gpu = use_gpu 41 | 42 | if self.use_gpu: 43 | self.model.cuda() 44 | 45 | def set_model(self, model): 46 | self.model = model 47 | 48 | def set_data_loader(self, data_loader): 49 | self.data_loader = data_loader 50 | 51 | def set_use_gpu(self, use_gpu): 52 | self.use_gpu = use_gpu 53 | if self.use_gpu and self.model != None: 54 | self.model.cuda() 55 | 56 | def to_var(self, x, use_gpu): 57 | if use_gpu: 58 | return Variable(torch.from_numpy(x).cuda()) 59 | else: 60 | return Variable(torch.from_numpy(x)) 61 | 62 | def test_one_step(self, data): 63 | return self.model.predict({ 64 | 'batch_h': self.to_var(data['batch_h'], self.use_gpu), 65 | 'batch_t': self.to_var(data['batch_t'], self.use_gpu), 66 | 'batch_r': self.to_var(data['batch_r'], self.use_gpu), 67 | 'mode': data['mode'] 68 | }) 69 | 70 | def run_link_prediction(self, type_constrain = False): 71 | self.lib.initTest() 72 | self.data_loader.set_sampling_mode('link') 73 | if type_constrain: 74 | type_constrain = 1 75 | else: 76 | type_constrain = 0 77 | training_range = tqdm(self.data_loader) 78 | for index, [data_head, data_tail] in enumerate(training_range): 79 | score = self.test_one_step(data_head) 80 | self.lib.testHead(score.__array_interface__["data"][0], index, type_constrain) 81 | score = self.test_one_step(data_tail) 82 | self.lib.testTail(score.__array_interface__["data"][0], index, type_constrain) 83 | self.lib.test_link_prediction(type_constrain) 84 | 85 | mrr = self.lib.getTestLinkMRR(type_constrain) 86 | mr = self.lib.getTestLinkMR(type_constrain) 87 | hit10 = self.lib.getTestLinkHit10(type_constrain) 88 | hit3 = self.lib.getTestLinkHit3(type_constrain) 89 | hit1 = self.lib.getTestLinkHit1(type_constrain) 90 | print (hit10) 91 | return mrr, mr, hit10, hit3, hit1 92 | 93 | def get_best_threshlod(self, score, ans): 94 | res = np.concatenate([ans.reshape(-1,1), score.reshape(-1,1)], axis = -1) 95 | order = np.argsort(score) 96 | res = res[order] 97 | 98 | total_all = (float)(len(score)) 99 | total_current = 0.0 100 | total_true = np.sum(ans) 101 | total_false = total_all - total_true 102 | 103 | res_mx = 0.0 104 | threshlod = None 105 | for index, [ans, score] in enumerate(res): 106 | if ans == 1: 107 | total_current += 1.0 108 | res_current = (2 * total_current + total_false - index - 1) / total_all 109 | if res_current > res_mx: 110 | res_mx = res_current 111 | threshlod = score 112 | return threshlod, res_mx 113 | 114 | def run_triple_classification(self, threshlod = None): 115 | self.lib.initTest() 116 | self.data_loader.set_sampling_mode('classification') 117 | score = [] 118 | ans = [] 119 | training_range = tqdm(self.data_loader) 120 | for index, [pos_ins, neg_ins] in enumerate(training_range): 121 | res_pos = self.test_one_step(pos_ins) 122 | ans = ans + [1 for i in range(len(res_pos))] 123 | score.append(res_pos) 124 | 125 | res_neg = self.test_one_step(neg_ins) 126 | ans = ans + [0 for i in range(len(res_pos))] 127 | score.append(res_neg) 128 | 129 | score = np.concatenate(score, axis = -1) 130 | ans = np.array(ans) 131 | 132 | if threshlod == None: 133 | threshlod, _ = self.get_best_threshlod(score, ans) 134 | 135 | res = np.concatenate([ans.reshape(-1,1), score.reshape(-1,1)], axis = -1) 136 | order = np.argsort(score) 137 | res = res[order] 138 | 139 | total_all = (float)(len(score)) 140 | total_current = 0.0 141 | total_true = np.sum(ans) 142 | total_false = total_all - total_true 143 | 144 | for index, [ans, score] in enumerate(res): 145 | if score > threshlod: 146 | acc = (2 * total_current + total_false - index) / total_all 147 | break 148 | elif ans == 1: 149 | total_current += 1.0 150 | 151 | return acc, threshlod -------------------------------------------------------------------------------- /arl/modules/prediction_heads.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch_geometric.nn import GATConv 5 | from transformers.models.bert.modeling_bert import BertPredictionHeadTransform 6 | 7 | from arl.modules.position_embeddings import get_2d_sincos_pos_embed 8 | from arl.modules.vision_encoders.clip_model import Transformer, LayerNorm 9 | 10 | 11 | class Pooler(nn.Module): 12 | def __init__(self, hidden_size): 13 | super().__init__() 14 | self.dense = nn.Linear(hidden_size, hidden_size) 15 | self.activation = nn.Tanh() 16 | 17 | def forward(self, hidden_states): 18 | first_token_tensor = hidden_states[:, 0] 19 | pooled_output = self.dense(first_token_tensor) 20 | pooled_output = self.activation(pooled_output) 21 | return pooled_output 22 | 23 | 24 | class MLMHead(nn.Module): 25 | def __init__(self, config, weight=None): 26 | super().__init__() 27 | self.transform = BertPredictionHeadTransform(config) 28 | self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 29 | self.bias = nn.Parameter(torch.zeros(config.vocab_size)) 30 | if weight is not None: 31 | self.decoder.weight = weight 32 | 33 | def forward(self, x): 34 | x = self.transform(x) 35 | x = self.decoder(x) + self.bias 36 | return x 37 | 38 | 39 | class MIMHead(nn.Module): 40 | def __init__(self, config): 41 | super().__init__() 42 | self.hidden_size = config["hidden_size"] 43 | self.patch_size = config["patch_size"] 44 | self.num_patches = (config["image_size"] // config["patch_size"]) ** 2 45 | self.decoder_hidden_size = config["mim_decoder_hidden_size"] 46 | self.decoder_num_layers = config["mim_decoder_num_layers"] 47 | self.decoder_num_heads = config["mim_decoder_num_heads"] 48 | self.decoder_num_channels = 3 * config["patch_size"] ** 2 49 | 50 | self.decoder_embed = nn.Linear(self.hidden_size, self.decoder_hidden_size, bias=True) 51 | self.mask_token = nn.Parameter(torch.zeros(1, 1, self.decoder_hidden_size)) 52 | torch.nn.init.normal_(self.mask_token, std=.02) 53 | 54 | self.decoder_pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, 55 | self.decoder_hidden_size), requires_grad=False) 56 | decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_hidden_size, int(self.num_patches ** .5), True) 57 | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) 58 | 59 | self.decoder = Transformer(self.decoder_hidden_size, self.decoder_num_layers + 1, self.decoder_num_heads) 60 | self.decoder_norm = LayerNorm(self.decoder_hidden_size) 61 | self.decoder_pred = nn.Linear(self.decoder_hidden_size, self.patch_size ** 2 * 3, bias=True) 62 | 63 | def forward(self, x, ids_restore): 64 | # embed tokens 65 | x = self.decoder_embed(x) 66 | 67 | # append mask tokens to sequence 68 | mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) 69 | x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token 70 | x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle 71 | x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token 72 | 73 | # add pos embed 74 | x = x + self.decoder_pos_embed.to(x.dtype) 75 | 76 | # apply Transformer blocks 77 | x = x.permute(1, 0, 2) # NLD -> LND 78 | x = self.decoder(x) 79 | x = x.permute(1, 0, 2) # LND -> NLD 80 | x = self.decoder_norm(x) 81 | 82 | # predictor projection 83 | x = self.decoder_pred(x) 84 | 85 | # remove cls token 86 | x = x[:, 1:, :] 87 | 88 | return x 89 | 90 | 91 | class ITMHead(nn.Module): 92 | def __init__(self, hidden_size): 93 | super().__init__() 94 | self.fc = nn.Linear(hidden_size, 2) 95 | 96 | def forward(self, x): 97 | x = self.fc(x) 98 | return x 99 | 100 | 101 | class AlignHead(torch.nn.Module): 102 | def __init__(self, config): 103 | super().__init__() 104 | triples = [[int(_) for _ in line.strip().split("\t")] 105 | for line in open("data/knowledge/train2id.txt").read().strip().split("\n")[1:]] 106 | 107 | self.register_buffer("edge_index", torch.tensor(triples, dtype=torch.long).T[:2].contiguous()) 108 | self.x = nn.Parameter(torch.load("data/knowledge/ent_embeddings.ckpt", 109 | map_location="cpu")["ent_embeddings.weight"], requires_grad=True) 110 | self.ent_hidden_size = self.x.shape[1] 111 | 112 | self.embedding_ln = nn.LayerNorm(self.ent_hidden_size, eps=1e-05) 113 | 114 | self.conv1 = GATConv(self.ent_hidden_size, 64, heads=8, dropout=0.2) 115 | self.conv2 = GATConv(64 * 8, self.ent_hidden_size, heads=1, concat=False, dropout=0.2) 116 | 117 | self.img_linear1 = nn.Linear(config["hidden_size"], config["hidden_size"]) 118 | self.img_linear2 = nn.Linear(config["hidden_size"], self.ent_hidden_size) 119 | self.txt_linear1 = nn.Linear(config["hidden_size"], config["hidden_size"]) 120 | self.txt_linear2 = nn.Linear(config["hidden_size"], self.ent_hidden_size) 121 | 122 | self.ent_embedding_only = config["ent_embedding_only"] 123 | 124 | def forward(self): 125 | x, edge_index = self.x, self.edge_index 126 | x = self.embedding_ln(x) 127 | x = F.dropout(x, p=0.2, training=self.training) 128 | x = F.elu(self.conv1(x, edge_index)) 129 | x = F.dropout(x, p=0.2, training=self.training) 130 | x = self.conv2(x, edge_index) 131 | return x 132 | 133 | def img_classify(self, input): 134 | bs = input.shape[0] 135 | x = self.forward() 136 | 137 | input = F.relu(self.img_linear2(F.relu(self.img_linear1(input)))) 138 | output = torch.bmm(input.unsqueeze(1), x.T.unsqueeze(0).repeat(bs, 1, 1)).squeeze(1) 139 | return output 140 | 141 | def txt_classify(self, input): 142 | bs = input.shape[0] 143 | x = self.forward() 144 | 145 | input = F.relu(self.txt_linear2(F.relu(self.txt_linear1(input)))) 146 | output = torch.bmm(input.unsqueeze(1), x.T.unsqueeze(0).repeat(bs, 1, 1)).squeeze(1) 147 | return output 148 | 149 | def embeddings(self, ent_ids, ignore_index=-100): 150 | bs, ne = ent_ids.shape 151 | if not self.ent_embedding_only: 152 | x = self.forward() 153 | else: 154 | x = self.x 155 | embeddings = x.new_zeros((bs * ne, x.shape[1])) 156 | ent_ids = ent_ids.reshape(-1) 157 | ent_ids_ = ent_ids.clone() 158 | ent_ids_[ent_ids_ == ignore_index] = 0 159 | embeddings_ = x.index_select(0, ent_ids_) 160 | embeddings[ent_ids != ignore_index] = embeddings_[ent_ids != ignore_index] 161 | embeddings = embeddings.reshape(bs, ne, -1) 162 | return embeddings 163 | -------------------------------------------------------------------------------- /arl/datamodules/base_datamodule.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_lightning import LightningDataModule 3 | from torch.utils.data import DataLoader 4 | from transformers import ( 5 | DataCollatorForLanguageModeling, 6 | BertTokenizerFast, 7 | RobertaTokenizerFast 8 | ) 9 | 10 | from arl.datasets.data_collator import DataCollatorForWholeEntityMask 11 | 12 | 13 | def get_pretrained_tokenizer(from_pretrained): 14 | if torch.distributed.is_initialized(): 15 | if torch.distributed.get_rank() == 0: 16 | if 'roberta' in from_pretrained: 17 | RobertaTokenizerFast.from_pretrained(from_pretrained) 18 | else: 19 | BertTokenizerFast.from_pretrained(from_pretrained, do_lower_case="uncased" in from_pretrained) 20 | torch.distributed.barrier() 21 | if 'roberta' in from_pretrained: 22 | return RobertaTokenizerFast.from_pretrained(from_pretrained) 23 | return BertTokenizerFast.from_pretrained(from_pretrained, do_lower_case="uncased" in from_pretrained) 24 | 25 | 26 | class BaseDataModule(LightningDataModule): 27 | def __init__(self, _config): 28 | super().__init__() 29 | # Hyper-Parameters 30 | self.data_dir = _config["data_root"] 31 | self.num_workers = _config["num_workers"] 32 | self.batch_size = _config["per_gpu_batchsize"] 33 | self.eval_batch_size = self.batch_size 34 | self.image_size = _config["image_size"] 35 | self.max_text_len = _config["max_text_len"] 36 | self.draw_false_image = _config["draw_false_image"] 37 | self.draw_false_text = _config["draw_false_text"] 38 | self.image_only = _config["image_only"] 39 | self.label_column_name = _config["label_column_name"] 40 | self.max_num_ents = _config["max_num_ents"] 41 | 42 | # Transformations 43 | self.train_transform_keys = ( 44 | ["default_train"] 45 | if len(_config["train_transform_keys"]) == 0 46 | else _config["train_transform_keys"] 47 | ) 48 | self.val_transform_keys = ( 49 | ["default_val"] 50 | if len(_config["val_transform_keys"]) == 0 51 | else _config["val_transform_keys"] 52 | ) 53 | 54 | # Tokenizer 55 | tokenizer = _config["tokenizer"] 56 | self.tokenizer = get_pretrained_tokenizer(tokenizer) 57 | self.vocab_size = self.tokenizer.vocab_size 58 | 59 | # Collator for Dataloaders 60 | collator = ( 61 | DataCollatorForWholeEntityMask 62 | if _config["whole_word_masking"] 63 | else DataCollatorForLanguageModeling 64 | ) 65 | 66 | self.mlm_collator = collator(tokenizer=self.tokenizer, mlm=True, mlm_probability=_config["mlm_prob"]) 67 | self.setup_flag = False 68 | 69 | @property 70 | def dataset_cls(self): 71 | raise NotImplementedError("return tuple of dataset class") 72 | 73 | @property 74 | def dataset_name(self): 75 | raise NotImplementedError("return name of dataset") 76 | 77 | def set_train_dataset(self): 78 | self.train_dataset = self.dataset_cls( 79 | self.data_dir, 80 | self.train_transform_keys, 81 | split="train", 82 | image_size=self.image_size, 83 | max_text_len=self.max_text_len, 84 | draw_false_image=self.draw_false_image, 85 | draw_false_text=self.draw_false_text, 86 | image_only=self.image_only, 87 | label_column_name=self.label_column_name, 88 | max_num_ents=self.max_num_ents, 89 | ) 90 | 91 | def set_val_dataset(self): 92 | self.val_dataset = self.dataset_cls( 93 | self.data_dir, 94 | self.val_transform_keys, 95 | split="val", 96 | image_size=self.image_size, 97 | max_text_len=self.max_text_len, 98 | draw_false_image=self.draw_false_image, 99 | draw_false_text=self.draw_false_text, 100 | image_only=self.image_only, 101 | label_column_name=self.label_column_name, 102 | max_num_ents=self.max_num_ents, 103 | ) 104 | 105 | if hasattr(self, "dataset_cls_no_false"): 106 | self.val_dataset_no_false = self.dataset_cls_no_false( 107 | self.data_dir, 108 | self.val_transform_keys, 109 | split="val", 110 | image_size=self.image_size, 111 | max_text_len=self.max_text_len, 112 | draw_false_image=0, 113 | draw_false_text=0, 114 | image_only=self.image_only, 115 | label_column_name=self.label_column_name, 116 | max_num_ents=self.max_num_ents, 117 | ) 118 | 119 | def make_no_false_val_dset(self, image_only=False): 120 | return self.dataset_cls_no_false( 121 | self.data_dir, 122 | self.val_transform_keys, 123 | split="val", 124 | image_size=self.image_size, 125 | max_text_len=self.max_text_len, 126 | draw_false_image=0, 127 | draw_false_text=0, 128 | image_only=image_only, 129 | label_column_name=self.label_column_name, 130 | max_num_ents=self.max_num_ents, 131 | ) 132 | 133 | def set_test_dataset(self): 134 | self.test_dataset = self.dataset_cls( 135 | self.data_dir, 136 | self.val_transform_keys, 137 | split="test", 138 | image_size=self.image_size, 139 | max_text_len=self.max_text_len, 140 | draw_false_image=self.draw_false_image, 141 | draw_false_text=self.draw_false_text, 142 | image_only=self.image_only, 143 | label_column_name=self.label_column_name, 144 | max_num_ents=self.max_num_ents, 145 | ) 146 | 147 | def setup(self, stage): 148 | if not self.setup_flag: 149 | self.set_train_dataset() 150 | self.set_val_dataset() 151 | self.set_test_dataset() 152 | 153 | self.train_dataset.tokenizer = self.tokenizer 154 | self.val_dataset.tokenizer = self.tokenizer 155 | self.test_dataset.tokenizer = self.tokenizer 156 | 157 | self.setup_flag = True 158 | 159 | def train_dataloader(self): 160 | loader = DataLoader( 161 | self.train_dataset, 162 | batch_size=self.batch_size, 163 | shuffle=True, 164 | num_workers=self.num_workers, 165 | pin_memory=True, 166 | collate_fn=self.train_dataset.collate, 167 | ) 168 | return loader 169 | 170 | def val_dataloader(self): 171 | loader = DataLoader( 172 | self.val_dataset, 173 | batch_size=self.eval_batch_size, 174 | shuffle=False, 175 | num_workers=self.num_workers, 176 | pin_memory=True, 177 | collate_fn=self.val_dataset.collate, 178 | ) 179 | return loader 180 | 181 | def test_dataloader(self): 182 | loader = DataLoader( 183 | self.test_dataset, 184 | batch_size=self.eval_batch_size, 185 | shuffle=False, 186 | num_workers=self.num_workers, 187 | pin_memory=True, 188 | collate_fn=self.test_dataset.collate, 189 | ) 190 | return loader 191 | -------------------------------------------------------------------------------- /arl/gadgets/my_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sklearn.metrics as sklm 3 | import torch 4 | from pytorch_lightning.metrics import Metric 5 | 6 | 7 | class Accuracy(Metric): 8 | def __init__(self, dist_sync_on_step=False): 9 | super().__init__(dist_sync_on_step=dist_sync_on_step) 10 | self.add_state("correct", default=torch.tensor(0.0), dist_reduce_fx="sum") 11 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 12 | 13 | def update(self, logits, target): 14 | logits, target = ( 15 | logits.detach().to(self.correct.device), 16 | target.detach().to(self.correct.device), 17 | ) 18 | preds = logits.argmax(dim=-1) 19 | preds = preds[target != -100] 20 | target = target[target != -100] 21 | if target.numel() == 0: 22 | return 1 23 | 24 | assert preds.shape == target.shape 25 | 26 | self.correct += torch.sum(preds == target) 27 | self.total += target.numel() 28 | 29 | def compute(self): 30 | return self.correct / self.total 31 | 32 | 33 | class Scalar(Metric): 34 | def __init__(self, dist_sync_on_step=False): 35 | super().__init__(dist_sync_on_step=dist_sync_on_step) 36 | self.add_state("scalar", default=torch.tensor(0.0), dist_reduce_fx="sum") 37 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 38 | 39 | def update(self, scalar): 40 | if isinstance(scalar, torch.Tensor): 41 | scalar = scalar.detach().to(self.scalar.device) 42 | else: 43 | scalar = torch.tensor(scalar).float().to(self.scalar.device) 44 | self.scalar += scalar 45 | self.total += 1 46 | 47 | def compute(self): 48 | return self.scalar / self.total 49 | 50 | 51 | class VQAScore(Metric): 52 | def __init__(self, dist_sync_on_step=False): 53 | super().__init__(dist_sync_on_step=dist_sync_on_step) 54 | self.add_state("score", default=torch.tensor(0.0), dist_reduce_fx="sum") 55 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 56 | 57 | def update(self, logits, target): 58 | logits, target = ( 59 | logits.detach().float().to(self.score.device), 60 | target.detach().float().to(self.score.device), 61 | ) 62 | logits = torch.max(logits, 1)[1] 63 | one_hots = torch.zeros(*target.size()).to(target) 64 | one_hots.scatter_(1, logits.view(-1, 1), 1) 65 | scores = one_hots * target 66 | 67 | self.score += scores.sum() 68 | self.total += len(logits) 69 | 70 | def compute(self): 71 | return self.score / self.total 72 | 73 | 74 | class VQARADScore(Metric): 75 | def __init__(self, dist_sync_on_step=False): 76 | super().__init__(dist_sync_on_step=dist_sync_on_step) 77 | self.add_state("score", default=torch.tensor(0.0), dist_reduce_fx="sum") 78 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 79 | 80 | self.add_state("close_score", default=torch.tensor(0.0), dist_reduce_fx="sum") 81 | self.add_state("close_total", default=torch.tensor(0.0), dist_reduce_fx="sum") 82 | 83 | self.add_state("open_score", default=torch.tensor(0.0), dist_reduce_fx="sum") 84 | self.add_state("open_total", default=torch.tensor(0.0), dist_reduce_fx="sum") 85 | 86 | self.best_score = 0 87 | self.best_close_score = 0 88 | self.best_open_score = 0 89 | 90 | def update(self, logits, target, types=None): 91 | logits, target = ( 92 | logits.detach().float().to(self.score.device), 93 | target.detach().float().to(self.score.device), 94 | ) 95 | logits = torch.max(logits, 1)[1] 96 | one_hots = torch.zeros(*target.size()).to(target) 97 | one_hots.scatter_(1, logits.view(-1, 1), 1) 98 | scores = one_hots * target 99 | 100 | close_scores = scores[types == 0] 101 | open_scores = scores[types == 1] 102 | 103 | self.close_score += close_scores.sum() 104 | self.close_total += len(close_scores) 105 | self.open_score += open_scores.sum() 106 | self.open_total += len(open_scores) 107 | 108 | self.score += scores.sum() 109 | self.total += len(scores) 110 | 111 | def compute(self): 112 | score = self.score / self.total 113 | return score 114 | 115 | def get_best_score(self): 116 | self.sync() 117 | score = self.score / self.total 118 | if score > self.best_score: 119 | self.best_score = score 120 | self.best_close_score = self.close_score / self.close_total if self.close_total != 0 else 0 121 | self.best_open_score = self.open_score / self.open_total if self.open_total != 0 else 0 122 | self.unsync() 123 | return self.best_score 124 | 125 | def get_best_close_score(self): 126 | return self.best_close_score 127 | 128 | def get_best_open_score(self): 129 | return self.best_open_score 130 | 131 | 132 | class ROCScore(Metric): 133 | def __init__(self, dist_sync_on_step=False): 134 | super().__init__(dist_sync_on_step=dist_sync_on_step) 135 | self.add_state("y_trues", default=[], dist_reduce_fx="cat") 136 | self.add_state("y_scores", default=[], dist_reduce_fx="cat") 137 | self.add_state("score", default=torch.tensor(0.0), dist_reduce_fx="mean") 138 | 139 | def update(self, logits, target): 140 | logits, target = ( 141 | logits.detach().float(), 142 | target.detach().float(), 143 | ) 144 | 145 | y_true = target 146 | y_score = 1 / (1 + torch.exp(-logits)) 147 | self.y_trues.append(y_true) 148 | self.y_scores.append(y_score) 149 | 150 | def compute(self): 151 | try: 152 | score = sklm.roc_auc_score(np.concatenate([y_true.cpu().numpy() for y_true in self.y_trues], axis=0), 153 | np.concatenate([y_score.cpu().numpy() for y_score in self.y_scores], axis=0)) 154 | self.score = torch.tensor(score).to(self.score) 155 | except ValueError: 156 | self.score = torch.tensor(0).to(self.score) 157 | return self.score 158 | 159 | 160 | class F1Score(Metric): 161 | def __init__(self, dist_sync_on_step=False): 162 | super().__init__(dist_sync_on_step=dist_sync_on_step) 163 | self.add_state("y_trues", default=[], dist_reduce_fx="cat") 164 | self.add_state("y_preds", default=[], dist_reduce_fx="cat") 165 | self.add_state("score", default=torch.tensor(0.0), dist_reduce_fx="mean") 166 | 167 | def update(self, logits, target): 168 | logits, target = ( 169 | logits.detach().float(), 170 | target.detach().float(), 171 | ) 172 | 173 | y_true = target 174 | y_score = 1 / (1 + torch.exp(-logits)) > 0.5 175 | self.y_trues.append(y_true) 176 | self.y_preds.append(y_score) 177 | 178 | def compute(self): 179 | try: 180 | score = sklm.f1_score(np.concatenate([y_true.cpu().numpy() for y_true in self.y_trues], axis=0), 181 | np.concatenate([y_pred.cpu().numpy() for y_pred in self.y_preds], axis=0)) 182 | self.score = torch.tensor(score).to(self.score) 183 | except ValueError: 184 | self.score = torch.tensor(0).to(self.score) 185 | return self.score 186 | -------------------------------------------------------------------------------- /arl/transforms/randaug.py: -------------------------------------------------------------------------------- 1 | # code in this file is adpated from rpmcruz/autoaugment 2 | # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py 3 | import random 4 | 5 | import PIL 6 | import PIL.ImageDraw 7 | import PIL.ImageEnhance 8 | import PIL.ImageOps 9 | import numpy as np 10 | import torch 11 | from PIL import Image 12 | 13 | 14 | def ShearX(img, v): # [-0.3, 0.3] 15 | assert -0.3 <= v <= 0.3 16 | if random.random() > 0.5: 17 | v = -v 18 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 19 | 20 | 21 | def ShearY(img, v): # [-0.3, 0.3] 22 | assert -0.3 <= v <= 0.3 23 | if random.random() > 0.5: 24 | v = -v 25 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 26 | 27 | 28 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 29 | assert -0.45 <= v <= 0.45 30 | if random.random() > 0.5: 31 | v = -v 32 | v = v * img.size[0] 33 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 34 | 35 | 36 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 37 | assert 0 <= v 38 | if random.random() > 0.5: 39 | v = -v 40 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 41 | 42 | 43 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 44 | assert -0.45 <= v <= 0.45 45 | if random.random() > 0.5: 46 | v = -v 47 | v = v * img.size[1] 48 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 49 | 50 | 51 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 52 | assert 0 <= v 53 | if random.random() > 0.5: 54 | v = -v 55 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 56 | 57 | 58 | def Rotate(img, v): # [-30, 30] 59 | assert -30 <= v <= 30 60 | if random.random() > 0.5: 61 | v = -v 62 | return img.rotate(v) 63 | 64 | 65 | def AutoContrast(img, _): 66 | return PIL.ImageOps.autocontrast(img) 67 | 68 | 69 | def Invert(img, _): 70 | return PIL.ImageOps.invert(img) 71 | 72 | 73 | def Equalize(img, _): 74 | return PIL.ImageOps.equalize(img) 75 | 76 | 77 | def Flip(img, _): # not from the paper 78 | return PIL.ImageOps.mirror(img) 79 | 80 | 81 | def Solarize(img, v): # [0, 256] 82 | assert 0 <= v <= 256 83 | return PIL.ImageOps.solarize(img, v) 84 | 85 | 86 | def SolarizeAdd(img, addition=0, threshold=128): 87 | img_np = np.array(img).astype(np.int) 88 | img_np = img_np + addition 89 | img_np = np.clip(img_np, 0, 255) 90 | img_np = img_np.astype(np.uint8) 91 | img = Image.fromarray(img_np) 92 | return PIL.ImageOps.solarize(img, threshold) 93 | 94 | 95 | def Posterize(img, v): # [4, 8] 96 | v = int(v) 97 | v = max(1, v) 98 | return PIL.ImageOps.posterize(img, v) 99 | 100 | 101 | def Contrast(img, v): # [0.1,1.9] 102 | assert 0.1 <= v <= 1.9 103 | return PIL.ImageEnhance.Contrast(img).enhance(v) 104 | 105 | 106 | def Color(img, v): # [0.1,1.9] 107 | assert 0.1 <= v <= 1.9 108 | return PIL.ImageEnhance.Color(img).enhance(v) 109 | 110 | 111 | def Brightness(img, v): # [0.1,1.9] 112 | assert 0.1 <= v <= 1.9 113 | return PIL.ImageEnhance.Brightness(img).enhance(v) 114 | 115 | 116 | def Sharpness(img, v): # [0.1,1.9] 117 | assert 0.1 <= v <= 1.9 118 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 119 | 120 | 121 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] 122 | assert 0.0 <= v <= 0.2 123 | if v <= 0.0: 124 | return img 125 | 126 | v = v * img.size[0] 127 | return CutoutAbs(img, v) 128 | 129 | 130 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] 131 | # assert 0 <= v <= 20 132 | if v < 0: 133 | return img 134 | w, h = img.size 135 | x0 = np.random.uniform(w) 136 | y0 = np.random.uniform(h) 137 | 138 | x0 = int(max(0, x0 - v / 2.0)) 139 | y0 = int(max(0, y0 - v / 2.0)) 140 | x1 = min(w, x0 + v) 141 | y1 = min(h, y0 + v) 142 | 143 | xy = (x0, y0, x1, y1) 144 | color = (125, 123, 114) 145 | # color = (0, 0, 0) 146 | img = img.copy() 147 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 148 | return img 149 | 150 | 151 | def SamplePairing(imgs): # [0, 0.4] 152 | def f(img1, v): 153 | i = np.random.choice(len(imgs)) 154 | img2 = PIL.Image.fromarray(imgs[i]) 155 | return PIL.Image.blend(img1, img2, v) 156 | 157 | return f 158 | 159 | 160 | def Identity(img, v): 161 | return img 162 | 163 | 164 | def augment_list(): # 16 oeprations and their ranges 165 | # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57 166 | # l = [ 167 | # (Identity, 0., 1.0), 168 | # (ShearX, 0., 0.3), # 0 169 | # (ShearY, 0., 0.3), # 1 170 | # (TranslateX, 0., 0.33), # 2 171 | # (TranslateY, 0., 0.33), # 3 172 | # (Rotate, 0, 30), # 4 173 | # (AutoContrast, 0, 1), # 5 174 | # (Invert, 0, 1), # 6 175 | # (Equalize, 0, 1), # 7 176 | # (Solarize, 0, 110), # 8 177 | # (Posterize, 4, 8), # 9 178 | # # (Contrast, 0.1, 1.9), # 10 179 | # (Color, 0.1, 1.9), # 11 180 | # (Brightness, 0.1, 1.9), # 12 181 | # (Sharpness, 0.1, 1.9), # 13 182 | # # (Cutout, 0, 0.2), # 14 183 | # # (SamplePairing(imgs), 0, 0.4), # 15 184 | # ] 185 | 186 | # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505 187 | l = [ 188 | (AutoContrast, 0, 1), 189 | (Equalize, 0, 1), 190 | # (Invert, 0, 1), 191 | (Rotate, 0, 30), 192 | (Posterize, 0, 4), 193 | (Solarize, 0, 256), 194 | (SolarizeAdd, 0, 110), 195 | (Color, 0.1, 1.9), 196 | (Contrast, 0.1, 1.9), 197 | (Brightness, 0.1, 1.9), 198 | (Sharpness, 0.1, 1.9), 199 | (ShearX, 0.0, 0.3), 200 | (ShearY, 0.0, 0.3), 201 | # (CutoutAbs, 0, 40), 202 | (TranslateXabs, 0.0, 100), 203 | (TranslateYabs, 0.0, 100), 204 | ] 205 | 206 | return l 207 | 208 | 209 | class Lighting(object): 210 | """Lighting noise(AlexNet - style PCA - based noise)""" 211 | 212 | def __init__(self, alphastd, eigval, eigvec): 213 | self.alphastd = alphastd 214 | self.eigval = torch.Tensor(eigval) 215 | self.eigvec = torch.Tensor(eigvec) 216 | 217 | def __call__(self, img): 218 | if self.alphastd == 0: 219 | return img 220 | 221 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 222 | rgb = ( 223 | self.eigvec.type_as(img) 224 | .clone() 225 | .mul(alpha.view(1, 3).expand(3, 3)) 226 | .mul(self.eigval.view(1, 3).expand(3, 3)) 227 | .sum(1) 228 | .squeeze() 229 | ) 230 | 231 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 232 | 233 | 234 | class CutoutDefault(object): 235 | """ 236 | Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py 237 | """ 238 | 239 | def __init__(self, length): 240 | self.length = length 241 | 242 | def __call__(self, img): 243 | h, w = img.size(1), img.size(2) 244 | mask = np.ones((h, w), np.float32) 245 | y = np.random.randint(h) 246 | x = np.random.randint(w) 247 | 248 | y1 = np.clip(y - self.length // 2, 0, h) 249 | y2 = np.clip(y + self.length // 2, 0, h) 250 | x1 = np.clip(x - self.length // 2, 0, w) 251 | x2 = np.clip(x + self.length // 2, 0, w) 252 | 253 | mask[y1:y2, x1:x2] = 0.0 254 | mask = torch.from_numpy(mask) 255 | mask = mask.expand_as(img) 256 | img *= mask 257 | return img 258 | 259 | 260 | class RandAugment: 261 | def __init__(self, n, m): 262 | self.n = n 263 | self.m = m # [0, 30] 264 | self.augment_list = augment_list() 265 | 266 | def __call__(self, img): 267 | ops = random.choices(self.augment_list, k=self.n) 268 | for op, minval, maxval in ops: 269 | val = (float(self.m) / 30) * float(maxval - minval) + minval 270 | img = op(img, val) 271 | 272 | return img 273 | -------------------------------------------------------------------------------- /prepro/OpenKE/README.md: -------------------------------------------------------------------------------- 1 | # OpenKE-PyTorch 2 | 3 | An Open-source Framework for Knowledge Embedding implemented with PyTorch. 4 | 5 | More information is available on our website 6 | [http://openke.thunlp.org/](http://openke.thunlp.org/) 7 | 8 | If you use the code, please cite the following [paper](http://aclweb.org/anthology/D18-2024): 9 | 10 | ``` 11 | @inproceedings{han2018openke, 12 | title={OpenKE: An Open Toolkit for Knowledge Embedding}, 13 | author={Han, Xu and Cao, Shulin and Lv Xin and Lin, Yankai and Liu, Zhiyuan and Sun, Maosong and Li, Juanzi}, 14 | booktitle={Proceedings of EMNLP}, 15 | year={2018} 16 | } 17 | ``` 18 | 19 | This package is mainly contributed (in chronological order) by [Xu Han](https://github.com/THUCSTHanxu13), [Yankai Lin](https://github.com/Mrlyk423), [Ruobing Xie](http://nlp.csai.tsinghua.edu.cn/~xrb/), [Zhiyuan Liu](http://nlp.csai.tsinghua.edu.cn/~lzy/), [Xin Lv](https://github.com/davidlvxin), [Shulin Cao](https://github.com/ShulinCao), [Weize Chen](https://github.com/chenweize1998), [Jingqin Yang](https://github.com/yjqqqaq). 20 | 21 | ## Overview 22 | 23 | This is an Efficient implementation based on PyTorch for knowledge representation learning (KRL). We use C++ to implement some underlying operations such as data preprocessing and negative sampling. For each specific model, it is implemented by PyTorch with Python interfaces so that there is a convenient platform to run models on GPUs. OpenKE composes 4 repositories: 24 | 25 | OpenKE-PyTorch: the project based on PyTorch, which provides the optimized and stable framework for knowledge graph embedding models. 26 | 27 | OpenKE-Tensorflow1.0: OpenKE implemented with TensorFlow, also providing the optimized and stable framework for knowledge graph embedding models. 28 | 29 | TensorFlow-TransX: light and simple version of OpenKE based on TensorFlow, including TransE, TransH, TransR and TransD. 30 | 31 | Fast-TransX: efficient lightweight C++ inferences for TransE and its extended models utilizing the framework of OpenKE, including TransH, TransR, TransD, TranSparse and PTransE. 32 | 33 | 34 | *** **UPDATE** *** 35 | 36 | We are now developing a new version of OpenKE-PyTorch. The project has been completely reconstructed and is faster, more extendable and the codes are easier to read and use now. If you need get to the old version, please refer to branch [OpenKE-PyTorch(old)](https://github.com/thunlp/OpenKE/tree/OpenKE-PyTorch(old)). 37 | 38 | *** **New Features** *** 39 | 40 | - RotatE 41 | - More enhancing strategies (e.g., adversarial training) 42 | - More scripts of the typical models for the benchmark datasets. 43 | - More extendable interfaces 44 | 45 | 46 | ## Models 47 | 48 | OpenKE (Tensorflow): 49 | 50 | * RESCAL, HolE 51 | * DistMult, ComplEx, Analogy 52 | * TransE, TransH, TransR, TransD 53 | 54 | OpenKE (PyTorch): 55 | 56 | * RESCAL 57 | * DistMult, ComplEx, Analogy 58 | * TransE, TransH, TransR, TransD 59 | * SimplE 60 | * RotatE 61 | 62 | We welcome any issues and requests for model implementation and bug fix. 63 | 64 | ## Experimental Settings 65 | 66 | For each test triplet, the head is removed and replaced by each of the entities from the entity set in turn. The scores of those corrupted triplets are first computed by the models and then sorted by the order. Then, we get the rank of the correct entity. This whole procedure is also repeated by removing those tail entities. We report the proportion of those correct entities ranked in the top 10/3/1 (Hits@10, Hits@3, Hits@1). The mean rank (MRR) and mean reciprocal rank (MRR) of the test triplets under this setting are also reported. 67 | 68 | Because some corrupted triplets may be in the training set and validation set. In this case, those corrupted triplets may be ranked above the test triplet, but this should not be counted as an error because both triplets are true. Hence, we remove those corrupted triplets appearing in the training, validation or test set, which ensures the corrupted triplets are not in the dataset. We report the proportion of those correct entities ranked in the top 10/3/1 (Hits@10 (filter), Hits@3(filter), Hits@1(filter)) under this setting. The mean rank (MRR (filter)) and mean reciprocal rank (MRR (filter)) of the test triplets under this setting are also reported. 69 | 70 | More details of the above-mentioned settings can be found from the papers [TransE](http://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data.pdf), [ComplEx](http://proceedings.mlr.press/v48/trouillon16.pdf). 71 | 72 | For those large-scale entity sets, to corrupt all entities with the whole entity set is time-costing. Hence, we also provide the experimental setting named "[type constraint](https://www.dbs.ifi.lmu.de/~krompass/papers/TypeConstrainedRepresentationLearningInKnowledgeGraphs.pdf)" to corrupt entities with some limited entity sets determining by their relations. 73 | 74 | ## Experiments 75 | 76 | We have provided the hyper-parameters of some models to achieve the state-of-the-art performace (Hits@10 (filter)) on FB15K237 and WN18RR. These scripts can be founded in the folder "./examples/". Up to now, these models include TransE, TransH, TransR, TransD, DistMult, ComplEx. The results of these models are as follows, 77 | 78 | |Model | WN18RR | FB15K237 | WN18RR (Paper\*)| FB15K237 (Paper\*)| 79 | |:-: |:-: |:-: |:-: |:-: | 80 | |TransE |0.512 |0.476|0.501|0.486| 81 | |TransH |0.507 |0.490|-|-| 82 | |TransR |0.519 |0.511|-|-| 83 | |TransD |0.508 |0.487|-|-| 84 | |DistMult |0.479 |0.419|0.49|0.419| 85 | |ComplEx |0.485 |0.426|0.51|0.428| 86 | |ConvE |0.506 |0.485|0.52|0.501| 87 | |RotatE |0.549 |0.479|-|0.480| 88 | |RotatE (+adv) |0.565 |0.522|0.571|0.533| 89 | 90 | 91 | We are still trying more hyper-parameters and more training strategies (e.g., adversarial training and label smoothing regularization) for these models. Hence, this table is still in change. We welcome everyone to help us update this table and hyper-parameters. 92 | 93 | 94 | ## Installation 95 | 96 | 1. Install [PyTorch](https://pytorch.org/get-started/locally/) 97 | 98 | 2. Clone the OpenKE-PyTorch branch: 99 | ```bash 100 | git clone -b OpenKE-PyTorch https://github.com/thunlp/OpenKE --depth 1 101 | cd OpenKE 102 | cd openke 103 | ``` 104 | 3. Compile C++ files 105 | ```bash 106 | bash make.sh 107 | ``` 108 | 4. Quick Start 109 | ```bash 110 | cd ../ 111 | cp examples/train_transe_FB15K237.py ./ 112 | python train_transe_FB15K237.py 113 | ``` 114 | ## Data 115 | 116 | * For training, datasets contain three files: 117 | 118 | train2id.txt: training file, the first line is the number of triples for training. Then the following lines are all in the format ***(e1, e2, rel)*** which indicates there is a relation ***rel*** between ***e1*** and ***e2*** . 119 | **Note that train2id.txt contains ids from entitiy2id.txt and relation2id.txt instead of the names of the entities and relations. If you use your own datasets, please check the format of your training file. Files in the wrong format may cause segmentation fault.** 120 | 121 | entity2id.txt: all entities and corresponding ids, one per line. The first line is the number of entities. 122 | 123 | relation2id.txt: all relations and corresponding ids, one per line. The first line is the number of relations. 124 | 125 | * For testing, datasets contain additional two files (totally five files): 126 | 127 | test2id.txt: testing file, the first line is the number of triples for testing. Then the following lines are all in the format ***(e1, e2, rel)*** . 128 | 129 | valid2id.txt: validating file, the first line is the number of triples for validating. Then the following lines are all in the format ***(e1, e2, rel)*** . 130 | 131 | type_constrain.txt: type constraining file, the first line is the number of relations. Then the following lines are type constraints for each relation. For example, the relation with id 1200 has 4 types of head entities, which are 3123, 1034, 58 and 5733. The relation with id 1200 has 4 types of tail entities, which are 12123, 4388, 11087 and 11088. You can get this file through **n-n.py** in folder benchmarks/FB15K 132 | 133 | ## To do 134 | 135 | The document of the new version of OpenKE-PyTorch will come soon. 136 | -------------------------------------------------------------------------------- /arl/modules/dist_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | This file contains primitives for multi-gpu communication. 4 | This is useful when doing distributed training. 5 | """ 6 | 7 | import functools 8 | import logging 9 | import pickle 10 | 11 | import numpy as np 12 | import torch 13 | import torch.distributed as dist 14 | 15 | _LOCAL_PROCESS_GROUP = None 16 | """ 17 | A torch process group which only includes processes that on the same machine as the current process. 18 | This variable is set when processes are spawned by `launch()` in "engine/launch.py". 19 | """ 20 | 21 | 22 | def get_world_size() -> int: 23 | if not dist.is_available(): 24 | return 1 25 | if not dist.is_initialized(): 26 | return 1 27 | return dist.get_world_size() 28 | 29 | 30 | def get_rank() -> int: 31 | if not dist.is_available(): 32 | return 0 33 | if not dist.is_initialized(): 34 | return 0 35 | return dist.get_rank() 36 | 37 | 38 | def get_local_rank() -> int: 39 | """ 40 | Returns: 41 | The rank of the current process within the local (per-machine) process group. 42 | """ 43 | if not dist.is_available(): 44 | return 0 45 | if not dist.is_initialized(): 46 | return 0 47 | assert _LOCAL_PROCESS_GROUP is not None 48 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP) 49 | 50 | 51 | def get_local_size() -> int: 52 | """ 53 | Returns: 54 | The size of the per-machine process group, 55 | i.e. the number of processes per machine. 56 | """ 57 | if not dist.is_available(): 58 | return 1 59 | if not dist.is_initialized(): 60 | return 1 61 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) 62 | 63 | 64 | def is_main_process() -> bool: 65 | return get_rank() == 0 66 | 67 | 68 | def synchronize(): 69 | """ 70 | Helper function to synchronize (barrier) among all processes when 71 | using distributed training 72 | """ 73 | if not dist.is_available(): 74 | return 75 | if not dist.is_initialized(): 76 | return 77 | world_size = dist.get_world_size() 78 | if world_size == 1: 79 | return 80 | dist.barrier() 81 | 82 | 83 | @functools.lru_cache() 84 | def _get_global_gloo_group(): 85 | """ 86 | Return a process group based on gloo backend, containing all the ranks 87 | The result is cached. 88 | """ 89 | if dist.get_backend() == "nccl": 90 | return dist.new_group(backend="gloo") 91 | else: 92 | return dist.group.WORLD 93 | 94 | 95 | def _serialize_to_tensor(data, group): 96 | backend = dist.get_backend(group) 97 | assert backend in ["gloo", "nccl"] 98 | device = torch.device("cpu" if backend == "gloo" else "cuda") 99 | 100 | buffer = pickle.dumps(data) 101 | if len(buffer) > 1024 ** 3: 102 | logger = logging.getLogger(__name__) 103 | logger.warning( 104 | "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( 105 | get_rank(), len(buffer) / (1024 ** 3), device 106 | ) 107 | ) 108 | storage = torch.ByteStorage.from_buffer(buffer) 109 | tensor = torch.ByteTensor(storage).to(device=device) 110 | return tensor 111 | 112 | 113 | def _pad_to_largest_tensor(tensor, group): 114 | """ 115 | Returns: 116 | list[int]: size of the tensor, on each rank 117 | Tensor: padded tensor that has the max size 118 | """ 119 | world_size = dist.get_world_size(group=group) 120 | assert ( 121 | world_size >= 1 122 | ), "comm.gather/all_gather must be called from ranks within the given group!" 123 | local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) 124 | size_list = [ 125 | torch.zeros([1], dtype=torch.int64, device=tensor.device) 126 | for _ in range(world_size) 127 | ] 128 | dist.all_gather(size_list, local_size, group=group) 129 | size_list = [int(size.item()) for size in size_list] 130 | 131 | max_size = max(size_list) 132 | 133 | # we pad the tensor because torch all_gather does not support 134 | # gathering tensors of different shapes 135 | if local_size != max_size: 136 | padding = torch.zeros( 137 | (max_size - local_size,), dtype=torch.uint8, device=tensor.device 138 | ) 139 | tensor = torch.cat((tensor, padding), dim=0) 140 | return size_list, tensor 141 | 142 | 143 | def all_gather(data, group=None): 144 | """ 145 | Run all_gather on arbitrary picklable data (not necessarily tensors). 146 | 147 | Args: 148 | data: any picklable object 149 | group: a torch process group. By default, will use a group which 150 | contains all ranks on gloo backend. 151 | 152 | Returns: 153 | list[data]: list of data gathered from each rank 154 | """ 155 | if get_world_size() == 1: 156 | return [data] 157 | if group is None: 158 | group = _get_global_gloo_group() 159 | if dist.get_world_size(group) == 1: 160 | return [data] 161 | 162 | tensor = _serialize_to_tensor(data, group) 163 | 164 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 165 | max_size = max(size_list) 166 | 167 | # receiving Tensor from all ranks 168 | tensor_list = [ 169 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) 170 | for _ in size_list 171 | ] 172 | dist.all_gather(tensor_list, tensor, group=group) 173 | 174 | data_list = [] 175 | for size, tensor in zip(size_list, tensor_list): 176 | buffer = tensor.cpu().numpy().tobytes()[:size] 177 | data_list.append(pickle.loads(buffer)) 178 | 179 | return data_list 180 | 181 | 182 | def gather(data, dst=0, group=None): 183 | """ 184 | Run gather on arbitrary picklable data (not necessarily tensors). 185 | 186 | Args: 187 | data: any picklable object 188 | dst (int): destination rank 189 | group: a torch process group. By default, will use a group which 190 | contains all ranks on gloo backend. 191 | 192 | Returns: 193 | list[data]: on dst, a list of data gathered from each rank. Otherwise, 194 | an empty list. 195 | """ 196 | if get_world_size() == 1: 197 | return [data] 198 | if group is None: 199 | group = _get_global_gloo_group() 200 | if dist.get_world_size(group=group) == 1: 201 | return [data] 202 | rank = dist.get_rank(group=group) 203 | 204 | tensor = _serialize_to_tensor(data, group) 205 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 206 | 207 | # receiving Tensor from all ranks 208 | if rank == dst: 209 | max_size = max(size_list) 210 | tensor_list = [ 211 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) 212 | for _ in size_list 213 | ] 214 | dist.gather(tensor, tensor_list, dst=dst, group=group) 215 | 216 | data_list = [] 217 | for size, tensor in zip(size_list, tensor_list): 218 | buffer = tensor.cpu().numpy().tobytes()[:size] 219 | data_list.append(pickle.loads(buffer)) 220 | return data_list 221 | else: 222 | dist.gather(tensor, [], dst=dst, group=group) 223 | return [] 224 | 225 | 226 | def shared_random_seed(): 227 | """ 228 | Returns: 229 | int: a random number that is the same across all workers. 230 | If workers need a shared RNG, they can use this shared seed to 231 | create one. 232 | 233 | All workers must call this function, otherwise it will deadlock. 234 | """ 235 | ints = np.random.randint(2 ** 31) 236 | all_ints = all_gather(ints) 237 | return all_ints[0] 238 | 239 | 240 | def reduce_dict(input_dict, average=True): 241 | """ 242 | Reduce the values in the dictionary from all processes so that process with rank 243 | 0 has the reduced results. 244 | 245 | Args: 246 | input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. 247 | average (bool): whether to do average or sum 248 | 249 | Returns: 250 | a dict with the same keys as input_dict, after reduction. 251 | """ 252 | world_size = get_world_size() 253 | if world_size < 2: 254 | return input_dict 255 | with torch.no_grad(): 256 | names = [] 257 | values = [] 258 | # sort the keys so that they are consistent across processes 259 | for k in sorted(input_dict.keys()): 260 | names.append(k) 261 | values.append(input_dict[k]) 262 | values = torch.stack(values, dim=0) 263 | dist.reduce(values, dst=0) 264 | if dist.get_rank() == 0 and average: 265 | # only main process gets accumulated, so only divide by 266 | # world_size in this case 267 | values /= world_size 268 | reduced_dict = {k: v for k, v in zip(names, values)} 269 | return reduced_dict 270 | -------------------------------------------------------------------------------- /prepro/OpenKE/openke/base/Reader.h: -------------------------------------------------------------------------------- 1 | #ifndef READER_H 2 | #define READER_H 3 | #include "Setting.h" 4 | #include "Triple.h" 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | INT *freqRel, *freqEnt; 11 | INT *lefHead, *rigHead; 12 | INT *lefTail, *rigTail; 13 | INT *lefRel, *rigRel; 14 | REAL *left_mean, *right_mean; 15 | REAL *prob; 16 | 17 | Triple *trainList; 18 | Triple *trainHead; 19 | Triple *trainTail; 20 | Triple *trainRel; 21 | 22 | INT *testLef, *testRig; 23 | INT *validLef, *validRig; 24 | 25 | extern "C" 26 | void importProb(REAL temp){ 27 | if (prob != NULL) 28 | free(prob); 29 | FILE *fin; 30 | fin = fopen((inPath + "kl_prob.txt").c_str(), "r"); 31 | printf("Current temperature:%f\n", temp); 32 | prob = (REAL *)calloc(relationTotal * (relationTotal - 1), sizeof(REAL)); 33 | INT tmp; 34 | for (INT i = 0; i < relationTotal * (relationTotal - 1); ++i){ 35 | tmp = fscanf(fin, "%f", &prob[i]); 36 | } 37 | REAL sum = 0.0; 38 | for (INT i = 0; i < relationTotal; ++i) { 39 | for (INT j = 0; j < relationTotal-1; ++j){ 40 | REAL tmp = exp(-prob[i * (relationTotal - 1) + j] / temp); 41 | sum += tmp; 42 | prob[i * (relationTotal - 1) + j] = tmp; 43 | } 44 | for (INT j = 0; j < relationTotal-1; ++j){ 45 | prob[i*(relationTotal-1)+j] /= sum; 46 | } 47 | sum = 0; 48 | } 49 | fclose(fin); 50 | } 51 | 52 | extern "C" 53 | void importTrainFiles() { 54 | 55 | printf("The toolkit is importing datasets.\n"); 56 | FILE *fin; 57 | int tmp; 58 | 59 | if (rel_file == "") 60 | fin = fopen((inPath + "relation2id.txt").c_str(), "r"); 61 | else 62 | fin = fopen(rel_file.c_str(), "r"); 63 | tmp = fscanf(fin, "%ld", &relationTotal); 64 | printf("The total of relations is %ld.\n", relationTotal); 65 | fclose(fin); 66 | 67 | if (ent_file == "") 68 | fin = fopen((inPath + "entity2id.txt").c_str(), "r"); 69 | else 70 | fin = fopen(ent_file.c_str(), "r"); 71 | tmp = fscanf(fin, "%ld", &entityTotal); 72 | printf("The total of entities is %ld.\n", entityTotal); 73 | fclose(fin); 74 | 75 | if (train_file == "") 76 | fin = fopen((inPath + "train2id.txt").c_str(), "r"); 77 | else 78 | fin = fopen(train_file.c_str(), "r"); 79 | tmp = fscanf(fin, "%ld", &trainTotal); 80 | trainList = (Triple *)calloc(trainTotal, sizeof(Triple)); 81 | trainHead = (Triple *)calloc(trainTotal, sizeof(Triple)); 82 | trainTail = (Triple *)calloc(trainTotal, sizeof(Triple)); 83 | trainRel = (Triple *)calloc(trainTotal, sizeof(Triple)); 84 | freqRel = (INT *)calloc(relationTotal, sizeof(INT)); 85 | freqEnt = (INT *)calloc(entityTotal, sizeof(INT)); 86 | for (INT i = 0; i < trainTotal; i++) { 87 | tmp = fscanf(fin, "%ld", &trainList[i].h); 88 | tmp = fscanf(fin, "%ld", &trainList[i].t); 89 | tmp = fscanf(fin, "%ld", &trainList[i].r); 90 | } 91 | fclose(fin); 92 | std::sort(trainList, trainList + trainTotal, Triple::cmp_head); 93 | tmp = trainTotal; trainTotal = 1; 94 | trainHead[0] = trainTail[0] = trainRel[0] = trainList[0]; 95 | freqEnt[trainList[0].t] += 1; 96 | freqEnt[trainList[0].h] += 1; 97 | freqRel[trainList[0].r] += 1; 98 | for (INT i = 1; i < tmp; i++) 99 | if (trainList[i].h != trainList[i - 1].h || trainList[i].r != trainList[i - 1].r || trainList[i].t != trainList[i - 1].t) { 100 | trainHead[trainTotal] = trainTail[trainTotal] = trainRel[trainTotal] = trainList[trainTotal] = trainList[i]; 101 | trainTotal++; 102 | freqEnt[trainList[i].t]++; 103 | freqEnt[trainList[i].h]++; 104 | freqRel[trainList[i].r]++; 105 | } 106 | 107 | std::sort(trainHead, trainHead + trainTotal, Triple::cmp_head); 108 | std::sort(trainTail, trainTail + trainTotal, Triple::cmp_tail); 109 | std::sort(trainRel, trainRel + trainTotal, Triple::cmp_rel); 110 | printf("The total of train triples is %ld.\n", trainTotal); 111 | 112 | lefHead = (INT *)calloc(entityTotal, sizeof(INT)); 113 | rigHead = (INT *)calloc(entityTotal, sizeof(INT)); 114 | lefTail = (INT *)calloc(entityTotal, sizeof(INT)); 115 | rigTail = (INT *)calloc(entityTotal, sizeof(INT)); 116 | lefRel = (INT *)calloc(entityTotal, sizeof(INT)); 117 | rigRel = (INT *)calloc(entityTotal, sizeof(INT)); 118 | memset(rigHead, -1, sizeof(INT)*entityTotal); 119 | memset(rigTail, -1, sizeof(INT)*entityTotal); 120 | memset(rigRel, -1, sizeof(INT)*entityTotal); 121 | for (INT i = 1; i < trainTotal; i++) { 122 | if (trainTail[i].t != trainTail[i - 1].t) { 123 | rigTail[trainTail[i - 1].t] = i - 1; 124 | lefTail[trainTail[i].t] = i; 125 | } 126 | if (trainHead[i].h != trainHead[i - 1].h) { 127 | rigHead[trainHead[i - 1].h] = i - 1; 128 | lefHead[trainHead[i].h] = i; 129 | } 130 | if (trainRel[i].h != trainRel[i - 1].h) { 131 | rigRel[trainRel[i - 1].h] = i - 1; 132 | lefRel[trainRel[i].h] = i; 133 | } 134 | } 135 | lefHead[trainHead[0].h] = 0; 136 | rigHead[trainHead[trainTotal - 1].h] = trainTotal - 1; 137 | lefTail[trainTail[0].t] = 0; 138 | rigTail[trainTail[trainTotal - 1].t] = trainTotal - 1; 139 | lefRel[trainRel[0].h] = 0; 140 | rigRel[trainRel[trainTotal - 1].h] = trainTotal - 1; 141 | 142 | left_mean = (REAL *)calloc(relationTotal,sizeof(REAL)); 143 | right_mean = (REAL *)calloc(relationTotal,sizeof(REAL)); 144 | for (INT i = 0; i < entityTotal; i++) { 145 | for (INT j = lefHead[i] + 1; j <= rigHead[i]; j++) 146 | if (trainHead[j].r != trainHead[j - 1].r) 147 | left_mean[trainHead[j].r] += 1.0; 148 | if (lefHead[i] <= rigHead[i]) 149 | left_mean[trainHead[lefHead[i]].r] += 1.0; 150 | for (INT j = lefTail[i] + 1; j <= rigTail[i]; j++) 151 | if (trainTail[j].r != trainTail[j - 1].r) 152 | right_mean[trainTail[j].r] += 1.0; 153 | if (lefTail[i] <= rigTail[i]) 154 | right_mean[trainTail[lefTail[i]].r] += 1.0; 155 | } 156 | for (INT i = 0; i < relationTotal; i++) { 157 | left_mean[i] = freqRel[i] / left_mean[i]; 158 | right_mean[i] = freqRel[i] / right_mean[i]; 159 | } 160 | } 161 | 162 | Triple *testList; 163 | Triple *validList; 164 | Triple *tripleList; 165 | 166 | extern "C" 167 | void importTestFiles() { 168 | FILE *fin; 169 | INT tmp; 170 | 171 | if (rel_file == "") 172 | fin = fopen((inPath + "relation2id.txt").c_str(), "r"); 173 | else 174 | fin = fopen(rel_file.c_str(), "r"); 175 | tmp = fscanf(fin, "%ld", &relationTotal); 176 | fclose(fin); 177 | 178 | if (ent_file == "") 179 | fin = fopen((inPath + "entity2id.txt").c_str(), "r"); 180 | else 181 | fin = fopen(ent_file.c_str(), "r"); 182 | tmp = fscanf(fin, "%ld", &entityTotal); 183 | fclose(fin); 184 | 185 | FILE* f_kb1, * f_kb2, * f_kb3; 186 | if (train_file == "") 187 | f_kb2 = fopen((inPath + "train2id.txt").c_str(), "r"); 188 | else 189 | f_kb2 = fopen(train_file.c_str(), "r"); 190 | if (test_file == "") 191 | f_kb1 = fopen((inPath + "test2id.txt").c_str(), "r"); 192 | else 193 | f_kb1 = fopen(test_file.c_str(), "r"); 194 | if (valid_file == "") 195 | f_kb3 = fopen((inPath + "valid2id.txt").c_str(), "r"); 196 | else 197 | f_kb3 = fopen(valid_file.c_str(), "r"); 198 | tmp = fscanf(f_kb1, "%ld", &testTotal); 199 | tmp = fscanf(f_kb2, "%ld", &trainTotal); 200 | tmp = fscanf(f_kb3, "%ld", &validTotal); 201 | tripleTotal = testTotal + trainTotal + validTotal; 202 | testList = (Triple *)calloc(testTotal, sizeof(Triple)); 203 | validList = (Triple *)calloc(validTotal, sizeof(Triple)); 204 | tripleList = (Triple *)calloc(tripleTotal, sizeof(Triple)); 205 | for (INT i = 0; i < testTotal; i++) { 206 | tmp = fscanf(f_kb1, "%ld", &testList[i].h); 207 | tmp = fscanf(f_kb1, "%ld", &testList[i].t); 208 | tmp = fscanf(f_kb1, "%ld", &testList[i].r); 209 | tripleList[i] = testList[i]; 210 | } 211 | for (INT i = 0; i < trainTotal; i++) { 212 | tmp = fscanf(f_kb2, "%ld", &tripleList[i + testTotal].h); 213 | tmp = fscanf(f_kb2, "%ld", &tripleList[i + testTotal].t); 214 | tmp = fscanf(f_kb2, "%ld", &tripleList[i + testTotal].r); 215 | } 216 | for (INT i = 0; i < validTotal; i++) { 217 | tmp = fscanf(f_kb3, "%ld", &tripleList[i + testTotal + trainTotal].h); 218 | tmp = fscanf(f_kb3, "%ld", &tripleList[i + testTotal + trainTotal].t); 219 | tmp = fscanf(f_kb3, "%ld", &tripleList[i + testTotal + trainTotal].r); 220 | validList[i] = tripleList[i + testTotal + trainTotal]; 221 | } 222 | fclose(f_kb1); 223 | fclose(f_kb2); 224 | fclose(f_kb3); 225 | 226 | std::sort(tripleList, tripleList + tripleTotal, Triple::cmp_head); 227 | std::sort(testList, testList + testTotal, Triple::cmp_rel2); 228 | std::sort(validList, validList + validTotal, Triple::cmp_rel2); 229 | printf("The total of test triples is %ld.\n", testTotal); 230 | printf("The total of valid triples is %ld.\n", validTotal); 231 | 232 | testLef = (INT *)calloc(relationTotal, sizeof(INT)); 233 | testRig = (INT *)calloc(relationTotal, sizeof(INT)); 234 | memset(testLef, -1, sizeof(INT) * relationTotal); 235 | memset(testRig, -1, sizeof(INT) * relationTotal); 236 | for (INT i = 1; i < testTotal; i++) { 237 | if (testList[i].r != testList[i-1].r) { 238 | testRig[testList[i-1].r] = i - 1; 239 | testLef[testList[i].r] = i; 240 | } 241 | } 242 | testLef[testList[0].r] = 0; 243 | testRig[testList[testTotal - 1].r] = testTotal - 1; 244 | 245 | validLef = (INT *)calloc(relationTotal, sizeof(INT)); 246 | validRig = (INT *)calloc(relationTotal, sizeof(INT)); 247 | memset(validLef, -1, sizeof(INT)*relationTotal); 248 | memset(validRig, -1, sizeof(INT)*relationTotal); 249 | for (INT i = 1; i < validTotal; i++) { 250 | if (validList[i].r != validList[i-1].r) { 251 | validRig[validList[i-1].r] = i - 1; 252 | validLef[validList[i].r] = i; 253 | } 254 | } 255 | validLef[validList[0].r] = 0; 256 | validRig[validList[validTotal - 1].r] = validTotal - 1; 257 | } 258 | 259 | INT* head_lef; 260 | INT* head_rig; 261 | INT* tail_lef; 262 | INT* tail_rig; 263 | INT* head_type; 264 | INT* tail_type; 265 | 266 | extern "C" 267 | void importTypeFiles() { 268 | 269 | head_lef = (INT *)calloc(relationTotal, sizeof(INT)); 270 | head_rig = (INT *)calloc(relationTotal, sizeof(INT)); 271 | tail_lef = (INT *)calloc(relationTotal, sizeof(INT)); 272 | tail_rig = (INT *)calloc(relationTotal, sizeof(INT)); 273 | INT total_lef = 0; 274 | INT total_rig = 0; 275 | FILE* f_type = fopen((inPath + "type_constrain.txt").c_str(),"r"); 276 | INT tmp; 277 | tmp = fscanf(f_type, "%ld", &tmp); 278 | for (INT i = 0; i < relationTotal; i++) { 279 | INT rel, tot; 280 | tmp = fscanf(f_type, "%ld %ld", &rel, &tot); 281 | for (INT j = 0; j < tot; j++) { 282 | tmp = fscanf(f_type, "%ld", &tmp); 283 | total_lef++; 284 | } 285 | tmp = fscanf(f_type, "%ld%ld", &rel, &tot); 286 | for (INT j = 0; j < tot; j++) { 287 | tmp = fscanf(f_type, "%ld", &tmp); 288 | total_rig++; 289 | } 290 | } 291 | fclose(f_type); 292 | head_type = (INT *)calloc(total_lef, sizeof(INT)); 293 | tail_type = (INT *)calloc(total_rig, sizeof(INT)); 294 | total_lef = 0; 295 | total_rig = 0; 296 | f_type = fopen((inPath + "type_constrain.txt").c_str(),"r"); 297 | tmp = fscanf(f_type, "%ld", &tmp); 298 | for (INT i = 0; i < relationTotal; i++) { 299 | INT rel, tot; 300 | tmp = fscanf(f_type, "%ld%ld", &rel, &tot); 301 | head_lef[rel] = total_lef; 302 | for (INT j = 0; j < tot; j++) { 303 | tmp = fscanf(f_type, "%ld", &head_type[total_lef]); 304 | total_lef++; 305 | } 306 | head_rig[rel] = total_lef; 307 | std::sort(head_type + head_lef[rel], head_type + head_rig[rel]); 308 | tmp = fscanf(f_type, "%ld%ld", &rel, &tot); 309 | tail_lef[rel] = total_rig; 310 | for (INT j = 0; j < tot; j++) { 311 | tmp = fscanf(f_type, "%ld", &tail_type[total_rig]); 312 | total_rig++; 313 | } 314 | tail_rig[rel] = total_rig; 315 | std::sort(tail_type + tail_lef[rel], tail_type + tail_rig[rel]); 316 | } 317 | fclose(f_type); 318 | } 319 | 320 | 321 | #endif 322 | --------------------------------------------------------------------------------