├── scripts ├── gqe_FB15k.sh ├── gqe_FB15k-237.sh ├── q2b_FB15k.sh ├── gqe_NELL-995.sh ├── q2b_NELL-995.sh ├── q2b_FB15k-237.sh ├── betae_FB15k.sh ├── betae_NELL-995.sh ├── betae_FB15k-237.sh ├── logice_FB15k.sh ├── logice_NELL-995.sh └── logice_FB15k-237.sh ├── util.py ├── README.md ├── dataloader.py ├── type_aggregator.py ├── main.py └── models.py /scripts/gqe_FB15k.sh: -------------------------------------------------------------------------------- 1 | export DATA_PATH=../data/FB15k-betae 2 | export SAVE_PATH=../logs/FB15k/gqe_temp 3 | export LOG_PATH=../logs/FB15k/gqe_temp.out 4 | export MODEL=temp 5 | export FAITHFUL=no_faithful 6 | 7 | export MAX_STEPS=450000 8 | export VALID_STEPS=10000 9 | export SAVE_STEPS=10000 10 | export ENT_TYPE_NEIGHBOR=32 11 | export REL_TYPE_NEIGHBOR=64 12 | 13 | CUDA_VISIBLE_DEVICES=0 nohup python -u ../main.py --cuda --do_train --do_valid --do_test \ 14 | --data_path $DATA_PATH --save_path $SAVE_PATH -n 128 -b 512 -d 800 -g 24 \ 15 | -lr 0.0001 --max_steps $MAX_STEPS --valid_steps $VALID_STEPS --save_checkpoint_steps $SAVE_STEPS \ 16 | --cpu_num 1 --geo vec --test_batch_size 16 --tasks "1p.2p.3p.2i.3i.ip.pi.2u.up" --print_on_screen \ 17 | --faithful $FAITHFUL --model_mode $MODEL --neighbor_ent_type_samples $ENT_TYPE_NEIGHBOR --neighbor_rel_type_samples $REL_TYPE_NEIGHBOR \ 18 | > $LOG_PATH 2>&1 & -------------------------------------------------------------------------------- /scripts/gqe_FB15k-237.sh: -------------------------------------------------------------------------------- 1 | export DATA_PATH=../data/FB15k-237-betae 2 | export SAVE_PATH=../logs/FB15k-237/gqe_temp 3 | export LOG_PATH=../logs/FB15k-237/gqe_temp.out 4 | export MODEL=temp 5 | export FAITHFUL=no_faithful 6 | 7 | export MAX_STEPS=450000 8 | export VALID_STEPS=10000 9 | export SAVE_STEPS=10000 10 | export ENT_TYPE_NEIGHBOR=32 11 | export REL_TYPE_NEIGHBOR=64 12 | 13 | CUDA_VISIBLE_DEVICES=0 nohup python -u ../main.py --cuda --do_train --do_valid --do_test \ 14 | --data_path $DATA_PATH --save_path $SAVE_PATH -n 128 -b 512 -d 800 -g 24 \ 15 | -lr 0.0001 --max_steps $MAX_STEPS --valid_steps $VALID_STEPS --save_checkpoint_steps $SAVE_STEPS \ 16 | --cpu_num 1 --geo vec --test_batch_size 16 --tasks "1p.2p.3p.2i.3i.ip.pi.2u.up" --print_on_screen \ 17 | --faithful $FAITHFUL --model_mode $MODEL --neighbor_ent_type_samples $ENT_TYPE_NEIGHBOR --neighbor_rel_type_samples $REL_TYPE_NEIGHBOR \ 18 | > $LOG_PATH 2>&1 & -------------------------------------------------------------------------------- /scripts/q2b_FB15k.sh: -------------------------------------------------------------------------------- 1 | export DATA_PATH=../data/FB15k-betae 2 | export SAVE_PATH=../logs/FB15k/q2b_temp 3 | export LOG_PATH=../logs/FB15k/q2b_temp.out 4 | export MODEL=temp 5 | export FAITHFUL=no_faithful 6 | 7 | export MAX_STEPS=450000 8 | export VALID_STEPS=10000 9 | export SAVE_STEPS=10000 10 | export ENT_TYPE_NEIGHBOR=32 11 | export REL_TYPE_NEIGHBOR=64 12 | 13 | CUDA_VISIBLE_DEVICES=0 nohup python -u ../main.py --cuda --do_train --do_valid --do_test \ 14 | --data_path $DATA_PATH --save_path $SAVE_PATH -n 128 -b 512 -d 400 -g 24 \ 15 | -lr 0.0001 --max_steps $MAX_STEPS --valid_steps $VALID_STEPS --save_checkpoint_steps $SAVE_STEPS \ 16 | --cpu_num 1 --geo box -boxm "(none,0.02)" --test_batch_size 16 --tasks "1p.2p.3p.2i.3i.ip.pi.2u.up" --print_on_screen \ 17 | --faithful $FAITHFUL --model_mode $MODEL --neighbor_ent_type_samples $ENT_TYPE_NEIGHBOR --neighbor_rel_type_samples $REL_TYPE_NEIGHBOR \ 18 | > $LOG_PATH 2>&1 & -------------------------------------------------------------------------------- /scripts/gqe_NELL-995.sh: -------------------------------------------------------------------------------- 1 | export DATA_PATH=../data/NELL-betae 2 | export SAVE_PATH=../logs/NELL-995/gqe_temp 3 | export LOG_PATH=../logs/NELL-995/gqe_temp.out 4 | export MODEL=temp 5 | export FAITHFUL=no_faithful 6 | 7 | export MAX_STEPS=450000 8 | export VALID_STEPS=10000 9 | export SAVE_STEPS=10000 10 | export ENT_TYPE_NEIGHBOR=6 11 | export REL_TYPE_NEIGHBOR=64 12 | 13 | CUDA_VISIBLE_DEVICES=0 nohup python -u ../main.py --cuda --do_train --do_valid --do_test \ 14 | --data_path $DATA_PATH --save_path $SAVE_PATH -n 128 -b 512 -d 800 -g 24 \ 15 | -lr 0.0001 --max_steps $MAX_STEPS --valid_steps $VALID_STEPS --save_checkpoint_steps $SAVE_STEPS \ 16 | --cpu_num 1 --geo vec --test_batch_size 16 --tasks "1p.2p.3p.2i.3i.ip.pi.2u.up" --print_on_screen --faithful $FAITHFUL \ 17 | --faithful $FAITHFUL --model_mode $MODEL --neighbor_ent_type_samples $ENT_TYPE_NEIGHBOR --neighbor_rel_type_samples $REL_TYPE_NEIGHBOR \ 18 | > $LOG_PATH 2>&1 & -------------------------------------------------------------------------------- /scripts/q2b_NELL-995.sh: -------------------------------------------------------------------------------- 1 | export DATA_PATH=../data/NELL-betae 2 | export SAVE_PATH=../logs/NELL-995/q2b_temp 3 | export LOG_PATH=../logs/NELL-995/q2b_temp.out 4 | export MODEL=temp 5 | export FAITHFUL=no_faithful 6 | 7 | export MAX_STEPS=450000 8 | export VALID_STEPS=10000 9 | export SAVE_STEPS=10000 10 | export ENT_TYPE_NEIGHBOR=6 11 | export REL_TYPE_NEIGHBOR=64 12 | 13 | CUDA_VISIBLE_DEVICES=0 nohup python -u ../main.py --cuda --do_train --do_valid --do_test \ 14 | --data_path $DATA_PATH --save_path $SAVE_PATH -n 128 -b 512 -d 400 -g 24 \ 15 | -lr 0.0001 --max_steps $MAX_STEPS --valid_steps $VALID_STEPS --save_checkpoint_steps $SAVE_STEPS \ 16 | --cpu_num 1 --geo box -boxm "(none,0.02)" --test_batch_size 16 --tasks "1p.2p.3p.2i.3i.ip.pi.2u.up" --print_on_screen \ 17 | --faithful $FAITHFUL --model_mode $MODEL --neighbor_ent_type_samples $ENT_TYPE_NEIGHBOR --neighbor_rel_type_samples $REL_TYPE_NEIGHBOR \ 18 | > $LOG_PATH 2>&1 & -------------------------------------------------------------------------------- /scripts/q2b_FB15k-237.sh: -------------------------------------------------------------------------------- 1 | export DATA_PATH=../data/FB15k-237-betae 2 | export SAVE_PATH=../logs/FB15k-237/q2b_temp 3 | export LOG_PATH=../logs/FB15k-237/q2b_temp.out 4 | export MODEL=temp 5 | export FAITHFUL=no_faithful 6 | 7 | export MAX_STEPS=450000 8 | export VALID_STEPS=10000 9 | export SAVE_STEPS=10000 10 | export ENT_TYPE_NEIGHBOR=32 11 | export REL_TYPE_NEIGHBOR=64 12 | 13 | CUDA_VISIBLE_DEVICES=0 nohup python -u ../main.py --cuda --do_train --do_valid --do_test \ 14 | --data_path $DATA_PATH --save_path $SAVE_PATH -n 128 -b 512 -d 400 -g 24 \ 15 | -lr 0.0001 --max_steps $MAX_STEPS --valid_steps $VALID_STEPS --save_checkpoint_steps $SAVE_STEPS \ 16 | --cpu_num 1 --geo box -boxm "(none,0.02)" --test_batch_size 16 --tasks "1p.2p.3p.2i.3i.ip.pi.2u.up" --print_on_screen \ 17 | --faithful $FAITHFUL --model_mode $MODEL --neighbor_ent_type_samples $ENT_TYPE_NEIGHBOR --neighbor_rel_type_samples $REL_TYPE_NEIGHBOR \ 18 | > $LOG_PATH 2>&1 & -------------------------------------------------------------------------------- /scripts/betae_FB15k.sh: -------------------------------------------------------------------------------- 1 | export DATA_PATH=../data/FB15k-betae 2 | export SAVE_PATH=../logs/FB15k/betae_temp 3 | export LOG_PATH=../logs/FB15k/betae_temp.out 4 | export MODEL=temp 5 | export FAITHFUL=no_faithful 6 | 7 | export MAX_STEPS=450000 8 | export VALID_STEPS=10000 9 | export SAVE_STEPS=10000 10 | export ENT_TYPE_NEIGHBOR=32 11 | export REL_TYPE_NEIGHBOR=64 12 | 13 | CUDA_VISIBLE_DEVICES=0 nohup python -u ../main.py --cuda --do_train --do_valid --do_test \ 14 | --data_path $DATA_PATH --save_path $SAVE_PATH -n 128 -b 512 -d 400 -g 60 \ 15 | -lr 0.0001 --max_steps $MAX_STEPS --valid_steps $VALID_STEPS --save_checkpoint_steps $SAVE_STEPS \ 16 | --cpu_num 1 --geo beta -betam "(1600,2)" --test_batch_size 16 --tasks "1p.2p.3p.2i.3i.ip.pi.2in.3in.inp.pin.pni.2u.up" --print_on_screen \ 17 | --faithful $FAITHFUL --model_mode $MODEL --neighbor_ent_type_samples $ENT_TYPE_NEIGHBOR --neighbor_rel_type_samples $REL_TYPE_NEIGHBOR \ 18 | > $LOG_PATH 2>&1 & -------------------------------------------------------------------------------- /scripts/betae_NELL-995.sh: -------------------------------------------------------------------------------- 1 | export DATA_PATH=../data/NELL-betae 2 | export SAVE_PATH=../logs/NELL-995/betae_temp 3 | export LOG_PATH=../logs/NELL-995/betae_temp.out 4 | export MODEL=temp 5 | export FAITHFUL=no_faithful 6 | 7 | export MAX_STEPS=450000 8 | export VALID_STEPS=10000 9 | export SAVE_STEPS=10000 10 | export ENT_TYPE_NEIGHBOR=6 11 | export REL_TYPE_NEIGHBOR=64 12 | 13 | CUDA_VISIBLE_DEVICES=0 nohup python -u ../main.py --cuda --do_train --do_valid --do_test \ 14 | --data_path $DATA_PATH --save_path $SAVE_PATH -n 128 -b 512 -d 400 -g 60 \ 15 | -lr 0.0001 --max_steps $MAX_STEPS --valid_steps $VALID_STEPS --save_checkpoint_steps $SAVE_STEPS \ 16 | --cpu_num 1 --geo beta -betam "(1600,2)" --test_batch_size 16 --tasks "1p.2p.3p.2i.3i.ip.pi.2in.3in.inp.pin.pni.2u.up" --print_on_screen \ 17 | --faithful $FAITHFUL --model_mode $MODEL --neighbor_ent_type_samples $ENT_TYPE_NEIGHBOR --neighbor_rel_type_samples $REL_TYPE_NEIGHBOR \ 18 | > $LOG_PATH 2>&1 & -------------------------------------------------------------------------------- /scripts/betae_FB15k-237.sh: -------------------------------------------------------------------------------- 1 | export DATA_PATH=../data/FB15k-237-betae 2 | export SAVE_PATH=../logs/FB15k-237/betae_temp 3 | export LOG_PATH=../logs/FB15k-237/betae_temp.out 4 | export MODEL=temp 5 | export FAITHFUL=no_faithful 6 | 7 | export MAX_STEPS=450000 8 | export VALID_STEPS=10000 9 | export SAVE_STEPS=10000 10 | export ENT_TYPE_NEIGHBOR=32 11 | export REL_TYPE_NEIGHBOR=64 12 | 13 | CUDA_VISIBLE_DEVICES=0 nohup python -u ../main.py --cuda --do_train --do_valid --do_test \ 14 | --data_path $DATA_PATH --save_path $SAVE_PATH -n 128 -b 512 -d 400 -g 60 \ 15 | -lr 0.0001 --max_steps $MAX_STEPS --valid_steps $VALID_STEPS --save_checkpoint_steps $SAVE_STEPS \ 16 | --cpu_num 1 --geo beta -betam "(1600,2)" --test_batch_size 16 --tasks "1p.2p.3p.2i.3i.ip.pi.2in.3in.inp.pin.pni.2u.up" --print_on_screen \ 17 | --faithful $FAITHFUL --model_mode $MODEL --neighbor_ent_type_samples $ENT_TYPE_NEIGHBOR --neighbor_rel_type_samples $REL_TYPE_NEIGHBOR \ 18 | > $LOG_PATH 2>&1 & -------------------------------------------------------------------------------- /scripts/logice_FB15k.sh: -------------------------------------------------------------------------------- 1 | export DATA_PATH=../data/FB15k-betae 2 | export SAVE_PATH=../logs/FB15k/logice_temp 3 | export LOG_PATH=../logs/FB15k/logice_temp.out 4 | export MODEL=temp 5 | export FAITHFUL=no_faithful 6 | 7 | export MAX_STEPS=450000 8 | export VALID_STEPS=10000 9 | export SAVE_STEPS=10000 10 | export ENT_TYPE_NEIGHBOR=32 11 | export REL_TYPE_NEIGHBOR=64 12 | 13 | CUDA_VISIBLE_DEVICES=0 nohup python -u ../main.py --cuda --do_train --do_valid --do_test \ 14 | --data_path $DATA_PATH --save_path $SAVE_PATH -n 128 -b 512 -d 400 -g 0.375 \ 15 | -lr 0.0001 --max_steps $MAX_STEPS --valid_steps $VALID_STEPS --save_checkpoint_steps $SAVE_STEPS \ 16 | --cpu_num 1 --geo logic -logicm "(luk,0,1,0,1600,2)" --test_batch_size 16 --tasks "1p.2p.3p.2i.3i.ip.pi.2in.3in.inp.pin.pni.2u.up" --print_on_screen \ 17 | --faithful $FAITHFUL --model_mode $MODEL --neighbor_ent_type_samples $ENT_TYPE_NEIGHBOR --neighbor_rel_type_samples $REL_TYPE_NEIGHBOR \ 18 | > $LOG_PATH 2>&1 & -------------------------------------------------------------------------------- /scripts/logice_NELL-995.sh: -------------------------------------------------------------------------------- 1 | export DATA_PATH=../data/NELL-betae 2 | export SAVE_PATH=../logs/NELL-995/logice_temp 3 | export LOG_PATH=../logs/NELL-995/logice_temp.out 4 | export MODEL=temp 5 | export FAITHFUL=no_faithful 6 | 7 | export MAX_STEPS=450000 8 | export VALID_STEPS=10000 9 | export SAVE_STEPS=10000 10 | export ENT_TYPE_NEIGHBOR=6 11 | export REL_TYPE_NEIGHBOR=64 12 | 13 | CUDA_VISIBLE_DEVICES=0 nohup python -u ../main.py --cuda --do_train --do_valid --do_test \ 14 | --data_path $DATA_PATH --save_path $SAVE_PATH -n 128 -b 512 -d 400 -g 0.375 \ 15 | -lr 0.0001 --max_steps $MAX_STEPS --valid_steps $VALID_STEPS --save_checkpoint_steps $SAVE_STEPS \ 16 | --cpu_num 1 --geo logic -logicm "(luk,0,1,0,1600,2)" --test_batch_size 16 --tasks "1p.2p.3p.2i.3i.ip.pi.2in.3in.inp.pin.pni.2u.up" --print_on_screen \ 17 | --faithful $FAITHFUL --model_mode $MODEL --neighbor_ent_type_samples $ENT_TYPE_NEIGHBOR --neighbor_rel_type_samples $REL_TYPE_NEIGHBOR \ 18 | > $LOG_PATH 2>&1 & -------------------------------------------------------------------------------- /scripts/logice_FB15k-237.sh: -------------------------------------------------------------------------------- 1 | export DATA_PATH=../data/FB15k-237-betae 2 | export SAVE_PATH=../logs/FB15k-237/logice_temp 3 | export LOG_PATH=../logs/FB15k-237/logice_temp.out 4 | export MODEL=temp 5 | export FAITHFUL=no_faithful 6 | 7 | export MAX_STEPS=450000 8 | export VALID_STEPS=10000 9 | export SAVE_STEPS=10000 10 | export ENT_TYPE_NEIGHBOR=32 11 | export REL_TYPE_NEIGHBOR=64 12 | 13 | CUDA_VISIBLE_DEVICES=0 nohup python -u ../main.py --cuda --do_train --do_valid --do_test \ 14 | --data_path $DATA_PATH --save_path $SAVE_PATH -n 128 -b 512 -d 400 -g 0.375 \ 15 | -lr 0.0001 --max_steps $MAX_STEPS --valid_steps $VALID_STEPS --save_checkpoint_steps $SAVE_STEPS \ 16 | --cpu_num 1 --geo logic -logicm "(luk,0,1,0,1600,2)" --test_batch_size 16 --tasks "1p.2p.3p.2i.3i.ip.pi.2in.3in.inp.pin.pni.2u.up" --print_on_screen \ 17 | --faithful $FAITHFUL --model_mode $MODEL --neighbor_ent_type_samples $ENT_TYPE_NEIGHBOR --neighbor_rel_type_samples $REL_TYPE_NEIGHBOR \ 18 | > $LOG_PATH 2>&1 & -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import torch 4 | import time 5 | 6 | def list2tuple(l): 7 | return tuple(list2tuple(x) if type(x)==list else x for x in l) 8 | 9 | def tuple2list(t): 10 | return list(tuple2list(x) if type(x)==tuple else x for x in t) 11 | 12 | flatten=lambda l: sum(map(flatten, l),[]) if isinstance(l,tuple) else [l] 13 | 14 | def parse_time(): 15 | return time.strftime("%Y.%m.%d-%H:%M:%S", time.localtime()) 16 | 17 | def set_global_seed(seed): 18 | torch.manual_seed(seed) 19 | torch.cuda.manual_seed(seed) 20 | np.random.seed(seed) 21 | random.seed(seed) 22 | torch.backends.cudnn.deterministic=True 23 | 24 | def eval_tuple(arg_return): 25 | """Evaluate a tuple string into a tuple.""" 26 | if type(arg_return) == tuple: 27 | return arg_return 28 | if arg_return[0] not in ["(", "["]: 29 | arg_return = eval(arg_return) 30 | else: 31 | splitted = arg_return[1:-1].split(",") 32 | List = [] 33 | for item in splitted: 34 | try: 35 | item = eval(item) 36 | except: 37 | pass 38 | if item == "": 39 | continue 40 | List.append(item) 41 | arg_return = tuple(List) 42 | return arg_return 43 | 44 | def flatten_query(queries): 45 | all_queries = [] 46 | for query_structure in queries: 47 | tmp_queries = list(queries[query_structure]) 48 | all_queries.extend([(query, query_structure) for query in tmp_queries]) 49 | return all_queries -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Type-aware Embeddings for Multi-Hop Reasoning over Knowledge Graphs 2 | #### This repo provides the source code & data of our paper: [Type-aware Embeddings for Multi-Hop Reasoning over Knowledge Graphs (IJCAI 2022)](https://arxiv.org/pdf/2205.00782.pdf). 3 | ## Dependencies 4 | * conda create -n temp python=3.7 -y 5 | * PyTorch 1.8.1 6 | * tensorboardX 2.5.1 7 | * numpy 1.21.6 8 | ## Running the code 9 | ### Dataset 10 | * Download the datasets from [here](https://drive.google.com/drive/folders/15ZJo6zuoj0S3Sx_8nz7TKr3Tq7Ku8JMR?usp=sharing). 11 | * Create the root directory ./data and put the datasets in. 12 | * It should be noted that we only provide the data provided by the BetaE paper (the corresponding dataset in Table 7 of the paper). For the dataset corresponding to Q2B (the corresponding dataset in Table 1 of the paper), you can download it from [here](http://snap.stanford.edu/betae/KG_data.zip). 13 | * You need to move *id2type.pkl*, *type2id.pkl*, *entity_type.npy* and *relation_type.npy* in the corresponding BetaE's dataset to the corresponding Q2B's dataset. 14 | ### Models 15 | - [x] [GQE](https://arxiv.org/abs/1806.01445) 16 | - [x] [Query2Box](https://arxiv.org/abs/1806.01445) 17 | - [x] [BetaE](https://arxiv.org/abs/2010.11465) 18 | - [x] [LogicE](https://arxiv.org/pdf/2103.00418.pdf) 19 | * We added our TEMP module to the above four models. 20 | ### Training Model 21 | * Take the GQE model in the FB15k-237 dataset as an example: 22 | #### Generalization 23 | ``` 24 | export DATA_PATH=../data/FB15k-237-betae 25 | export SAVE_PATH=../logs/FB15k-237/gqe_temp 26 | export LOG_PATH=../logs/FB15k-237/gqe_temp.out 27 | export MODEL=temp 28 | export FAITHFUL=no_faithful 29 | 30 | export MAX_STEPS=450000 31 | export VALID_STEPS=10000 32 | export SAVE_STEPS=10000 33 | export ENT_TYPE_NEIGHBOR=32 34 | export REL_TYPE_NEIGHBOR=64 35 | 36 | CUDA_VISIBLE_DEVICES=0 nohup python -u ../main.py --cuda --do_train --do_valid --do_test \ 37 | --data_path $DATA_PATH --save_path $SAVE_PATH -n 128 -b 512 -d 800 -g 24 \ 38 | -lr 0.0001 --max_steps $MAX_STEPS --valid_steps $VALID_STEPS --save_checkpoint_steps $SAVE_STEPS \ 39 | --cpu_num 1 --geo vec --test_batch_size 16 --tasks "1p.2p.3p.2i.3i.ip.pi.2u.up" --print_on_screen \ 40 | --faithful $FAITHFUL --model_mode $MODEL --neighbor_ent_type_samples $ENT_TYPE_NEIGHBOR --neighbor_rel_type_samples $REL_TYPE_NEIGHBOR \ 41 | > $LOG_PATH 2>&1 & 42 | ``` 43 | #### Deductive 44 | ``` 45 | export DATA_PATH=../data/FB15k-237-betae 46 | export SAVE_PATH=../logs/FB15k-237/gqe_faithful_temp 47 | export LOG_PATH=../logs/FB15k-237/gqe_faithful_temp.out 48 | export MODEL=temp 49 | export FAITHFUL=faithful 50 | 51 | export MAX_STEPS=450000 52 | export VALID_STEPS=10000 53 | export SAVE_STEPS=10000 54 | export ENT_TYPE_NEIGHBOR=32 55 | export REL_TYPE_NEIGHBOR=64 56 | 57 | CUDA_VISIBLE_DEVICES=0 nohup python -u ../main.py --cuda --do_train --do_valid --do_test \ 58 | --data_path $DATA_PATH --save_path $SAVE_PATH -n 128 -b 512 -d 800 -g 24 \ 59 | -lr 0.0001 --max_steps $MAX_STEPS --valid_steps $VALID_STEPS --save_checkpoint_steps $SAVE_STEPS \ 60 | --cpu_num 1 --geo vec --test_batch_size 16 --tasks "1p.2p.3p.2i.3i.ip.pi.2u.up" --print_on_screen \ 61 | --faithful $FAITHFUL --model_mode $MODEL --neighbor_ent_type_samples $ENT_TYPE_NEIGHBOR --neighbor_rel_type_samples $REL_TYPE_NEIGHBOR \ 62 | > $LOG_PATH 2>&1 & 63 | ``` 64 | * Other running scripts can be seen in ./scripts. 65 | ## Citation 66 | If you find this code useful, please consider citing the following paper. 67 | ``` 68 | @article{DBLP:journals/corr/abs-2205-00782, 69 | author = {Zhiwei Hu and Víctor Gutiérrez-Basulto and Zhiliang Xiang and Xiaoli Li and Ru Li and Jeff Z. Pan}, 70 | title = {Type-aware Embeddings for Multi-Hop Reasoning over Knowledge Graphs}, 71 | journal = {CoRR}, 72 | volume = {abs/2205.00782}, 73 | year = {2022}, 74 | url = {https://doi.org/10.48550/arXiv.2205.00782}, 75 | doi = {10.48550/arXiv.2205.00782}, 76 | eprint = {2205.00782}, 77 | } 78 | ``` 79 | 80 | ## Acknowledgement 81 | We refer to the code of [KGReasoning](https://hub.fastgit.xyz/snap-stanford/KGReasoning). Thanks for their contributions. 82 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from torch.utils.data import Dataset 11 | from util import flatten 12 | 13 | class TestDataset(Dataset): 14 | def __init__(self, queries, nentity, nrelation): 15 | # queries is a list of (query, query_structure) pairs 16 | self.len = len(queries) 17 | self.queries = queries 18 | self.nentity = nentity 19 | self.nrelation = nrelation 20 | 21 | def __len__(self): 22 | return self.len 23 | 24 | def __getitem__(self, idx): 25 | query = self.queries[idx][0] 26 | query_structure = self.queries[idx][1] 27 | negative_sample = torch.LongTensor(range(self.nentity)) 28 | return negative_sample, flatten(query), query, query_structure 29 | 30 | @staticmethod 31 | def collate_fn(data): 32 | negative_sample = torch.stack([_[0] for _ in data], dim=0) 33 | query = [_[1] for _ in data] 34 | query_unflatten = [_[2] for _ in data] 35 | query_structure = [_[3] for _ in data] 36 | return negative_sample, query, query_unflatten, query_structure 37 | 38 | class TrainDataset(Dataset): 39 | def __init__(self, queries, nentity, nrelation, negative_sample_size, answer): 40 | # queries is a list of (query, query_structure) pairs 41 | self.len = len(queries) 42 | self.queries = queries 43 | self.nentity = nentity 44 | self.nrelation = nrelation 45 | self.negative_sample_size = negative_sample_size 46 | self.answer = answer 47 | self.count = self.count_frequency(queries, answer) 48 | 49 | def __len__(self): 50 | return self.len 51 | 52 | def __getitem__(self, idx): 53 | query = self.queries[idx][0] 54 | query_structure = self.queries[idx][1] 55 | tail = np.random.choice(list(self.answer[query])) 56 | subsampling_weight = self.count[query] 57 | subsampling_weight = torch.sqrt(1 / torch.Tensor([subsampling_weight])) 58 | negative_sample_list = [] 59 | negative_sample_size = 0 60 | while negative_sample_size < self.negative_sample_size: 61 | negative_sample = np.random.randint(self.nentity, size=self.negative_sample_size*2) 62 | mask = np.in1d( 63 | negative_sample, 64 | self.answer[query], 65 | assume_unique=True, 66 | invert=True 67 | ) 68 | negative_sample = negative_sample[mask] 69 | negative_sample_list.append(negative_sample) 70 | negative_sample_size += negative_sample.size 71 | negative_sample = np.concatenate(negative_sample_list)[:self.negative_sample_size] 72 | negative_sample = torch.from_numpy(negative_sample) 73 | positive_sample = torch.LongTensor([tail]) 74 | return positive_sample, negative_sample, subsampling_weight, flatten(query), query_structure 75 | 76 | @staticmethod 77 | def collate_fn(data): 78 | positive_sample = torch.cat([_[0] for _ in data], dim=0) 79 | negative_sample = torch.stack([_[1] for _ in data], dim=0) 80 | subsample_weight = torch.cat([_[2] for _ in data], dim=0) 81 | query = [_[3] for _ in data] 82 | query_structure = [_[4] for _ in data] 83 | return positive_sample, negative_sample, subsample_weight, query, query_structure 84 | 85 | @staticmethod 86 | def count_frequency(queries, answer, start=4): 87 | count = {} 88 | for query, qtype in queries: 89 | count[query] = start + len(answer[query]) 90 | return count 91 | 92 | class SingledirectionalOneShotIterator(object): 93 | def __init__(self, dataloader): 94 | self.iterator = self.one_shot_iterator(dataloader) 95 | self.step = 0 96 | 97 | def __next__(self): 98 | self.step += 1 99 | data = next(self.iterator) 100 | return data 101 | 102 | @staticmethod 103 | def one_shot_iterator(dataloader): 104 | while True: 105 | for data in dataloader: 106 | yield data -------------------------------------------------------------------------------- /type_aggregator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2021/10/12 12:41 4 | @Auth : zhiweihu 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | from abc import abstractmethod 9 | from typing import Optional 10 | 11 | class Aggregator(nn.Module): 12 | def __init__(self, input_dim, output_dim, act, self_included, neighbor_ent_type_samples): 13 | super(Aggregator, self).__init__() 14 | self.input_dim = input_dim 15 | self.output_dim = output_dim 16 | self.act = act 17 | self.self_included = self_included 18 | self.neighbor_ent_type_samples = neighbor_ent_type_samples 19 | 20 | def forward(self, self_vectors, neighbor_vectors): 21 | outputs = self._call(self_vectors, neighbor_vectors) 22 | return outputs 23 | 24 | @abstractmethod 25 | def _call(self, self_vectors, entity_vectors): 26 | pass 27 | 28 | class EntityTypeAggregator(Aggregator): 29 | def __init__(self, input_dim, output_dim, act=lambda x: x, self_included=True, with_sigmoid=False, neighbor_ent_type_samples=32): 30 | super(EntityTypeAggregator, self).__init__(input_dim, output_dim, act, self_included, neighbor_ent_type_samples) 31 | self.proj_layer = HighwayNetwork(neighbor_ent_type_samples, 1, 2, activation=nn.Sigmoid()) 32 | 33 | multiplier = 2 if self_included else 1 34 | self.layer = nn.Linear(self.input_dim * multiplier, self.output_dim) 35 | nn.init.xavier_uniform_(self.layer.weight) 36 | self.with_sigmoid = with_sigmoid 37 | 38 | def _call(self, self_vectors, neighbor_vectors): 39 | neighbor_vectors = torch.transpose(neighbor_vectors, 1, 2) 40 | neighbor_vectors = self.proj_layer(neighbor_vectors) 41 | neighbor_vectors = torch.transpose(neighbor_vectors, 1, 2) 42 | neighbor_vectors = neighbor_vectors.squeeze(1) 43 | 44 | if self.self_included: 45 | self_vectors = self_vectors.view([-1, self.input_dim]) 46 | output = torch.cat([self_vectors, neighbor_vectors], dim=-1) 47 | output = self.layer(output) 48 | output = output.view([-1, self.output_dim]) 49 | if self.with_sigmoid: 50 | output = torch.sigmoid(output) 51 | 52 | return self.act(output) 53 | 54 | class HighwayNetwork(nn.Module): 55 | def __init__(self, 56 | input_dim: int, 57 | output_dim: int, 58 | n_layers: int, 59 | activation: Optional[nn.Module] = None): 60 | super(HighwayNetwork, self).__init__() 61 | self.n_layers = n_layers 62 | self.nonlinear = nn.ModuleList( 63 | [nn.Linear(input_dim, input_dim) for _ in range(n_layers)]) 64 | self.gate = nn.ModuleList( 65 | [nn.Linear(input_dim, input_dim) for _ in range(n_layers)]) 66 | for layer in self.gate: 67 | layer.bias = torch.nn.Parameter(0. * torch.ones_like(layer.bias)) 68 | self.final_linear_layer = nn.Linear(input_dim, output_dim) 69 | self.activation = nn.ReLU() if activation is None else activation 70 | self.sigmoid = nn.Sigmoid() 71 | 72 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 73 | for layer_idx in range(self.n_layers): 74 | gate_values = self.sigmoid(self.gate[layer_idx](inputs)) 75 | nonlinear = self.activation(self.nonlinear[layer_idx](inputs)) 76 | inputs = gate_values * nonlinear + (1. - gate_values) * inputs 77 | return self.final_linear_layer(inputs) 78 | 79 | class Match(nn.Module): 80 | def __init__(self, hidden_size, with_sigmoid=False): 81 | super(Match, self).__init__() 82 | self.map_linear = nn.Linear(2 * hidden_size, 2 * hidden_size) 83 | self.trans_linear = nn.Linear(hidden_size, hidden_size) 84 | self.with_sigmoid = with_sigmoid 85 | 86 | def forward(self, inputs): 87 | proj_p, proj_q = inputs 88 | trans_q = self.trans_linear(proj_q) 89 | att_weights = proj_p.bmm(torch.transpose(trans_q, 1, 2)) 90 | att_norm = torch.nn.functional.softmax(att_weights, dim=-1) 91 | att_vec = att_norm.bmm(proj_q) 92 | elem_min = att_vec - proj_p 93 | elem_mul = att_vec * proj_p 94 | all_con = torch.cat([elem_min, elem_mul], 2) 95 | output = self.map_linear(all_con) 96 | if self.with_sigmoid: 97 | output = torch.sigmoid(output) 98 | return output 99 | 100 | class RelationTypeAggregator(nn.Module): 101 | def __init__(self, hidden_size, with_sigmoid=False): 102 | super(RelationTypeAggregator, self).__init__() 103 | self.linear = nn.Linear(2 * hidden_size, hidden_size) 104 | self.linear2 = nn.Linear(2 * hidden_size, 2 * hidden_size) 105 | self.with_sigmoid = with_sigmoid 106 | 107 | def forward(self, inputs): 108 | p, q = inputs 109 | lq = self.linear2(q) 110 | lp = self.linear2(p) 111 | mid = nn.Sigmoid()(lq+lp) 112 | output = p * mid + q * (1-mid) 113 | output = self.linear(output) 114 | if self.with_sigmoid: 115 | output = torch.sigmoid(output) 116 | return output 117 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import argparse 8 | import json 9 | import logging 10 | import os 11 | 12 | import numpy as np 13 | import torch 14 | from torch.utils.data import DataLoader 15 | from models import KGReasoning 16 | from dataloader import TestDataset, TrainDataset, SingledirectionalOneShotIterator 17 | from tensorboardX import SummaryWriter 18 | import pickle 19 | from collections import defaultdict 20 | from util import flatten_query, parse_time, set_global_seed, eval_tuple 21 | 22 | # import os 23 | # os.environ['CUDA_VISIBLE_DEVICES']='4' 24 | 25 | query_name_dict = {('e',('r',)): '1p', 26 | ('e', ('r', 'r')): '2p', 27 | ('e', ('r', 'r', 'r')): '3p', 28 | ('e', ('r', 'r', 'r', 'r')): '4p', 29 | ('e', ('r', 'r', 'r', 'r', 'r')): '5p', 30 | (('e', ('r',)), ('e', ('r',))): '2i', 31 | (('e', ('r',)), ('e', ('r',)), ('e', ('r',))): '3i', 32 | ((('e', ('r',)), ('e', ('r',))), ('r',)): 'ip', 33 | (('e', ('r', 'r')), ('e', ('r',))): 'pi', 34 | (('e', ('r',)), ('e', ('r', 'n'))): '2in', 35 | (('e', ('r',)), ('e', ('r',)), ('e', ('r', 'n'))): '3in', 36 | ((('e', ('r',)), ('e', ('r', 'n'))), ('r',)): 'inp', 37 | (('e', ('r', 'r')), ('e', ('r', 'n'))): 'pin', 38 | (('e', ('r', 'r', 'n')), ('e', ('r',))): 'pni', 39 | (('e', ('r',)), ('e', ('r',)), ('u',)): '2u-DNF', 40 | ((('e', ('r',)), ('e', ('r',)), ('u',)), ('r',)): 'up-DNF', 41 | ((('e', ('r', 'n')), ('e', ('r', 'n'))), ('n',)): '2u-DM', 42 | ((('e', ('r', 'n')), ('e', ('r', 'n'))), ('n', 'r')): 'up-DM' 43 | } 44 | name_query_dict = {value: key for key, value in query_name_dict.items()} 45 | all_tasks = list(name_query_dict.keys()) # ['1p', '2p', '3p', '4p', '5p'] 46 | 47 | def parse_args(args=None): 48 | parser = argparse.ArgumentParser( 49 | description='Training and Testing Knowledge Graph Embedding Models', 50 | usage='train.py [] [-h | --help]' 51 | ) 52 | 53 | parser.add_argument('--cuda', action='store_true', help='use GPU', default=True) 54 | 55 | parser.add_argument('--do_train', action='store_true', help="do train", default=True) 56 | parser.add_argument('--do_valid', action='store_true', help="do valid", default=True) 57 | parser.add_argument('--do_test', action='store_true', help="do test", default=True) 58 | 59 | parser.add_argument('--data_path', type=str, default='./data/FB15k-237-long_chain', help="KG data path") 60 | parser.add_argument('-n', '--negative_sample_size', default=128, type=int, help="negative entities sampled per query") 61 | parser.add_argument('-d', '--hidden_dim', default=400, type=int, help="embedding dimension") 62 | parser.add_argument('-g', '--gamma', default=0.375, type=float, help="margin in the loss") 63 | parser.add_argument('-b', '--batch_size', default=512, type=int, help="batch size of queries") 64 | parser.add_argument('--drop', type=float, default=0.1, help='dropout rate') 65 | parser.add_argument('--test_batch_size', default=16, type=int, help='valid/test batch size') 66 | parser.add_argument('-lr', '--learning_rate', default=0.0001, type=float) 67 | parser.add_argument('-cpu', '--cpu_num', default=1, type=int, help="used to speed up torch.dataloader") 68 | parser.add_argument('-save', '--save_path', default='./logs/FB15k-237/gqe_baseline_test', type=str, help="no need to set manually, will configure automatically") 69 | parser.add_argument('--max_steps', default=300000, type=int, help="maximum iterations to train") 70 | parser.add_argument('--warm_up_steps', default=None, type=int, help="no need to set manually, will configure automatically") 71 | 72 | parser.add_argument('--save_checkpoint_steps', default=10000, type=int, help="save checkpoints every xx steps") 73 | parser.add_argument('--valid_steps', default=10000, type=int, help="evaluate validation queries every xx steps") 74 | parser.add_argument('--log_steps', default=100, type=int, help='train log every xx steps') 75 | parser.add_argument('--test_log_steps', default=1000, type=int, help='valid/test log every xx steps') 76 | 77 | parser.add_argument('--nentity', type=int, default=0, help='DO NOT MANUALLY SET') 78 | parser.add_argument('--nrelation', type=int, default=0, help='DO NOT MANUALLY SET') 79 | 80 | parser.add_argument('--geo', default='logic', type=str, choices=['vec', 'box', 'beta', 'cone', 'logic'], help='the reasoning model, vec for GQE, box for Query2box, beta for BetaE') 81 | parser.add_argument('--print_on_screen', action='store_true', default=True) 82 | 83 | parser.add_argument('--tasks', default='1p.2p.3p.4p.5p', type=str, help="tasks connected by dot, refer to the BetaE paper for detailed meaning and structure of each task") 84 | parser.add_argument('--seed', default=0, type=int, help="random seed") 85 | parser.add_argument('-betam', '--beta_mode', default="(1600,2)", type=str, help='(hidden_dim,num_layer) for BetaE relational projection') 86 | parser.add_argument('-boxm', '--box_mode', default="(none,0.02)", type=str, help='(offset activation,center_reg) for Query2box, center_reg balances the in_box dist and out_box dist') 87 | parser.add_argument('-cenr', '--center_reg', default=0.02, type=float, help='center_reg for ConE, center_reg balances the in_cone dist and out_cone dist') 88 | parser.add_argument('-logicm', '--logic_mode', default="(luk,1,1,0,1600,2)", type=str, help='(tnorm,bounded,use_att,use_gtrans,hidden_dim,num_layer)') 89 | parser.add_argument('--prefix', default=None, type=str, help='prefix of the log path') 90 | parser.add_argument('--checkpoint_path', default=None, type=str, help='path for loading the checkpoints') 91 | parser.add_argument('-evu', '--evaluate_union', default="DNF", type=str, choices=['DNF', 'DM'], help='the way to evaluate union queries, transform it to disjunctive normal form (DNF) or use the De Morgan\'s laws (DM)') 92 | 93 | parser.add_argument('--model_mode', default="baseline", type=str, choices=['baseline', 'temp'], help='the type of model') 94 | parser.add_argument('--faithful', default="no_faithful", type=str, choices=['faithful', 'no_faithful'], help='faithful or not') 95 | parser.add_argument('--neighbor_ent_type_samples', type=int, default=32, help='number of sampled entity type neighbors') 96 | parser.add_argument('--neighbor_rel_type_samples', type=int, default=64, help='number of sampled relation type neighbors') 97 | return parser.parse_args(args) 98 | 99 | def save_model(model, optimizer, save_variable_list, args): 100 | ''' 101 | Save the parameters of the model and the optimizer, 102 | as well as some other variables such as step and learning_rate 103 | ''' 104 | 105 | argparse_dict = vars(args) 106 | with open(os.path.join(args.save_path, 'config.json'), 'w') as fjson: 107 | json.dump(argparse_dict, fjson) 108 | 109 | torch.save({ 110 | **save_variable_list, 111 | 'model_state_dict': model.state_dict(), 112 | 'optimizer_state_dict': optimizer.state_dict()}, 113 | os.path.join(args.save_path, 'checkpoint') 114 | ) 115 | 116 | def set_logger(args): 117 | ''' 118 | Write logs to console and log file 119 | ''' 120 | if args.do_train: 121 | log_file = os.path.join(args.save_path, 'train.log') 122 | else: 123 | log_file = os.path.join(args.save_path, 'test.log') 124 | 125 | logging.basicConfig( 126 | format='%(asctime)s %(levelname)-8s %(message)s', 127 | level=logging.INFO, 128 | datefmt='%Y-%m-%d %H:%M:%S', 129 | filename=log_file, 130 | filemode='a+' 131 | ) 132 | if args.print_on_screen: 133 | console = logging.StreamHandler() 134 | console.setLevel(logging.INFO) 135 | formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s') 136 | console.setFormatter(formatter) 137 | logging.getLogger('').addHandler(console) 138 | 139 | def log_metrics(mode, step, metrics): 140 | ''' 141 | Print the evaluation logs 142 | ''' 143 | for metric in metrics: 144 | logging.info('%s %s at step %d: %f' % (mode, metric, step, metrics[metric])) 145 | 146 | def evaluate(model, tp_answers, fn_answers, args, dataloader, query_name_dict, mode, step, writer): 147 | ''' 148 | Evaluate queries in dataloader 149 | ''' 150 | average_metrics = defaultdict(float) 151 | all_metrics = defaultdict(float) 152 | 153 | metrics = model.test_step(model, tp_answers, fn_answers, args, dataloader, query_name_dict) 154 | num_query_structures = 0 155 | num_queries = 0 156 | for query_structure in metrics: 157 | log_metrics(mode+" "+query_name_dict[query_structure], step, metrics[query_structure]) 158 | for metric in metrics[query_structure]: 159 | writer.add_scalar("_".join([mode, query_name_dict[query_structure], metric]), metrics[query_structure][metric], step) 160 | all_metrics["_".join([query_name_dict[query_structure], metric])] = metrics[query_structure][metric] 161 | if metric != 'num_queries': 162 | average_metrics[metric] += metrics[query_structure][metric] 163 | num_queries += metrics[query_structure]['num_queries'] 164 | num_query_structures += 1 165 | 166 | for metric in average_metrics: 167 | average_metrics[metric] /= num_query_structures 168 | writer.add_scalar("_".join([mode, 'average', metric]), average_metrics[metric], step) 169 | all_metrics["_".join(["average", metric])] = average_metrics[metric] 170 | log_metrics('%s average'%mode, step, average_metrics) 171 | 172 | return all_metrics 173 | 174 | def load_data(args, tasks): 175 | ''' 176 | Load queries and remove queries not in tasks 177 | ''' 178 | logging.info("loading data") 179 | train_queries = pickle.load(open(os.path.join(args.data_path, "train-queries.pkl"), 'rb')) 180 | train_answers = pickle.load(open(os.path.join(args.data_path, "train-answers.pkl"), 'rb')) 181 | valid_queries = pickle.load(open(os.path.join(args.data_path, "valid-queries.pkl"), 'rb')) 182 | valid_hard_answers = pickle.load(open(os.path.join(args.data_path, "valid-hard-answers.pkl"), 'rb')) 183 | valid_easy_answers = pickle.load(open(os.path.join(args.data_path, "valid-easy-answers.pkl"), 'rb')) 184 | test_queries = pickle.load(open(os.path.join(args.data_path, "test-queries.pkl"), 'rb')) 185 | test_hard_answers = pickle.load(open(os.path.join(args.data_path, "test-hard-answers.pkl"), 'rb')) 186 | test_easy_answers = pickle.load(open(os.path.join(args.data_path, "test-easy-answers.pkl"), 'rb')) 187 | 188 | if args.faithful == 'faithful': 189 | # Train on all splits to evaluate reasoning faithfulness (entailment) 190 | for queries in [valid_queries, test_queries]: 191 | for query_structure in queries: 192 | train_queries[query_structure] |= queries[query_structure] 193 | 194 | for answers in [valid_hard_answers, valid_easy_answers, test_hard_answers, test_easy_answers]: 195 | for query in answers: 196 | train_answers.setdefault(query, set()) 197 | train_answers[query] |= answers[query] 198 | 199 | # remove tasks not in args.tasks 200 | for name in all_tasks: 201 | if 'u' in name: 202 | name, evaluate_union = name.split('-') 203 | else: 204 | evaluate_union = args.evaluate_union 205 | if name not in tasks or evaluate_union != args.evaluate_union: 206 | query_structure = name_query_dict[name if 'u' not in name else '-'.join([name, evaluate_union])] 207 | if query_structure in train_queries: 208 | del train_queries[query_structure] 209 | if query_structure in valid_queries: 210 | del valid_queries[query_structure] 211 | if query_structure in test_queries: 212 | del test_queries[query_structure] 213 | 214 | return train_queries, train_answers, valid_queries, valid_hard_answers, valid_easy_answers, test_queries, test_hard_answers, test_easy_answers 215 | 216 | def main(args): 217 | set_global_seed(args.seed) 218 | tasks = args.tasks.split('.') 219 | for task in tasks: 220 | if 'n' in task and args.geo in ['box', 'vec']: 221 | assert False, "Q2B and GQE cannot handle queries with negation" 222 | if args.evaluate_union == 'DM': 223 | assert args.geo == 'beta' or args.geo == 'cone' or args.geo == 'logic', "only BetaE supports modeling union using De Morgan's Laws" 224 | 225 | cur_time = parse_time() 226 | if args.prefix is None: 227 | prefix = 'logs' 228 | else: 229 | prefix = args.prefix 230 | 231 | print ("overwritting args.save_path") 232 | args.save_path = os.path.join(args.save_path, args.tasks, args.geo) 233 | if args.geo in ['box']: 234 | tmp_str = "g-{}-mode-{}".format(args.gamma, args.box_mode) 235 | elif args.geo in ['vec']: 236 | tmp_str = "g-{}".format(args.gamma) 237 | elif args.geo == 'beta': 238 | tmp_str = "g-{}-mode-{}".format(args.gamma, args.beta_mode) 239 | elif args.geo == 'cone': 240 | tmp_str = "g-{}-mode-{}".format(args.gamma, args.center_reg) 241 | elif args.geo == 'logic': 242 | tmp_str = "g-{}-mode-{}".format(args.gamma, args.logic_mode) 243 | 244 | if args.checkpoint_path is not None: 245 | args.save_path = args.checkpoint_path 246 | else: 247 | args.save_path = os.path.join(args.save_path, tmp_str, cur_time) 248 | 249 | if not os.path.exists(args.save_path): 250 | os.makedirs(args.save_path) 251 | 252 | print ("logging to", args.save_path) 253 | if not args.do_train: # if not training, then create tensorboard files in some tmp location 254 | writer = SummaryWriter('./logs-debug/unused-tb') 255 | else: 256 | writer = SummaryWriter(args.save_path) 257 | set_logger(args) 258 | 259 | with open('%s/stats.txt'%args.data_path) as f: 260 | entrel = f.readlines() 261 | nentity = int(entrel[0].split(' ')[-1]) 262 | nrelation = int(entrel[1].split(' ')[-1]) 263 | ntype = int(entrel[2].split(' ')[-1]) 264 | 265 | args.nentity = nentity 266 | args.nrelation = nrelation 267 | 268 | logging.info('-------------------------------'*3) 269 | logging.info('Geo: %s' % args.geo) 270 | logging.info('Data Path: %s' % args.data_path) 271 | logging.info('#entity: %d' % nentity) 272 | logging.info('#relation: %d' % nrelation) 273 | logging.info('#max steps: %d' % args.max_steps) 274 | logging.info('Evaluate unoins using: %s' % args.evaluate_union) 275 | 276 | train_queries, train_answers, valid_queries, valid_hard_answers, valid_easy_answers, test_queries, test_hard_answers, test_easy_answers = load_data(args, tasks) 277 | 278 | logging.info("Training info:") 279 | if args.do_train: 280 | for query_structure in train_queries: 281 | logging.info(query_name_dict[query_structure]+": "+str(len(train_queries[query_structure]))) 282 | train_path_queries = defaultdict(set) 283 | train_other_queries = defaultdict(set) 284 | path_list = ['1p', '2p', '3p', '4p', '5p'] 285 | for query_structure in train_queries: 286 | if query_name_dict[query_structure] in path_list: 287 | train_path_queries[query_structure] = train_queries[query_structure] 288 | else: 289 | train_other_queries[query_structure] = train_queries[query_structure] 290 | train_path_queries = flatten_query(train_path_queries) 291 | train_path_iterator = SingledirectionalOneShotIterator(DataLoader( 292 | TrainDataset(train_path_queries, nentity, nrelation, args.negative_sample_size, train_answers), 293 | batch_size=args.batch_size, 294 | shuffle=True, 295 | num_workers=args.cpu_num, 296 | collate_fn=TrainDataset.collate_fn 297 | )) 298 | if len(train_other_queries) > 0: 299 | train_other_queries = flatten_query(train_other_queries) 300 | train_other_iterator = SingledirectionalOneShotIterator(DataLoader( 301 | TrainDataset(train_other_queries, nentity, nrelation, args.negative_sample_size, train_answers), 302 | batch_size=args.batch_size, 303 | shuffle=True, 304 | num_workers=args.cpu_num, 305 | collate_fn=TrainDataset.collate_fn 306 | )) 307 | else: 308 | train_other_iterator = None 309 | 310 | logging.info("Validation info:") 311 | if args.do_valid: 312 | for query_structure in valid_queries: 313 | logging.info(query_name_dict[query_structure]+": "+str(len(valid_queries[query_structure]))) 314 | valid_queries = flatten_query(valid_queries) 315 | valid_dataloader = DataLoader( 316 | TestDataset( 317 | valid_queries, 318 | args.nentity, 319 | args.nrelation, 320 | ), 321 | batch_size=args.test_batch_size, 322 | num_workers=args.cpu_num, 323 | collate_fn=TestDataset.collate_fn 324 | ) 325 | 326 | 327 | logging.info("Test info:") 328 | if args.do_test: 329 | for query_structure in test_queries: 330 | logging.info(query_name_dict[query_structure]+": "+str(len(test_queries[query_structure]))) 331 | test_queries = flatten_query(test_queries) 332 | test_dataloader = DataLoader( 333 | TestDataset( 334 | test_queries, 335 | args.nentity, 336 | args.nrelation, 337 | ), 338 | batch_size=args.test_batch_size, 339 | num_workers=args.cpu_num, 340 | collate_fn=TestDataset.collate_fn 341 | ) 342 | 343 | entity2type, relation2type = build_kg(args.data_path, args.neighbor_ent_type_samples, args.neighbor_rel_type_samples) 344 | 345 | model = KGReasoning( 346 | nentity=nentity, 347 | nrelation=nrelation, 348 | ntype=ntype, 349 | entity2type=entity2type, 350 | relation2type=relation2type, 351 | hidden_dim=args.hidden_dim, 352 | gamma=args.gamma, 353 | geo=args.geo, 354 | use_cuda = args.cuda, 355 | box_mode=eval_tuple(args.box_mode), 356 | beta_mode = eval_tuple(args.beta_mode), 357 | center_reg=args.center_reg, 358 | logic_mode=eval_tuple(args.logic_mode), 359 | model_mode = args.model_mode, 360 | test_batch_size=args.test_batch_size, 361 | query_name_dict = query_name_dict, 362 | drop=args.drop, 363 | neighbor_ent_type_samples=args.neighbor_ent_type_samples 364 | ) 365 | 366 | logging.info('Model Parameter Configuration:') 367 | num_params = 0 368 | for name, param in model.named_parameters(): 369 | logging.info('Parameter %s: %s, require_grad = %s' % (name, str(param.size()), str(param.requires_grad))) 370 | if param.requires_grad: 371 | num_params += np.prod(param.size()) 372 | logging.info('Parameter Number: %d' % num_params) 373 | 374 | if args.cuda: 375 | model = model.cuda() 376 | 377 | if args.do_train: 378 | current_learning_rate = args.learning_rate 379 | optimizer = torch.optim.Adam( 380 | filter(lambda p: p.requires_grad, model.parameters()), 381 | lr=current_learning_rate 382 | ) 383 | warm_up_steps = args.max_steps // 2 384 | 385 | if args.checkpoint_path is not None: 386 | logging.info('Loading checkpoint %s...' % args.checkpoint_path) 387 | checkpoint = torch.load(os.path.join(args.checkpoint_path, 'checkpoint')) 388 | init_step = checkpoint['step'] 389 | model.load_state_dict(checkpoint['model_state_dict']) 390 | 391 | if args.do_train: 392 | current_learning_rate = checkpoint['current_learning_rate'] 393 | warm_up_steps = checkpoint['warm_up_steps'] 394 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 395 | else: 396 | logging.info('Ramdomly Initializing %s Model...' % args.geo) 397 | init_step = 0 398 | 399 | step = init_step 400 | if args.geo == 'box': 401 | logging.info('box mode = %s' % args.box_mode) 402 | elif args.geo == 'beta': 403 | logging.info('beta mode = %s' % args.beta_mode) 404 | elif args.geo == 'cone': 405 | logging.info('cone mode = %s' % args.center_reg) 406 | elif args.geo == 'logic': 407 | logging.info('logic mode = %s (tnorm,bounded,use_att,use_gtrans,hidden_dim,num_layer)' % args.logic_mode) 408 | logging.info('tasks = %s' % args.tasks) 409 | logging.info('init_step = %d' % init_step) 410 | if args.do_train: 411 | logging.info('Start Training...') 412 | logging.info('learning_rate = %d' % current_learning_rate) 413 | logging.info('batch_size = %d' % args.batch_size) 414 | logging.info('hidden_dim = %d' % args.hidden_dim) 415 | logging.info('gamma = %f' % args.gamma) 416 | 417 | if args.do_train: 418 | training_logs = [] 419 | # #Training Loop 420 | for step in range(init_step, args.max_steps): 421 | if step == 2*args.max_steps//3: 422 | args.valid_steps *= 4 423 | 424 | log = model.train_step(model, optimizer, train_path_iterator, args, step) 425 | for metric in log: 426 | writer.add_scalar('path_'+metric, log[metric], step) 427 | if train_other_iterator is not None: 428 | log = model.train_step(model, optimizer, train_other_iterator, args, step) 429 | for metric in log: 430 | writer.add_scalar('other_'+metric, log[metric], step) 431 | log = model.train_step(model, optimizer, train_path_iterator, args, step) 432 | 433 | training_logs.append(log) 434 | 435 | if step >= warm_up_steps: 436 | current_learning_rate = current_learning_rate / 5 437 | logging.info('Change learning_rate to %f at step %d' % (current_learning_rate, step)) 438 | optimizer = torch.optim.Adam( 439 | filter(lambda p: p.requires_grad, model.parameters()), 440 | lr=current_learning_rate 441 | ) 442 | warm_up_steps = warm_up_steps * 1.5 443 | 444 | if step % args.save_checkpoint_steps == 0: 445 | save_variable_list = { 446 | 'step': step, 447 | 'current_learning_rate': current_learning_rate, 448 | 'warm_up_steps': warm_up_steps 449 | } 450 | save_model(model, optimizer, save_variable_list, args) 451 | 452 | if step % args.valid_steps == 0 and step > 0: 453 | if args.do_valid: 454 | logging.info('Evaluating on Valid Dataset...') 455 | valid_all_metrics = evaluate(model, valid_easy_answers, valid_hard_answers, args, valid_dataloader, query_name_dict, 'Valid', step, writer) 456 | 457 | if args.do_test: 458 | logging.info('Evaluating on Test Dataset...') 459 | test_all_metrics = evaluate(model, test_easy_answers, test_hard_answers, args, test_dataloader, query_name_dict, 'Test', step, writer) 460 | 461 | if step % args.log_steps == 0: 462 | metrics = {} 463 | for metric in training_logs[0].keys(): 464 | metrics[metric] = sum([log[metric] for log in training_logs])/len(training_logs) 465 | 466 | log_metrics('Training average', step, metrics) 467 | training_logs = [] 468 | 469 | save_variable_list = { 470 | 'step': step, 471 | 'current_learning_rate': current_learning_rate, 472 | 'warm_up_steps': warm_up_steps 473 | } 474 | save_model(model, optimizer, save_variable_list, args) 475 | 476 | try: 477 | print (step) 478 | except: 479 | step = 0 480 | 481 | if args.do_test: 482 | logging.info('Evaluating on Test Dataset...') 483 | test_all_metrics = evaluate(model, test_easy_answers, test_hard_answers, args, test_dataloader, query_name_dict, 'Test', step, writer) 484 | 485 | logging.info("Training finished!!") 486 | 487 | def build_kg(data_path, neighbor_ent_type_samples, neighbor_rel_type_samples): 488 | entity_type_mapping = np.load(data_path + '/entity_type.npy', allow_pickle=True) 489 | entity2types = [] 490 | for i in range(len(entity_type_mapping)): 491 | sampled_types = np.random.choice(entity_type_mapping[i], size=neighbor_ent_type_samples, 492 | replace=len(entity_type_mapping[i]) < neighbor_ent_type_samples) 493 | entity2types.append(sampled_types) 494 | 495 | relation_type_mapping = np.load(data_path + '/relation_type.npy', allow_pickle=True) 496 | relation2types = [] 497 | for i in range(len(relation_type_mapping)): 498 | sampled_types = np.random.choice(relation_type_mapping[i], size=neighbor_rel_type_samples, 499 | replace=len(relation_type_mapping[i]) < neighbor_rel_type_samples) 500 | relation2types.append(sampled_types) 501 | return entity2types, relation2types 502 | 503 | if __name__ == '__main__': 504 | main(parse_args()) -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import logging 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import collections 12 | from tqdm import tqdm 13 | from type_aggregator import EntityTypeAggregator, RelationTypeAggregator, Match 14 | 15 | pi = 3.14159265358979323846 16 | eps = 1e-6 17 | 18 | query_name_dict = {('e',('r',)): '1p', 19 | ('e', ('r', 'r')): '2p', 20 | ('e', ('r', 'r', 'r')): '3p', 21 | ('e', ('r', 'r', 'r', 'r')): '4p', 22 | ('e', ('r', 'r', 'r', 'r', 'r')): '5p', 23 | (('e', ('r',)), ('e', ('r',))): '2i', 24 | (('e', ('r',)), ('e', ('r',)), ('e', ('r',))): '3i', 25 | ((('e', ('r',)), ('e', ('r',))), ('r',)): 'ip', 26 | (('e', ('r', 'r')), ('e', ('r',))): 'pi', 27 | (('e', ('r',)), ('e', ('r', 'n'))): '2in', 28 | (('e', ('r',)), ('e', ('r',)), ('e', ('r', 'n'))): '3in', 29 | ((('e', ('r',)), ('e', ('r', 'n'))), ('r',)): 'inp', 30 | (('e', ('r', 'r')), ('e', ('r', 'n'))): 'pin', 31 | (('e', ('r', 'r', 'n')), ('e', ('r',))): 'pni', 32 | (('e', ('r',)), ('e', ('r',)), ('u',)): '2u-DNF', 33 | ((('e', ('r',)), ('e', ('r',)), ('u',)), ('r',)): 'up-DNF', 34 | ((('e', ('r', 'n')), ('e', ('r', 'n'))), ('n',)): '2u-DM', 35 | ((('e', ('r', 'n')), ('e', ('r', 'n'))), ('n', 'r')): 'up-DM' 36 | } 37 | 38 | def Identity(x): 39 | return x 40 | 41 | class BoxOffsetIntersection(nn.Module): 42 | 43 | def __init__(self, dim): 44 | super(BoxOffsetIntersection, self).__init__() 45 | self.dim = dim 46 | self.layer1 = nn.Linear(self.dim, self.dim) 47 | self.layer2 = nn.Linear(self.dim, self.dim) 48 | 49 | nn.init.xavier_uniform_(self.layer1.weight) 50 | nn.init.xavier_uniform_(self.layer2.weight) 51 | 52 | def forward(self, embeddings): 53 | layer1_act = F.relu(self.layer1(embeddings)) 54 | layer1_mean = torch.mean(layer1_act, dim=0) 55 | gate = torch.sigmoid(self.layer2(layer1_mean)) 56 | offset, _ = torch.min(embeddings, dim=0) 57 | 58 | return offset * gate 59 | 60 | class CenterIntersection(nn.Module): 61 | 62 | def __init__(self, dim): 63 | super(CenterIntersection, self).__init__() 64 | self.dim = dim 65 | self.layer1 = nn.Linear(self.dim, self.dim) 66 | self.layer2 = nn.Linear(self.dim, self.dim) 67 | 68 | nn.init.xavier_uniform_(self.layer1.weight) 69 | nn.init.xavier_uniform_(self.layer2.weight) 70 | 71 | def forward(self, embeddings): 72 | layer1_act = F.relu(self.layer1(embeddings)) # (num_conj, dim) 73 | attention = F.softmax(self.layer2(layer1_act), dim=0) # (num_conj, dim) 74 | embedding = torch.sum(attention * embeddings, dim=0) 75 | 76 | return embedding 77 | 78 | class BetaIntersection(nn.Module): 79 | 80 | def __init__(self, dim): 81 | super(BetaIntersection, self).__init__() 82 | self.dim = dim 83 | self.layer1 = nn.Linear(2 * self.dim, 2 * self.dim) 84 | self.layer2 = nn.Linear(2 * self.dim, self.dim) 85 | 86 | nn.init.xavier_uniform_(self.layer1.weight) 87 | nn.init.xavier_uniform_(self.layer2.weight) 88 | 89 | def forward(self, alpha_embeddings, beta_embeddings): 90 | all_embeddings = torch.cat([alpha_embeddings, beta_embeddings], dim=-1) 91 | layer1_act = F.relu(self.layer1(all_embeddings)) # (num_conj, batch_size, 2 * dim) 92 | attention = F.softmax(self.layer2(layer1_act), dim=0) # (num_conj, batch_size, dim) 93 | 94 | alpha_embedding = torch.sum(attention * alpha_embeddings, dim=0) 95 | beta_embedding = torch.sum(attention * beta_embeddings, dim=0) 96 | 97 | return alpha_embedding, beta_embedding 98 | 99 | class BetaProjection(nn.Module): 100 | def __init__(self, entity_dim, relation_dim, hidden_dim, projection_regularizer, num_layers, with_regular=True): 101 | super(BetaProjection, self).__init__() 102 | self.entity_dim = entity_dim 103 | self.relation_dim = relation_dim 104 | self.hidden_dim = hidden_dim 105 | self.num_layers = num_layers 106 | self.layer1 = nn.Linear(self.entity_dim + self.relation_dim, self.hidden_dim) # 1st layer 107 | self.layer0 = nn.Linear(self.hidden_dim, self.entity_dim) # final layer 108 | for nl in range(2, num_layers + 1): 109 | setattr(self, "layer{}".format(nl), nn.Linear(self.hidden_dim, self.hidden_dim)) 110 | for nl in range(num_layers + 1): 111 | nn.init.xavier_uniform_(getattr(self, "layer{}".format(nl)).weight) 112 | self.projection_regularizer = projection_regularizer 113 | self.with_regular = with_regular 114 | 115 | def forward(self, e_embedding, r_embedding): 116 | x = torch.cat([e_embedding, r_embedding], dim=-1) 117 | for nl in range(1, self.num_layers + 1): 118 | x = F.relu(getattr(self, "layer{}".format(nl))(x)) 119 | x = self.layer0(x) 120 | x = self.projection_regularizer(x) 121 | if self.with_regular == True: 122 | x = self.projection_regularizer(x) 123 | 124 | return x 125 | 126 | class Regularizer(): 127 | def __init__(self, base_add, min_val, max_val): 128 | self.base_add = base_add 129 | self.min_val = min_val 130 | self.max_val = max_val 131 | 132 | def __call__(self, entity_embedding): 133 | return torch.clamp(entity_embedding + self.base_add, self.min_val, self.max_val) 134 | 135 | def convert_to_arg(x): 136 | y = torch.tanh(2 * x) * pi / 2 + pi / 2 137 | return y 138 | 139 | def convert_to_axis(x): 140 | y = torch.tanh(x) * pi 141 | return y 142 | 143 | class AngleScale: 144 | def __init__(self, embedding_range): 145 | self.embedding_range = embedding_range 146 | 147 | def __call__(self, axis_embedding, scale=None): 148 | if scale is None: 149 | scale = pi 150 | return axis_embedding / self.embedding_range * scale 151 | 152 | class ConeProjection(nn.Module): 153 | def __init__(self, dim, hidden_dim, num_layers, with_regular=True): 154 | super(ConeProjection, self).__init__() 155 | self.entity_dim = dim 156 | self.relation_dim = dim 157 | self.hidden_dim = hidden_dim 158 | self.num_layers = num_layers 159 | self.layer1 = nn.Linear(self.entity_dim + self.relation_dim, self.hidden_dim) 160 | self.layer0 = nn.Linear(self.hidden_dim, self.entity_dim + self.relation_dim) 161 | self.with_regular = with_regular 162 | for nl in range(2, num_layers + 1): 163 | setattr(self, "layer{}".format(nl), nn.Linear(self.hidden_dim, self.hidden_dim)) 164 | for nl in range(num_layers + 1): 165 | nn.init.xavier_uniform_(getattr(self, "layer{}".format(nl)).weight) 166 | 167 | def forward(self, source_embedding_axis, source_embedding_arg, r_embedding_axis, r_embedding_arg): 168 | x = torch.cat([source_embedding_axis + r_embedding_axis, source_embedding_arg + r_embedding_arg], dim=-1) 169 | for nl in range(1, self.num_layers + 1): 170 | x = F.relu(getattr(self, "layer{}".format(nl))(x)) 171 | x = self.layer0(x) 172 | 173 | axis, arg = torch.chunk(x, 2, dim=-1) 174 | if self.with_regular: 175 | axis_embeddings = convert_to_axis(axis) 176 | arg_embeddings = convert_to_arg(arg) 177 | else: 178 | axis_embeddings = axis 179 | arg_embeddings = arg 180 | return axis_embeddings, arg_embeddings 181 | 182 | class ConeIntersection(nn.Module): 183 | def __init__(self, dim, drop): 184 | super(ConeIntersection, self).__init__() 185 | self.dim = dim 186 | self.layer_axis1 = nn.Linear(self.dim * 2, self.dim) 187 | self.layer_arg1 = nn.Linear(self.dim * 2, self.dim) 188 | self.layer_axis2 = nn.Linear(self.dim, self.dim) 189 | self.layer_arg2 = nn.Linear(self.dim, self.dim) 190 | 191 | nn.init.xavier_uniform_(self.layer_axis1.weight) 192 | nn.init.xavier_uniform_(self.layer_arg1.weight) 193 | nn.init.xavier_uniform_(self.layer_axis2.weight) 194 | nn.init.xavier_uniform_(self.layer_arg2.weight) 195 | 196 | self.drop = nn.Dropout(p=drop) 197 | 198 | def forward(self, axis_embeddings, arg_embeddings): 199 | logits = torch.cat([axis_embeddings - arg_embeddings, axis_embeddings + arg_embeddings], dim=-1) 200 | axis_layer1_act = F.relu(self.layer_axis1(logits)) 201 | 202 | axis_attention = F.softmax(self.layer_axis2(axis_layer1_act), dim=0) 203 | 204 | x_embeddings = torch.cos(axis_embeddings) 205 | y_embeddings = torch.sin(axis_embeddings) 206 | x_embeddings = torch.sum(axis_attention * x_embeddings, dim=0) 207 | y_embeddings = torch.sum(axis_attention * y_embeddings, dim=0) 208 | 209 | # when x_embeddings are very closed to zero, the tangent may be nan 210 | # no need to consider the sign of x_embeddings 211 | x_embeddings[torch.abs(x_embeddings) < 1e-3] = 1e-3 212 | 213 | axis_embeddings = torch.atan(y_embeddings / x_embeddings) 214 | 215 | indicator_x = x_embeddings < 0 216 | indicator_y = y_embeddings < 0 217 | indicator_two = indicator_x & torch.logical_not(indicator_y) 218 | indicator_three = indicator_x & indicator_y 219 | 220 | axis_embeddings[indicator_two] = axis_embeddings[indicator_two] + pi 221 | axis_embeddings[indicator_three] = axis_embeddings[indicator_three] - pi 222 | 223 | # DeepSets 224 | arg_layer1_act = F.relu(self.layer_arg1(logits)) 225 | arg_layer1_mean = torch.mean(arg_layer1_act, dim=0) 226 | gate = torch.sigmoid(self.layer_arg2(arg_layer1_mean)) 227 | 228 | arg_embeddings = self.drop(arg_embeddings) 229 | arg_embeddings, _ = torch.min(arg_embeddings, dim=0) 230 | arg_embeddings = arg_embeddings * gate 231 | 232 | return axis_embeddings, arg_embeddings 233 | 234 | class ConeNegation(nn.Module): 235 | def __init__(self): 236 | super(ConeNegation, self).__init__() 237 | 238 | def forward(self, axis_embedding, arg_embedding): 239 | indicator_positive = axis_embedding >= 0 240 | indicator_negative = axis_embedding < 0 241 | axis_embedding[indicator_positive] = axis_embedding[indicator_positive] - pi 242 | axis_embedding[indicator_negative] = axis_embedding[indicator_negative] + pi 243 | arg_embedding = pi - arg_embedding 244 | return axis_embedding, arg_embedding 245 | 246 | def order_bounds(embedding): # ensure lower < upper truth bound for logic embedding 247 | embedding = torch.clamp(embedding, 0, 1) 248 | lower, upper = torch.chunk(embedding, 2, dim=-1) 249 | contra = lower > upper 250 | if contra.any(): # contradiction 251 | mean = (lower + upper) / 2 252 | lower = torch.where(lower > upper, mean, lower) 253 | upper = torch.where(lower > upper, mean, upper) 254 | ordered_embedding = torch.cat([lower, upper], dim=-1) 255 | return ordered_embedding 256 | 257 | def valclamp(x, a=1, b=6, lo=0, hi=1): # relu1 with gradient-transparent clamp on negative 258 | elu_neg = a * (torch.exp(b * x) - 1) 259 | return ((x < lo).float() * (lo + elu_neg - elu_neg.detach()) + 260 | (lo <= x).float() * (x <= hi).float() * x + 261 | (hi < x).float()) 262 | 263 | class LogicIntersection(nn.Module): 264 | 265 | def __init__(self, dim, tnorm, bounded, use_att, use_gtrans): 266 | super(LogicIntersection, self).__init__() 267 | self.dim = dim 268 | self.tnorm = tnorm 269 | self.bounded = bounded 270 | self.use_att = use_att 271 | self.use_gtrans = use_gtrans # gradient transparency 272 | 273 | if use_att: # use attention with weighted t-norm 274 | self.layer1 = nn.Linear(2 * self.dim, 2 * self.dim) 275 | 276 | if bounded: 277 | self.layer2 = nn.Linear(2 * self.dim, self.dim) # same weight for bound pair 278 | else: 279 | self.layer2 = nn.Linear(2 * self.dim, 2 * self.dim) 280 | 281 | nn.init.xavier_uniform_(self.layer1.weight) 282 | nn.init.xavier_uniform_(self.layer2.weight) 283 | 284 | def forward(self, embeddings): 285 | if self.use_att: # use attention with weighted t-norm 286 | layer1_act = F.relu(self.layer1(embeddings)) # (num_conj, batch_size, 2 * dim) 287 | attention = F.softmax(self.layer2(layer1_act), dim=0) # (num_conj, batch_size, dim) 288 | attention = attention / torch.max(attention, dim=0, keepdim=True).values 289 | 290 | if self.bounded: # same weight for bound pair 291 | attention = torch.cat([attention, attention], dim=-1) 292 | 293 | if self.tnorm == 'mins': # minimum / Godel t-norm 294 | smooth_param = -10 # smooth minimum 295 | min_weights = attention * torch.exp(smooth_param * embeddings) 296 | embedding = torch.sum(min_weights * embeddings, dim=0) / torch.sum(min_weights, dim=0) 297 | if self.bounded: 298 | embedding = order_bounds(embedding) 299 | 300 | elif self.tnorm == 'luk': # Lukasiewicz t-norm 301 | embedding = 1 - torch.sum(attention * (1 - embeddings), dim=0) 302 | if self.use_gtrans: 303 | embedding = valclamp(embedding, b=6. / embedding.shape[0]) 304 | else: 305 | embedding = torch.clamp(embedding, 0, 1) 306 | 307 | elif self.tnorm == 'prod': # product t-norm 308 | embedding = torch.prod(torch.pow(torch.clamp(embeddings, 0, 1) + eps, attention), dim=0) 309 | 310 | else: # no attention 311 | if self.tnorm == 'mins': # minimum / Godel t-norm 312 | smooth_param = -10 # smooth minimum 313 | min_weights = torch.exp(smooth_param * embeddings) 314 | embedding = torch.sum(min_weights * embeddings, dim=0) / torch.sum(min_weights, dim=0) 315 | if self.bounded: 316 | embedding = order_bounds(embedding) 317 | 318 | elif self.tnorm == 'luk': # Lukasiewicz t-norm 319 | embedding = 1 - torch.sum(1 - embeddings, dim=0) 320 | if self.use_gtrans: 321 | embedding = valclamp(embedding, b=6. / embedding.shape[0]) 322 | else: 323 | embedding = torch.clamp(embedding, 0, 1) 324 | 325 | elif self.tnorm == 'prod': # product t-norm 326 | embedding = torch.prod(embeddings, dim=0) 327 | 328 | return embedding 329 | 330 | class LogicProjection(nn.Module): 331 | def __init__(self, entity_dim, relation_dim, hidden_dim, num_layers, bounded, with_sigmoid=False): 332 | super(LogicProjection, self).__init__() 333 | self.entity_dim = entity_dim 334 | self.relation_dim = relation_dim 335 | self.hidden_dim = hidden_dim 336 | self.num_layers = num_layers 337 | self.bounded = bounded 338 | self.layer1 = nn.Linear(self.entity_dim + self.relation_dim, self.hidden_dim) # 1st layer 339 | self.layer0 = nn.Linear(self.hidden_dim, self.entity_dim) # final layer 340 | self.with_sigmoid = with_sigmoid 341 | for nl in range(2, num_layers + 1): 342 | setattr(self, "layer{}".format(nl), nn.Linear(self.hidden_dim, self.hidden_dim)) 343 | for nl in range(num_layers + 1): 344 | nn.init.xavier_uniform_(getattr(self, "layer{}".format(nl)).weight) 345 | 346 | def forward(self, e_embedding, r_embedding): 347 | x = torch.cat([e_embedding, r_embedding], dim=-1) 348 | for nl in range(1, self.num_layers + 1): 349 | x = F.relu(getattr(self, "layer{}".format(nl))(x)) 350 | x = self.layer0(x) 351 | if self.with_sigmoid: 352 | x = torch.sigmoid(x) 353 | 354 | if self.bounded: 355 | lower, upper = torch.chunk(x, 2, dim=-1) 356 | upper = lower + upper * (1 - lower) 357 | x = torch.cat([lower, upper], dim=-1) 358 | 359 | return x 360 | 361 | class SizePredict(nn.Module): 362 | def __init__(self, entity_dim): 363 | super(SizePredict, self).__init__() 364 | 365 | self.layer2 = nn.Linear(entity_dim, entity_dim // 4) 366 | self.layer1 = nn.Linear(entity_dim // 4, entity_dim // 16) 367 | self.layer0 = nn.Linear(entity_dim // 16, 1) 368 | 369 | nn.init.xavier_uniform_(self.layer2.weight) 370 | nn.init.xavier_uniform_(self.layer1.weight) 371 | nn.init.xavier_uniform_(self.layer0.weight) 372 | 373 | def forward(self, entropy_embedding): 374 | x = self.layer2(entropy_embedding) 375 | x = F.relu(x) 376 | x = self.layer1(x) 377 | x = F.relu(x) 378 | x = self.layer0(x) 379 | x = torch.sigmoid(x) 380 | 381 | return x.squeeze() 382 | 383 | class KGReasoning(nn.Module): 384 | def __init__(self, nentity, nrelation, ntype, hidden_dim, entity2type, relation2type, gamma, 385 | geo, test_batch_size=1, 386 | box_mode=None, use_cuda=False, 387 | query_name_dict=None, beta_mode=None, center_reg=None, logic_mode=None, model_mode='baseline', drop=0., neighbor_ent_type_samples=32): 388 | super(KGReasoning, self).__init__() 389 | self.nentity = nentity 390 | self.nrelation = nrelation 391 | self.ntype = ntype 392 | self.hidden_dim = hidden_dim 393 | self.epsilon = 2.0 394 | self.geo = geo 395 | self.use_cuda = use_cuda 396 | self.batch_entity_range = torch.arange(nentity).to(torch.float).repeat(test_batch_size, 1).cuda() if self.use_cuda else torch.arange(nentity).to(torch.float).repeat(test_batch_size, 1) # used in test_step 397 | self.query_name_dict = query_name_dict 398 | 399 | self.entity2type = torch.tensor(entity2type).cuda() 400 | self.relation2type = torch.tensor(relation2type).cuda() 401 | self.neighbor_ent_type_samples = neighbor_ent_type_samples 402 | if self.geo == 'vec' or self.geo == 'box' or self.geo == 'cone': 403 | self.ent_neighbor_type_agg = EntityTypeAggregator(input_dim=hidden_dim, output_dim=hidden_dim, self_included=True, neighbor_ent_type_samples=self.neighbor_ent_type_samples) 404 | elif self.geo == 'beta' or self.geo == 'logic': 405 | self.ent_neighbor_type_agg = EntityTypeAggregator(input_dim=hidden_dim*2, output_dim=hidden_dim*2, self_included=True, neighbor_ent_type_samples=self.neighbor_ent_type_samples) 406 | 407 | if self.geo == 'logic': 408 | self.rel_neighbor_type_agg = RelationTypeAggregator(hidden_dim * 2) 409 | else: 410 | self.rel_neighbor_type_agg = RelationTypeAggregator(hidden_dim) 411 | if self.geo == 'logic': 412 | self.match = Match(hidden_dim * 2) 413 | else: 414 | self.match = Match(hidden_dim) 415 | self.model_mode = model_mode 416 | 417 | self.gamma = nn.Parameter( 418 | torch.Tensor([gamma]), 419 | requires_grad=False 420 | ) 421 | 422 | self.embedding_range = nn.Parameter( 423 | torch.Tensor([(self.gamma.item() + self.epsilon) / hidden_dim]), 424 | requires_grad=False 425 | ) 426 | 427 | self.entity_dim = hidden_dim 428 | self.relation_dim = hidden_dim 429 | self.type_dim = hidden_dim 430 | 431 | self.cen = center_reg 432 | 433 | if self.geo == 'box': 434 | self.entity_embedding = nn.Parameter(torch.zeros(nentity, self.entity_dim)) # centor for entities 435 | activation, cen = box_mode 436 | self.cen = cen # hyperparameter that balances the in-box distance and the out-box distance 437 | if activation == 'none': 438 | self.func = Identity 439 | elif activation == 'relu': 440 | self.func = F.relu 441 | elif activation == 'softplus': 442 | self.func = F.softplus 443 | elif self.geo == 'vec': 444 | self.entity_embedding = nn.Parameter(torch.zeros(nentity, self.entity_dim)) # center for entities 445 | elif self.geo == 'beta': 446 | self.entity_embedding = nn.Parameter(torch.zeros(nentity, self.entity_dim * 2)) # alpha and beta 447 | self.entity_regularizer = Regularizer(1, 0.05, 1e9) # make sure the parameters of beta embeddings are positive 448 | self.projection_regularizer = Regularizer(1, 0.05, 1e9) # make sure the parameters of beta embeddings after relation projection are positive 449 | elif self.geo == 'cone': 450 | self.entity_embedding = nn.Parameter(torch.zeros(nentity, self.entity_dim), requires_grad=True) # axis for entities 451 | self.angle_scale = AngleScale(self.embedding_range.item()) # scale axis embeddings to [-pi, pi] 452 | self.modulus = nn.Parameter(torch.Tensor([0.5 * self.embedding_range.item()]), requires_grad=True) 453 | self.axis_scale = 1.0 454 | self.arg_scale = 1.0 455 | elif self.geo == 'logic': 456 | self.tnorm, self.bounded, use_att, use_gtrans, hidden_dim, num_layers = logic_mode 457 | if self.bounded: 458 | lower = torch.rand((nentity, self.entity_dim)) 459 | upper = lower + torch.rand((nentity, self.entity_dim)) * (1 - lower) 460 | self.entity_embedding = nn.Parameter(torch.cat([lower, upper], dim=-1)) 461 | else: 462 | self.entity_embedding = nn.Parameter(torch.rand((nentity, self.entity_dim * 2))) 463 | 464 | if self.geo in ['box', 'vec', 'beta', 'cone']: 465 | nn.init.uniform_( 466 | tensor=self.entity_embedding, 467 | a=-self.embedding_range.item(), 468 | b=self.embedding_range.item() 469 | ) 470 | 471 | if self.geo == 'beta' or self.geo == 'logic': 472 | self.relation_embedding = nn.Parameter(torch.zeros(nrelation, self.relation_dim * 2)) 473 | else: 474 | self.relation_embedding = nn.Parameter(torch.zeros(nrelation, self.relation_dim)) 475 | 476 | if self.geo in ['box', 'vec', 'beta', 'cone']: 477 | nn.init.uniform_( 478 | tensor=self.relation_embedding, 479 | a=-self.embedding_range.item(), 480 | b=self.embedding_range.item() 481 | ) 482 | 483 | if self.geo == 'box': 484 | self.offset_embedding = nn.Parameter(torch.zeros(nrelation, self.entity_dim)) 485 | nn.init.uniform_( 486 | tensor=self.offset_embedding, 487 | a=0., 488 | b=self.embedding_range.item() 489 | ) 490 | self.center_net = CenterIntersection(self.entity_dim) 491 | self.offset_net = BoxOffsetIntersection(self.entity_dim) 492 | 493 | self.type_embedding = nn.Parameter(torch.zeros(self.ntype, self.type_dim)) 494 | nn.init.uniform_( 495 | tensor=self.type_embedding, 496 | a=-self.embedding_range.item(), 497 | b=self.embedding_range.item() 498 | ) 499 | elif self.geo == 'vec': 500 | self.center_net = CenterIntersection(self.entity_dim) 501 | self.type_embedding = nn.Parameter(torch.zeros(self.ntype, self.type_dim)) 502 | nn.init.uniform_( 503 | tensor=self.type_embedding, 504 | a=-self.embedding_range.item(), 505 | b=self.embedding_range.item() 506 | ) 507 | elif self.geo == 'beta': 508 | hidden_dim, num_layers = beta_mode 509 | self.center_net = BetaIntersection(self.entity_dim) 510 | self.relation_center_net = CenterIntersection(self.relation_dim) 511 | self.projection_net = BetaProjection(self.entity_dim * 2, 512 | self.relation_dim * 2, 513 | hidden_dim, 514 | self.projection_regularizer, 515 | num_layers) 516 | self.projection_without_net = BetaProjection(self.entity_dim * 2, 517 | self.relation_dim * 2, 518 | hidden_dim, 519 | self.projection_regularizer, 520 | num_layers, with_regular=False) 521 | self.type_embedding = nn.Parameter(torch.zeros(self.ntype, self.type_dim * 2)) 522 | nn.init.uniform_( 523 | tensor=self.type_embedding, 524 | a=-self.embedding_range.item(), 525 | b=self.embedding_range.item() 526 | ) 527 | elif self.geo == 'cone': 528 | self.axis_embedding = nn.Parameter(torch.zeros(nrelation, self.relation_dim), requires_grad=True) 529 | nn.init.uniform_( 530 | tensor=self.axis_embedding, 531 | a=-self.embedding_range.item(), 532 | b=self.embedding_range.item() 533 | ) 534 | self.arg_embedding = nn.Parameter(torch.zeros(nrelation, self.relation_dim), requires_grad=True) 535 | nn.init.uniform_( 536 | tensor=self.arg_embedding, 537 | a=-self.embedding_range.item(), 538 | b=self.embedding_range.item() 539 | ) 540 | self.cone_proj = ConeProjection(self.entity_dim, 1600, 2) 541 | self.cone_intersection = ConeIntersection(self.entity_dim, drop) 542 | self.cone_negation = ConeNegation() 543 | self.relation_center_net = CenterIntersection(self.entity_dim) 544 | self.cone_without_proj = ConeProjection(self.entity_dim, 1600, 2, with_regular=False) 545 | self.type_embedding = nn.Parameter(torch.zeros(self.ntype, self.type_dim), requires_grad=True) 546 | nn.init.uniform_( 547 | tensor=self.type_embedding, 548 | a=-self.embedding_range.item(), 549 | b=self.embedding_range.item() 550 | ) 551 | elif self.geo == 'logic': 552 | tnorm, bounded, use_att, use_gtrans, hidden_dim, num_layers = logic_mode 553 | self.center_net = LogicIntersection(self.entity_dim, tnorm, bounded, use_att, use_gtrans) 554 | self.relation_center_net = CenterIntersection(self.entity_dim * 2) 555 | self.projection_net = LogicProjection(self.entity_dim * 2, 556 | self.relation_dim * 2, 557 | hidden_dim, 558 | num_layers, 559 | bounded, with_sigmoid=True) 560 | self.projection_without_net = LogicProjection(self.entity_dim, 561 | self.relation_dim, 562 | hidden_dim, 563 | num_layers, 564 | bounded, with_sigmoid=False) 565 | self.type_embedding = nn.Parameter(torch.rand(self.ntype, self.type_dim * 2)) 566 | 567 | def forward(self, positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict): 568 | if self.geo == 'box': 569 | return self.forward_box(positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict) 570 | elif self.geo == 'vec': 571 | return self.forward_vec(positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict) 572 | elif self.geo == 'beta': 573 | return self.forward_beta(positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict) 574 | elif self.geo == 'cone': 575 | return self.forward_cone(positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict) 576 | elif self.geo == 'logic': 577 | return self.forward_logic(positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict) 578 | 579 | def embed_query_vec(self, queries, query_structure, idx): 580 | ''' 581 | Iterative embed a batch of queries with same structure using GQE 582 | queries: a flattened batch of queries 583 | ''' 584 | all_relation_flag = True 585 | for ele in query_structure[-1]: # whether the current query tree has merged to one branch and only need to do relation traversal, e.g., path queries or conjunctive queries after the intersection 586 | if ele not in ['r', 'n']: 587 | all_relation_flag = False 588 | break 589 | if all_relation_flag: 590 | if query_structure[0] == 'e': 591 | if self.model_mode == 'baseline': 592 | embedding = torch.index_select(self.entity_embedding, dim=0, index=queries[:, idx]) 593 | elif self.model_mode == 'temp': 594 | embedding = torch.index_select(self.entity_embedding, dim=0, index=queries[:, idx]) 595 | ent_type_id = torch.index_select(self.entity2type, dim=0, index=queries[:, idx]) 596 | entity_neighbor_type_embedding = torch.index_select(self.type_embedding, dim=0, index=ent_type_id.view(-1)).view(ent_type_id.shape[0], ent_type_id.shape[1], -1) 597 | embedding = self.ent_neighbor_type_agg(embedding, entity_neighbor_type_embedding) 598 | 599 | idx += 1 600 | else: 601 | embedding, idx = self.embed_query_vec(queries, query_structure[0], idx) 602 | for i in range(len(query_structure[-1])): 603 | if query_structure[-1][i] == 'n': 604 | assert False, "vec cannot handle queries with negation" 605 | else: 606 | if self.model_mode == 'baseline': 607 | r_embedding = torch.index_select(self.relation_embedding, dim=0, index=queries[:, idx]) 608 | embedding += r_embedding 609 | elif self.model_mode == 'temp': 610 | r_embedding = torch.index_select(self.relation_embedding, dim=0, index=queries[:, idx]) 611 | rel_type_id = torch.index_select(self.relation2type, dim=0, index=queries[:, idx]) 612 | relation_neighbor_type_embedding = torch.index_select(self.type_embedding, dim=0, index=rel_type_id.view(-1)).view(rel_type_id.shape[0], rel_type_id.shape[1], -1) 613 | relation_neighbor_type_embedding = torch.transpose(relation_neighbor_type_embedding, 0, 1) 614 | rel_type_cent = self.center_net(relation_neighbor_type_embedding) 615 | 616 | embedding = embedding.unsqueeze(1) 617 | r_embedding = r_embedding.unsqueeze(1) 618 | rel_type_cent = rel_type_cent.unsqueeze(1) 619 | ent_rel = self.match([embedding, r_embedding]) 620 | ent_rel_type = self.match([embedding, rel_type_cent]) 621 | rel_ent = self.match([r_embedding, embedding]) 622 | rel_rel_type = self.match([r_embedding, rel_type_cent]) 623 | 624 | embedding = self.rel_neighbor_type_agg([ent_rel.squeeze(1), ent_rel_type.squeeze(1)]) 625 | r_embedding = self.rel_neighbor_type_agg([rel_ent.squeeze(1), rel_rel_type.squeeze(1)]) 626 | embedding += r_embedding 627 | 628 | idx += 1 629 | else: 630 | embedding_list = [] 631 | for i in range(len(query_structure)): 632 | embedding, idx = self.embed_query_vec(queries, query_structure[i], idx) 633 | embedding_list.append(embedding) 634 | embedding = self.center_net(torch.stack(embedding_list)) 635 | 636 | return embedding, idx 637 | 638 | def embed_query_box(self, queries, query_structure, idx): 639 | ''' 640 | Iterative embed a batch of queries with same structure using Query2box 641 | queries: a flattened batch of queries 642 | ''' 643 | all_relation_flag = True 644 | for ele in query_structure[-1]: # whether the current query tree has merged to one branch and only need to do relation traversal, e.g., path queries or conjunctive queries after the intersection 645 | if ele not in ['r', 'n']: 646 | all_relation_flag = False 647 | break 648 | if all_relation_flag: 649 | if query_structure[0] == 'e': 650 | if self.model_mode == 'baseline': 651 | embedding = torch.index_select(self.entity_embedding, dim=0, index=queries[:, idx]) 652 | elif self.model_mode == 'temp': 653 | embedding = torch.index_select(self.entity_embedding, dim=0, index=queries[:, idx]) 654 | ent_type_id = torch.index_select(self.entity2type, dim=0, index=queries[:, idx]) 655 | entity_neighbor_type_embedding = torch.index_select(self.type_embedding, dim=0, index=ent_type_id.view(-1)).view(ent_type_id.shape[0], ent_type_id.shape[1], -1) 656 | embedding = self.ent_neighbor_type_agg(embedding, entity_neighbor_type_embedding) 657 | 658 | if self.use_cuda: 659 | offset_embedding = torch.zeros_like(embedding).cuda() 660 | else: 661 | offset_embedding = torch.zeros_like(embedding) 662 | idx += 1 663 | else: 664 | embedding, offset_embedding, idx = self.embed_query_box(queries, query_structure[0], idx) 665 | for i in range(len(query_structure[-1])): 666 | if query_structure[-1][i] == 'n': 667 | assert False, "box cannot handle queries with negation" 668 | else: 669 | if self.model_mode == 'baseline': 670 | r_embedding = torch.index_select(self.relation_embedding, dim=0, index=queries[:, idx]) 671 | r_offset_embedding = torch.index_select(self.offset_embedding, dim=0, index=queries[:, idx]) 672 | embedding += r_embedding 673 | offset_embedding += self.func(r_offset_embedding) 674 | elif self.model_mode == 'temp': 675 | r_embedding = torch.index_select(self.relation_embedding, dim=0, index=queries[:, idx]) 676 | r_offset_embedding = torch.index_select(self.offset_embedding, dim=0, index=queries[:, idx]) 677 | rel_type_id = torch.index_select(self.relation2type, dim=0, index=queries[:, idx]) 678 | relation_neighbor_type_embedding = torch.index_select(self.type_embedding, dim=0, index=rel_type_id.view(-1)).view(rel_type_id.shape[0], rel_type_id.shape[1], -1) 679 | relation_neighbor_type_embedding = torch.transpose(relation_neighbor_type_embedding, 0, 1) 680 | rel_type_cent = self.center_net(relation_neighbor_type_embedding) 681 | 682 | embedding = embedding.unsqueeze(1) 683 | r_embedding = r_embedding.unsqueeze(1) 684 | rel_type_cent = rel_type_cent.unsqueeze(1) 685 | ent_rel = self.match([embedding, r_embedding]) 686 | ent_rel_type = self.match([embedding, rel_type_cent]) 687 | rel_ent = self.match([r_embedding, embedding]) 688 | rel_rel_type = self.match([r_embedding, rel_type_cent]) 689 | 690 | embedding = self.rel_neighbor_type_agg([ent_rel.squeeze(1), ent_rel_type.squeeze(1)]) 691 | r_embedding = self.rel_neighbor_type_agg([rel_ent.squeeze(1), rel_rel_type.squeeze(1)]) 692 | 693 | embedding += r_embedding 694 | offset_embedding += self.func(r_offset_embedding) 695 | 696 | idx += 1 697 | else: 698 | embedding_list = [] 699 | offset_embedding_list = [] 700 | for i in range(len(query_structure)): 701 | embedding, offset_embedding, idx = self.embed_query_box(queries, query_structure[i], idx) 702 | embedding_list.append(embedding) 703 | offset_embedding_list.append(offset_embedding) 704 | embedding = self.center_net(torch.stack(embedding_list)) 705 | offset_embedding = self.offset_net(torch.stack(offset_embedding_list)) 706 | 707 | return embedding, offset_embedding, idx 708 | 709 | def embed_query_beta(self, queries, query_structure, idx, filter_flag=False): 710 | ''' 711 | Iterative embed a batch of queries with same structure using BetaE 712 | queries: a flattened batch of queries 713 | ''' 714 | all_relation_flag = True 715 | if self.model_mode == 'temp': 716 | if filter_flag == False: 717 | if query_structure in query_name_dict and query_name_dict[query_structure] == 'ip': 718 | filter_flag = True 719 | for ele in query_structure[-1]: # whether the current query tree has merged to one branch and only need to do relation traversal, e.g., path queries or conjunctive queries after the intersection 720 | if ele not in ['r', 'n']: 721 | all_relation_flag = False 722 | break 723 | if all_relation_flag: 724 | if query_structure[0] == 'e': 725 | if self.model_mode == 'baseline': 726 | embedding = self.entity_regularizer(torch.index_select(self.entity_embedding, dim=0, index=queries[:, idx])) 727 | elif self.model_mode == 'temp': 728 | embedding = torch.index_select(self.entity_embedding, dim=0, index=queries[:, idx]) 729 | ent_type_id = torch.index_select(self.entity2type, dim=0, index=queries[:, idx]) 730 | entity_neighbor_type_embedding = torch.index_select(self.type_embedding, dim=0, index=ent_type_id.view(-1)).view(ent_type_id.shape[0], ent_type_id.shape[1], -1) 731 | embedding = self.ent_neighbor_type_agg(embedding, entity_neighbor_type_embedding) 732 | 733 | idx += 1 734 | else: 735 | alpha_embedding, beta_embedding, idx = self.embed_query_beta(queries, query_structure[0], idx) 736 | embedding = torch.cat([alpha_embedding, beta_embedding], dim=-1) 737 | for i in range(len(query_structure[-1])): 738 | if query_structure[-1][i] == 'n': 739 | assert (queries[:, idx] == -2).all() 740 | if self.model_mode == 'temp': 741 | embedding = self.entity_regularizer(embedding) 742 | embedding = 1./embedding 743 | else: 744 | if self.model_mode == 'baseline': 745 | r_embedding = torch.index_select(self.relation_embedding, dim=0, index=queries[:, idx]) 746 | embedding = self.projection_net(embedding, r_embedding) 747 | elif self.model_mode == 'temp': 748 | r_embedding = torch.index_select(self.relation_embedding, dim=0, index=queries[:, idx]) 749 | rel_type_id = torch.index_select(self.relation2type, dim=0, index=queries[:, idx]) 750 | relation_neighbor_type_embedding = torch.index_select(self.type_embedding, dim=0, index=rel_type_id.view(-1)).view(rel_type_id.shape[0], rel_type_id.shape[1], -1) 751 | relation_neighbor_type_embedding = torch.transpose(relation_neighbor_type_embedding, 0, 1) 752 | alpha_relation_neighbor_type_embedding, beta_relation_neighbor_type_embedding = torch.chunk(relation_neighbor_type_embedding, 2, dim=-1) 753 | alpha_embedding, beta_embedding = torch.chunk(embedding, 2, dim=-1) 754 | alpha_r_embedding, beta_r_embedding = torch.chunk(r_embedding, 2, dim=-1) 755 | alpha_rel_type_cent = self.relation_center_net(alpha_relation_neighbor_type_embedding) 756 | beta_rel_type_cent = self.relation_center_net(beta_relation_neighbor_type_embedding) 757 | 758 | alpha_embedding = alpha_embedding.unsqueeze(1) 759 | alpha_r_embedding = alpha_r_embedding.unsqueeze(1) 760 | alpha_rel_type_cent = alpha_rel_type_cent.unsqueeze(1) 761 | alpha_ent_rel = self.match([alpha_embedding, alpha_r_embedding]) 762 | alpha_ent_rel_type = self.match([alpha_embedding, alpha_rel_type_cent]) 763 | alpha_rel_ent = self.match([alpha_r_embedding, alpha_embedding]) 764 | alpha_rel_rel_type = self.match([alpha_r_embedding, alpha_rel_type_cent]) 765 | 766 | alpha_embedding = self.rel_neighbor_type_agg([alpha_ent_rel.squeeze(1), alpha_ent_rel_type.squeeze(1)]) 767 | alpha_r_embedding = self.rel_neighbor_type_agg([alpha_rel_ent.squeeze(1), alpha_rel_rel_type.squeeze(1)]) 768 | 769 | beta_embedding = beta_embedding.unsqueeze(1) 770 | beta_r_embedding = beta_r_embedding.unsqueeze(1) 771 | beta_rel_type_cent = beta_rel_type_cent.unsqueeze(1) 772 | beta_ent_rel = self.match([beta_embedding, beta_r_embedding]) 773 | beta_ent_rel_type = self.match([beta_embedding, beta_rel_type_cent]) 774 | beta_rel_ent = self.match([beta_r_embedding, beta_embedding]) 775 | beta_rel_rel_type = self.match([beta_r_embedding, beta_rel_type_cent]) 776 | 777 | beta_embedding = self.rel_neighbor_type_agg([beta_ent_rel.squeeze(1), beta_ent_rel_type.squeeze(1)]) 778 | beta_r_embedding = self.rel_neighbor_type_agg([beta_rel_ent.squeeze(1), beta_rel_rel_type.squeeze(1)]) 779 | 780 | embedding = torch.cat([alpha_embedding, beta_embedding], dim=-1) 781 | r_embedding = torch.cat([alpha_r_embedding, beta_r_embedding], dim=-1) 782 | embedding = self.projection_without_net(embedding, r_embedding) 783 | idx += 1 784 | if self.model_mode == 'temp': 785 | if filter_flag == False: 786 | embedding = self.entity_regularizer(embedding) 787 | alpha_embedding, beta_embedding = torch.chunk(embedding, 2, dim=-1) 788 | else: 789 | alpha_embedding_list = [] 790 | beta_embedding_list = [] 791 | for i in range(len(query_structure)): 792 | alpha_embedding, beta_embedding, idx = self.embed_query_beta(queries, query_structure[i], idx) 793 | alpha_embedding_list.append(alpha_embedding) 794 | beta_embedding_list.append(beta_embedding) 795 | alpha_embedding, beta_embedding = self.center_net(torch.stack(alpha_embedding_list), torch.stack(beta_embedding_list)) 796 | 797 | return alpha_embedding, beta_embedding, idx 798 | 799 | def embed_query_cone(self, queries, query_structure, idx, filter_flag=False): 800 | all_relation_flag = True 801 | if self.model_mode == 'temp': 802 | if filter_flag == False: 803 | if query_structure in query_name_dict and query_name_dict[query_structure] == 'ip': 804 | filter_flag = True 805 | for ele in query_structure[-1]: 806 | if ele not in ['r', 'n']: 807 | all_relation_flag = False 808 | break 809 | if all_relation_flag: 810 | if query_structure[0] == 'e': 811 | if self.model_mode == 'baseline': 812 | axis_entity_embedding = torch.index_select(self.entity_embedding, dim=0, index=queries[:, idx]) 813 | axis_entity_embedding = self.angle_scale(axis_entity_embedding, self.axis_scale) 814 | axis_entity_embedding = convert_to_axis(axis_entity_embedding) 815 | elif self.model_mode == 'temp': 816 | axis_entity_embedding = torch.index_select(self.entity_embedding, dim=0, index=queries[:, idx]) 817 | ent_type_id = torch.index_select(self.entity2type, dim=0, index=queries[:, idx]) 818 | entity_neighbor_type_embedding = torch.index_select(self.type_embedding, dim=0, index=ent_type_id.view(-1)).view(ent_type_id.shape[0], ent_type_id.shape[1], -1) 819 | axis_entity_embedding = self.ent_neighbor_type_agg(axis_entity_embedding, entity_neighbor_type_embedding) 820 | axis_entity_embedding = self.angle_scale(axis_entity_embedding, self.axis_scale) 821 | 822 | if self.use_cuda: 823 | arg_entity_embedding = torch.zeros_like(axis_entity_embedding).cuda() 824 | else: 825 | arg_entity_embedding = torch.zeros_like(axis_entity_embedding) 826 | idx += 1 827 | 828 | axis_embedding = axis_entity_embedding 829 | arg_embedding = arg_entity_embedding 830 | else: 831 | axis_embedding, arg_embedding, idx = self.embed_query_cone(queries, query_structure[0], idx, filter_flag) 832 | 833 | for i in range(len(query_structure[-1])): 834 | # negation 835 | if query_structure[-1][i] == 'n': 836 | assert (queries[:, idx] == -2).all() 837 | if self.model_mode == 'temp': 838 | axis_embedding = self.angle_scale(axis_embedding, self.axis_scale) 839 | axis_embedding = convert_to_axis(axis_embedding) 840 | axis_embedding, arg_embedding = self.cone_negation(axis_embedding, arg_embedding) 841 | 842 | # projection 843 | else: 844 | if self.model_mode == 'baseline': 845 | axis_r_embedding = torch.index_select(self.axis_embedding, dim=0, index=queries[:, idx]) 846 | arg_r_embedding = torch.index_select(self.arg_embedding, dim=0, index=queries[:, idx]) 847 | 848 | axis_r_embedding = self.angle_scale(axis_r_embedding, self.axis_scale) 849 | arg_r_embedding = self.angle_scale(arg_r_embedding, self.arg_scale) 850 | 851 | axis_r_embedding = convert_to_axis(axis_r_embedding) 852 | arg_r_embedding = convert_to_axis(arg_r_embedding) 853 | 854 | axis_embedding, arg_embedding = self.cone_proj(axis_embedding, arg_embedding, axis_r_embedding, arg_r_embedding) 855 | elif self.model_mode == 'temp': 856 | axis_r_embedding = torch.index_select(self.axis_embedding, dim=0, index=queries[:, idx]) 857 | axis_rel_type_id = torch.index_select(self.relation2type, dim=0, index=queries[:, idx]) 858 | axis_relation_neighbor_type_embedding = torch.index_select(self.type_embedding, dim=0, index=axis_rel_type_id.view(-1)).view(axis_rel_type_id.shape[0], axis_rel_type_id.shape[1], -1) 859 | axis_relation_neighbor_type_embedding = torch.transpose(axis_relation_neighbor_type_embedding, 0, 1) 860 | axis_rel_type_cent = self.relation_center_net(axis_relation_neighbor_type_embedding) 861 | axis_r_embedding = self.angle_scale(axis_r_embedding, self.axis_scale) 862 | axis_rel_type_cent = self.angle_scale(axis_rel_type_cent, self.axis_scale) 863 | 864 | arg_r_embedding = torch.index_select(self.arg_embedding, dim=0, index=queries[:, idx]) 865 | arg_rel_type_id = torch.index_select(self.relation2type, dim=0, index=queries[:, idx]) 866 | arg_relation_neighbor_type_embedding = torch.index_select(self.type_embedding, dim=0, index=arg_rel_type_id.view(-1)).view(arg_rel_type_id.shape[0], arg_rel_type_id.shape[1], -1) 867 | arg_relation_neighbor_type_embedding = torch.transpose(arg_relation_neighbor_type_embedding, 0, 1) 868 | arg_rel_type_cent = self.relation_center_net(arg_relation_neighbor_type_embedding) 869 | arg_r_embedding = self.angle_scale(arg_r_embedding, self.axis_scale) 870 | arg_rel_type_cent = self.angle_scale(arg_rel_type_cent, self.axis_scale) 871 | 872 | axis_embedding = axis_embedding.unsqueeze(1) 873 | axis_r_embedding = axis_r_embedding.unsqueeze(1) 874 | axis_rel_type_cent = axis_rel_type_cent.unsqueeze(1) 875 | axis_ent_rel = self.match([axis_embedding, axis_r_embedding]) 876 | axis_ent_rel_type = self.match([axis_embedding, axis_rel_type_cent]) 877 | axis_rel_ent = self.match([axis_r_embedding, axis_embedding]) 878 | axis_rel_rel_type = self.match([axis_r_embedding, axis_rel_type_cent]) 879 | 880 | axis_embedding = self.rel_neighbor_type_agg([axis_ent_rel.squeeze(1), axis_ent_rel_type.squeeze(1)]) 881 | axis_r_embedding = self.rel_neighbor_type_agg([axis_rel_ent.squeeze(1), axis_rel_rel_type.squeeze(1)]) 882 | 883 | arg_embedding = arg_embedding.unsqueeze(1) 884 | arg_r_embedding = arg_r_embedding.unsqueeze(1) 885 | arg_rel_type_cent = arg_rel_type_cent.unsqueeze(1) 886 | arg_ent_rel = self.match([arg_embedding, arg_r_embedding]) 887 | arg_ent_rel_type = self.match([arg_embedding, arg_rel_type_cent]) 888 | arg_rel_ent = self.match([arg_r_embedding, arg_embedding]) 889 | arg_rel_rel_type = self.match([arg_r_embedding, arg_rel_type_cent]) 890 | 891 | arg_embedding = self.rel_neighbor_type_agg([arg_ent_rel.squeeze(1), arg_ent_rel_type.squeeze(1)]) 892 | arg_r_embedding = self.rel_neighbor_type_agg([arg_rel_ent.squeeze(1), arg_rel_rel_type.squeeze(1)]) 893 | 894 | axis_embedding, arg_embedding = self.cone_without_proj(axis_embedding, arg_embedding, axis_r_embedding, arg_r_embedding) 895 | 896 | idx += 1 897 | if self.model_mode == 'temp': 898 | if filter_flag == False: 899 | axis_embedding = convert_to_axis(axis_embedding) 900 | arg_embedding = convert_to_axis(arg_embedding) 901 | else: 902 | # intersection 903 | axis_embedding_list = [] 904 | arg_embedding_list = [] 905 | for i in range(len(query_structure)): 906 | axis_embedding, arg_embedding, idx = self.embed_query_cone(queries, query_structure[i], idx, filter_flag) 907 | axis_embedding_list.append(axis_embedding) 908 | arg_embedding_list.append(arg_embedding) 909 | 910 | stacked_axis_embeddings = torch.stack(axis_embedding_list) 911 | stacked_arg_embeddings = torch.stack(arg_embedding_list) 912 | 913 | axis_embedding, arg_embedding = self.cone_intersection(stacked_axis_embeddings, stacked_arg_embeddings) 914 | 915 | return axis_embedding, arg_embedding, idx 916 | 917 | def embed_query_logic(self, queries, query_structure, idx, filter_flag=False): 918 | ''' 919 | Iterative embed a batch of queries with same structure using logic embeddings 920 | queries: a flattened batch of queries 921 | ''' 922 | all_relation_flag = True 923 | if filter_flag == False: 924 | if query_structure in query_name_dict and query_name_dict[query_structure] == 'ip': 925 | filter_flag = True 926 | for ele in query_structure[-1]: # whether the current query tree has merged to one branch and only need to do relation traversal, e.g., path queries or conjunctive queries after the intersection 927 | if ele not in ['r', 'n']: 928 | all_relation_flag = False 929 | break 930 | if all_relation_flag: 931 | if query_structure[0] == 'e': 932 | if self.model_mode == 'baseline': 933 | embedding = torch.index_select(self.entity_embedding, dim=0, index=queries[:, idx]) 934 | elif self.model_mode == 'temp': 935 | embedding = torch.index_select(self.entity_embedding, dim=0, index=queries[:, idx]) 936 | ent_type_id = torch.index_select(self.entity2type, dim=0, index=queries[:, idx]) 937 | entity_neighbor_type_embedding = torch.index_select(self.type_embedding, dim=0, index=ent_type_id.view(-1)).view(ent_type_id.shape[0], ent_type_id.shape[1], -1) 938 | embedding = self.ent_neighbor_type_agg(embedding, entity_neighbor_type_embedding) 939 | 940 | idx += 1 941 | else: 942 | embedding, idx = self.embed_query_logic(queries, query_structure[0], idx, filter_flag) 943 | for i in range(len(query_structure[-1])): 944 | if query_structure[-1][i] == 'n': 945 | assert (queries[:, idx] == -2).all() 946 | if self.bounded: 947 | lower_embedding, upper_embedding = torch.chunk(embedding, 2, dim=-1) 948 | embedding = torch.cat([1 - upper_embedding, 1 - lower_embedding], dim=-1) 949 | else: 950 | embedding = 1 - embedding 951 | else: 952 | if self.model_mode == 'baseline': 953 | r_embedding = torch.index_select(self.relation_embedding, dim=0, index=queries[:, idx]) 954 | embedding = self.projection_net(embedding, r_embedding) 955 | elif self.model_mode == 'temp': 956 | r_embedding = torch.index_select(self.relation_embedding, dim=0, index=queries[:, idx]) 957 | rel_type_id = torch.index_select(self.relation2type, dim=0, index=queries[:, idx]) 958 | relation_neighbor_type_embedding = torch.index_select(self.type_embedding, dim=0, index=rel_type_id.view(-1)).view(rel_type_id.shape[0], rel_type_id.shape[1], -1) 959 | relation_neighbor_type_embedding = torch.transpose(relation_neighbor_type_embedding, 0, 1) 960 | rel_type_cent = self.relation_center_net(relation_neighbor_type_embedding) 961 | 962 | embedding = embedding.unsqueeze(1) 963 | r_embedding = r_embedding.unsqueeze(1) 964 | rel_type_cent = rel_type_cent.unsqueeze(1) 965 | ent_rel = self.match([embedding, r_embedding]) 966 | ent_rel_type = self.match([embedding, rel_type_cent]) 967 | rel_ent = self.match([r_embedding, embedding]) 968 | rel_rel_type = self.match([r_embedding, rel_type_cent]) 969 | 970 | embedding = self.rel_neighbor_type_agg([ent_rel.squeeze(1), ent_rel_type.squeeze(1)]) 971 | r_embedding = self.rel_neighbor_type_agg([rel_ent.squeeze(1), rel_rel_type.squeeze(1)]) 972 | embedding = self.projection_net(embedding, r_embedding) 973 | 974 | idx += 1 975 | else: 976 | embedding_list = [] 977 | for i in range(len(query_structure)): 978 | embedding, idx = self.embed_query_logic(queries, query_structure[i], idx, filter_flag) 979 | embedding_list.append(embedding) 980 | embedding = self.center_net(torch.stack(embedding_list)) 981 | 982 | return embedding, idx 983 | 984 | def transform_union_query(self, queries, query_structure): 985 | ''' 986 | transform 2u queries to two 1p queries 987 | transform up queries to two 2p queries 988 | ''' 989 | if self.query_name_dict[query_structure] == '2u-DNF': 990 | queries = queries[:, :-1] # remove union -1 991 | elif self.query_name_dict[query_structure] == 'up-DNF': 992 | queries = torch.cat([torch.cat([queries[:, :2], queries[:, 5:6]], dim=1), torch.cat([queries[:, 2:4], queries[:, 5:6]], dim=1)], dim=1) 993 | queries = torch.reshape(queries, [queries.shape[0]*2, -1]) 994 | return queries 995 | 996 | def transform_union_structure(self, query_structure): 997 | if self.query_name_dict[query_structure] == '2u-DNF': 998 | return ('e', ('r',)) 999 | elif self.query_name_dict[query_structure] == 'up-DNF': 1000 | return ('e', ('r', 'r')) 1001 | 1002 | def cal_logit_vec(self, entity_embedding, query_embedding): 1003 | distance = entity_embedding - query_embedding 1004 | logit = self.gamma - torch.norm(distance, p=1, dim=-1) 1005 | return logit 1006 | 1007 | def forward_vec(self, positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict): 1008 | all_center_embeddings, all_idxs = [], [] 1009 | all_union_center_embeddings, all_union_idxs = [], [] 1010 | for query_structure in batch_queries_dict: 1011 | if 'u' in self.query_name_dict[query_structure]: 1012 | center_embedding, _ = self.embed_query_vec(self.transform_union_query(batch_queries_dict[query_structure], 1013 | query_structure), 1014 | self.transform_union_structure(query_structure), 0) 1015 | all_union_center_embeddings.append(center_embedding) 1016 | all_union_idxs.extend(batch_idxs_dict[query_structure]) 1017 | else: 1018 | center_embedding, _ = self.embed_query_vec(batch_queries_dict[query_structure], query_structure, 0) 1019 | all_center_embeddings.append(center_embedding) 1020 | all_idxs.extend(batch_idxs_dict[query_structure]) 1021 | 1022 | if len(all_center_embeddings) > 0: 1023 | all_center_embeddings = torch.cat(all_center_embeddings, dim=0).unsqueeze(1) 1024 | if len(all_union_center_embeddings) > 0: 1025 | all_union_center_embeddings = torch.cat(all_union_center_embeddings, dim=0).unsqueeze(1) 1026 | all_union_center_embeddings = all_union_center_embeddings.view(all_union_center_embeddings.shape[0]//2, 2, 1, -1) 1027 | 1028 | if type(subsampling_weight) != type(None): 1029 | subsampling_weight = subsampling_weight[all_idxs+all_union_idxs] 1030 | 1031 | if type(positive_sample) != type(None): 1032 | if len(all_center_embeddings) > 0: 1033 | positive_sample_regular = positive_sample[all_idxs] 1034 | positive_embedding = torch.index_select(self.entity_embedding, dim=0, index=positive_sample_regular).unsqueeze(1) 1035 | positive_logit = self.cal_logit_vec(positive_embedding, all_center_embeddings) 1036 | else: 1037 | positive_logit = torch.Tensor([]).to(self.entity_embedding.device) 1038 | 1039 | if len(all_union_center_embeddings) > 0: 1040 | positive_sample_union = positive_sample[all_union_idxs] 1041 | positive_embedding = torch.index_select(self.entity_embedding, dim=0, index=positive_sample_union).unsqueeze(1).unsqueeze(1) 1042 | positive_union_logit = self.cal_logit_vec(positive_embedding, all_union_center_embeddings) 1043 | positive_union_logit = torch.max(positive_union_logit, dim=1)[0] 1044 | else: 1045 | positive_union_logit = torch.Tensor([]).to(self.entity_embedding.device) 1046 | positive_logit = torch.cat([positive_logit, positive_union_logit], dim=0) 1047 | else: 1048 | positive_logit = None 1049 | 1050 | if type(negative_sample) != type(None): 1051 | if len(all_center_embeddings) > 0: 1052 | negative_sample_regular = negative_sample[all_idxs] 1053 | batch_size, negative_size = negative_sample_regular.shape 1054 | negative_embedding = torch.index_select(self.entity_embedding, dim=0, index=negative_sample_regular.view(-1)).view(batch_size, negative_size, -1) 1055 | negative_logit = self.cal_logit_vec(negative_embedding, all_center_embeddings) 1056 | else: 1057 | negative_logit = torch.Tensor([]).to(self.entity_embedding.device) 1058 | 1059 | if len(all_union_center_embeddings) > 0: 1060 | negative_sample_union = negative_sample[all_union_idxs] 1061 | batch_size, negative_size = negative_sample_union.shape 1062 | negative_embedding = torch.index_select(self.entity_embedding, dim=0, index=negative_sample_union.view(-1)).view(batch_size, 1, negative_size, -1) 1063 | negative_union_logit = self.cal_logit_vec(negative_embedding, all_union_center_embeddings) 1064 | negative_union_logit = torch.max(negative_union_logit, dim=1)[0] 1065 | else: 1066 | negative_union_logit = torch.Tensor([]).to(self.entity_embedding.device) 1067 | negative_logit = torch.cat([negative_logit, negative_union_logit], dim=0) 1068 | else: 1069 | negative_logit = None 1070 | 1071 | return positive_logit, negative_logit, subsampling_weight, all_idxs+all_union_idxs, None 1072 | 1073 | def cal_logit_box(self, entity_embedding, query_center_embedding, query_offset_embedding): 1074 | delta = (entity_embedding - query_center_embedding).abs() 1075 | distance_out = F.relu(delta - query_offset_embedding) 1076 | distance_in = torch.min(delta, query_offset_embedding) 1077 | logit = self.gamma - torch.norm(distance_out, p=1, dim=-1) - self.cen * torch.norm(distance_in, p=1, dim=-1) 1078 | return logit 1079 | 1080 | def forward_box(self, positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict): 1081 | all_center_embeddings, all_offset_embeddings, all_idxs = [], [], [] 1082 | all_union_center_embeddings, all_union_offset_embeddings, all_union_idxs = [], [], [] 1083 | for query_structure in batch_queries_dict: 1084 | if 'u' in self.query_name_dict[query_structure]: 1085 | center_embedding, offset_embedding, _ = \ 1086 | self.embed_query_box(self.transform_union_query(batch_queries_dict[query_structure], 1087 | query_structure), 1088 | self.transform_union_structure(query_structure), 1089 | 0) 1090 | all_union_center_embeddings.append(center_embedding) 1091 | all_union_offset_embeddings.append(offset_embedding) 1092 | all_union_idxs.extend(batch_idxs_dict[query_structure]) 1093 | else: 1094 | center_embedding, offset_embedding, _ = self.embed_query_box(batch_queries_dict[query_structure], 1095 | query_structure, 1096 | 0) 1097 | all_center_embeddings.append(center_embedding) 1098 | all_offset_embeddings.append(offset_embedding) 1099 | all_idxs.extend(batch_idxs_dict[query_structure]) 1100 | 1101 | if len(all_center_embeddings) > 0 and len(all_offset_embeddings) > 0: 1102 | all_center_embeddings = torch.cat(all_center_embeddings, dim=0).unsqueeze(1) 1103 | all_offset_embeddings = torch.cat(all_offset_embeddings, dim=0).unsqueeze(1) 1104 | if len(all_union_center_embeddings) > 0 and len(all_union_offset_embeddings) > 0: 1105 | all_union_center_embeddings = torch.cat(all_union_center_embeddings, dim=0).unsqueeze(1) 1106 | all_union_offset_embeddings = torch.cat(all_union_offset_embeddings, dim=0).unsqueeze(1) 1107 | all_union_center_embeddings = all_union_center_embeddings.view(all_union_center_embeddings.shape[0]//2, 2, 1, -1) 1108 | all_union_offset_embeddings = all_union_offset_embeddings.view(all_union_offset_embeddings.shape[0]//2, 2, 1, -1) 1109 | 1110 | if type(subsampling_weight) != type(None): 1111 | subsampling_weight = subsampling_weight[all_idxs+all_union_idxs] 1112 | 1113 | if type(positive_sample) != type(None): 1114 | if len(all_center_embeddings) > 0: 1115 | positive_sample_regular = positive_sample[all_idxs] 1116 | positive_embedding = torch.index_select(self.entity_embedding, dim=0, index=positive_sample_regular).unsqueeze(1) 1117 | positive_logit = self.cal_logit_box(positive_embedding, all_center_embeddings, all_offset_embeddings) 1118 | else: 1119 | positive_logit = torch.Tensor([]).to(self.entity_embedding.device) 1120 | 1121 | if len(all_union_center_embeddings) > 0: 1122 | positive_sample_union = positive_sample[all_union_idxs] 1123 | positive_embedding = torch.index_select(self.entity_embedding, dim=0, index=positive_sample_union).unsqueeze(1).unsqueeze(1) 1124 | positive_union_logit = self.cal_logit_box(positive_embedding, all_union_center_embeddings, all_union_offset_embeddings) 1125 | positive_union_logit = torch.max(positive_union_logit, dim=1)[0] 1126 | else: 1127 | positive_union_logit = torch.Tensor([]).to(self.entity_embedding.device) 1128 | positive_logit = torch.cat([positive_logit, positive_union_logit], dim=0) 1129 | else: 1130 | positive_logit = None 1131 | 1132 | if type(negative_sample) != type(None): 1133 | if len(all_center_embeddings) > 0: 1134 | negative_sample_regular = negative_sample[all_idxs] 1135 | batch_size, negative_size = negative_sample_regular.shape 1136 | negative_embedding = torch.index_select(self.entity_embedding, dim=0, index=negative_sample_regular.view(-1)).view(batch_size, negative_size, -1) 1137 | negative_logit = self.cal_logit_box(negative_embedding, all_center_embeddings, all_offset_embeddings) 1138 | else: 1139 | negative_logit = torch.Tensor([]).to(self.entity_embedding.device) 1140 | 1141 | if len(all_union_center_embeddings) > 0: 1142 | negative_sample_union = negative_sample[all_union_idxs] 1143 | batch_size, negative_size = negative_sample_union.shape 1144 | negative_embedding = torch.index_select(self.entity_embedding, dim=0, index=negative_sample_union.view(-1)).view(batch_size, 1, negative_size, -1) 1145 | negative_union_logit = self.cal_logit_box(negative_embedding, all_union_center_embeddings, all_union_offset_embeddings) 1146 | negative_union_logit = torch.max(negative_union_logit, dim=1)[0] 1147 | else: 1148 | negative_union_logit = torch.Tensor([]).to(self.entity_embedding.device) 1149 | negative_logit = torch.cat([negative_logit, negative_union_logit], dim=0) 1150 | else: 1151 | negative_logit = None 1152 | 1153 | return positive_logit, negative_logit, subsampling_weight, all_idxs+all_union_idxs, None 1154 | 1155 | def cal_logit_beta(self, entity_embedding, query_dist): 1156 | alpha_embedding, beta_embedding = torch.chunk(entity_embedding, 2, dim=-1) 1157 | entity_dist = torch.distributions.beta.Beta(alpha_embedding, beta_embedding) 1158 | logit = self.gamma - torch.norm(torch.distributions.kl.kl_divergence(entity_dist, query_dist), p=1, dim=-1) 1159 | return logit 1160 | 1161 | def forward_beta(self, positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict): 1162 | all_idxs, all_alpha_embeddings, all_beta_embeddings = [], [], [] 1163 | all_union_idxs, all_union_alpha_embeddings, all_union_beta_embeddings = [], [], [] 1164 | for query_structure in batch_queries_dict: 1165 | if 'u' in self.query_name_dict[query_structure] and 'DNF' in self.query_name_dict[query_structure]: 1166 | alpha_embedding, beta_embedding, _ = \ 1167 | self.embed_query_beta(self.transform_union_query(batch_queries_dict[query_structure], 1168 | query_structure), 1169 | self.transform_union_structure(query_structure), 1170 | 0) 1171 | all_union_idxs.extend(batch_idxs_dict[query_structure]) 1172 | all_union_alpha_embeddings.append(alpha_embedding) 1173 | all_union_beta_embeddings.append(beta_embedding) 1174 | else: 1175 | alpha_embedding, beta_embedding, _ = self.embed_query_beta(batch_queries_dict[query_structure], 1176 | query_structure, 1177 | 0) 1178 | all_idxs.extend(batch_idxs_dict[query_structure]) 1179 | all_alpha_embeddings.append(alpha_embedding) 1180 | all_beta_embeddings.append(beta_embedding) 1181 | 1182 | if len(all_alpha_embeddings) > 0: 1183 | all_alpha_embeddings = torch.cat(all_alpha_embeddings, dim=0).unsqueeze(1) 1184 | all_beta_embeddings = torch.cat(all_beta_embeddings, dim=0).unsqueeze(1) 1185 | all_dists = torch.distributions.beta.Beta(all_alpha_embeddings, all_beta_embeddings) 1186 | if len(all_union_alpha_embeddings) > 0: 1187 | all_union_alpha_embeddings = torch.cat(all_union_alpha_embeddings, dim=0).unsqueeze(1) 1188 | all_union_beta_embeddings = torch.cat(all_union_beta_embeddings, dim=0).unsqueeze(1) 1189 | all_union_alpha_embeddings = all_union_alpha_embeddings.view(all_union_alpha_embeddings.shape[0] // 2, 2, 1, 1190 | -1) 1191 | all_union_beta_embeddings = all_union_beta_embeddings.view(all_union_beta_embeddings.shape[0] // 2, 2, 1, 1192 | -1) 1193 | all_union_dists = torch.distributions.beta.Beta(all_union_alpha_embeddings, all_union_beta_embeddings) 1194 | 1195 | if type(subsampling_weight) != type(None): 1196 | subsampling_weight = subsampling_weight[all_idxs + all_union_idxs] 1197 | 1198 | if type(positive_sample) != type(None): 1199 | if len(all_alpha_embeddings) > 0: 1200 | positive_sample_regular = positive_sample[ 1201 | all_idxs] # positive samples for non-union queries in this batch 1202 | positive_embedding = self.entity_regularizer( 1203 | torch.index_select(self.entity_embedding, dim=0, index=positive_sample_regular).unsqueeze(1)) 1204 | positive_logit = self.cal_logit_beta(positive_embedding, all_dists) 1205 | else: 1206 | positive_logit = torch.Tensor([]).to(self.entity_embedding.device) 1207 | 1208 | if len(all_union_alpha_embeddings) > 0: 1209 | positive_sample_union = positive_sample[ 1210 | all_union_idxs] # positive samples for union queries in this batch 1211 | positive_embedding = self.entity_regularizer( 1212 | torch.index_select(self.entity_embedding, dim=0, index=positive_sample_union).unsqueeze( 1213 | 1).unsqueeze(1)) 1214 | positive_union_logit = self.cal_logit_beta(positive_embedding, all_union_dists) 1215 | positive_union_logit = torch.max(positive_union_logit, dim=1)[0] 1216 | else: 1217 | positive_union_logit = torch.Tensor([]).to(self.entity_embedding.device) 1218 | positive_logit = torch.cat([positive_logit, positive_union_logit], dim=0) 1219 | else: 1220 | positive_logit = None 1221 | 1222 | if type(negative_sample) != type(None): 1223 | if len(all_alpha_embeddings) > 0: 1224 | negative_sample_regular = negative_sample[all_idxs] 1225 | batch_size, negative_size = negative_sample_regular.shape 1226 | negative_embedding = self.entity_regularizer( 1227 | torch.index_select(self.entity_embedding, dim=0, index=negative_sample_regular.view(-1)).view( 1228 | batch_size, negative_size, -1)) 1229 | negative_logit = self.cal_logit_beta(negative_embedding, all_dists) 1230 | else: 1231 | negative_logit = torch.Tensor([]).to(self.entity_embedding.device) 1232 | 1233 | if len(all_union_alpha_embeddings) > 0: 1234 | negative_sample_union = negative_sample[all_union_idxs] 1235 | batch_size, negative_size = negative_sample_union.shape 1236 | negative_embedding = self.entity_regularizer( 1237 | torch.index_select(self.entity_embedding, dim=0, index=negative_sample_union.view(-1)).view( 1238 | batch_size, 1, negative_size, -1)) 1239 | negative_union_logit = self.cal_logit_beta(negative_embedding, all_union_dists) 1240 | negative_union_logit = torch.max(negative_union_logit, dim=1)[0] 1241 | else: 1242 | negative_union_logit = torch.Tensor([]).to(self.entity_embedding.device) 1243 | negative_logit = torch.cat([negative_logit, negative_union_logit], dim=0) 1244 | else: 1245 | negative_logit = None 1246 | 1247 | return positive_logit, negative_logit, subsampling_weight, all_idxs + all_union_idxs, None 1248 | 1249 | def cal_logit_cone(self, entity_embedding, query_axis_embedding, query_arg_embedding): 1250 | delta1 = entity_embedding - (query_axis_embedding - query_arg_embedding) 1251 | delta2 = entity_embedding - (query_axis_embedding + query_arg_embedding) 1252 | 1253 | distance2axis = torch.abs(torch.sin((entity_embedding - query_axis_embedding) / 2)) 1254 | distance_base = torch.abs(torch.sin(query_arg_embedding / 2)) 1255 | 1256 | indicator_in = distance2axis < distance_base 1257 | distance_out = torch.min(torch.abs(torch.sin(delta1 / 2)), torch.abs(torch.sin(delta2 / 2))) 1258 | distance_out[indicator_in] = 0. 1259 | 1260 | distance_in = torch.min(distance2axis, distance_base) 1261 | 1262 | distance = torch.norm(distance_out, p=1, dim=-1) + self.cen * torch.norm(distance_in, p=1, dim=-1) 1263 | logit = self.gamma - distance * self.modulus 1264 | 1265 | return logit 1266 | 1267 | def forward_cone(self, positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict): 1268 | all_idxs, all_axis_embeddings, all_arg_embeddings = [], [], [] 1269 | all_union_idxs, all_union_axis_embeddings, all_union_arg_embeddings = [], [], [] 1270 | for query_structure in batch_queries_dict: 1271 | if 'u' in self.query_name_dict[query_structure] and 'DNF' in self.query_name_dict[query_structure]: 1272 | axis_embedding, arg_embedding, _ = \ 1273 | self.embed_query_cone(self.transform_union_query(batch_queries_dict[query_structure], query_structure), self.transform_union_structure(query_structure), 0) 1274 | all_union_idxs.extend(batch_idxs_dict[query_structure]) 1275 | all_union_axis_embeddings.append(axis_embedding) 1276 | all_union_arg_embeddings.append(arg_embedding) 1277 | else: 1278 | axis_embedding, arg_embedding, _ = self.embed_query_cone(batch_queries_dict[query_structure], query_structure, 0) 1279 | all_idxs.extend(batch_idxs_dict[query_structure]) 1280 | all_axis_embeddings.append(axis_embedding) 1281 | all_arg_embeddings.append(arg_embedding) 1282 | 1283 | if len(all_axis_embeddings) > 0: 1284 | all_axis_embeddings = torch.cat(all_axis_embeddings, dim=0).unsqueeze(1) 1285 | all_arg_embeddings = torch.cat(all_arg_embeddings, dim=0).unsqueeze(1) 1286 | if len(all_union_axis_embeddings) > 0: 1287 | all_union_axis_embeddings = torch.cat(all_union_axis_embeddings, dim=0).unsqueeze(1) 1288 | all_union_arg_embeddings = torch.cat(all_union_arg_embeddings, dim=0).unsqueeze(1) 1289 | all_union_axis_embeddings = all_union_axis_embeddings.view( 1290 | all_union_axis_embeddings.shape[0] // 2, 2, 1, -1) 1291 | all_union_arg_embeddings = all_union_arg_embeddings.view( 1292 | all_union_arg_embeddings.shape[0] // 2, 2, 1, -1) 1293 | if type(subsampling_weight) != type(None): 1294 | subsampling_weight = subsampling_weight[all_idxs + all_union_idxs] 1295 | 1296 | if type(positive_sample) != type(None): 1297 | if len(all_axis_embeddings) > 0: 1298 | # positive samples for non-union queries in this batch 1299 | positive_sample_regular = positive_sample[all_idxs] 1300 | positive_embedding = torch.index_select(self.entity_embedding, dim=0, index=positive_sample_regular).unsqueeze(1) 1301 | 1302 | positive_embedding = self.angle_scale(positive_embedding, self.axis_scale) 1303 | positive_embedding = convert_to_axis(positive_embedding) 1304 | 1305 | positive_logit = self.cal_logit_cone(positive_embedding, all_axis_embeddings, all_arg_embeddings) 1306 | else: 1307 | positive_logit = torch.Tensor([]).to(self.entity_embedding.device) 1308 | 1309 | 1310 | if len(all_union_axis_embeddings) > 0: 1311 | # positive samples for union queries in this batch 1312 | positive_sample_union = positive_sample[all_union_idxs] 1313 | positive_embedding = torch.index_select(self.entity_embedding, dim=0, index=positive_sample_union).unsqueeze(1).unsqueeze(1) 1314 | 1315 | positive_embedding = self.angle_scale(positive_embedding, self.axis_scale) 1316 | positive_embedding = convert_to_axis(positive_embedding) 1317 | 1318 | positive_union_logit = self.cal_logit_cone(positive_embedding, all_union_axis_embeddings, all_union_arg_embeddings) 1319 | 1320 | positive_union_logit = torch.max(positive_union_logit, dim=1)[0] 1321 | else: 1322 | positive_union_logit = torch.Tensor([]).to(self.entity_embedding.device) 1323 | positive_logit = torch.cat([positive_logit, positive_union_logit], dim=0) 1324 | else: 1325 | positive_logit = None 1326 | 1327 | if type(negative_sample) != type(None): 1328 | if len(all_axis_embeddings) > 0: 1329 | negative_sample_regular = negative_sample[all_idxs] 1330 | batch_size, negative_size = negative_sample_regular.shape 1331 | negative_embedding = torch.index_select(self.entity_embedding, dim=0, index=negative_sample_regular.view(-1)).view(batch_size, negative_size, -1) 1332 | negative_embedding = self.angle_scale(negative_embedding, self.axis_scale) 1333 | negative_embedding = convert_to_axis(negative_embedding) 1334 | 1335 | negative_logit = self.cal_logit_cone(negative_embedding, all_axis_embeddings, all_arg_embeddings) 1336 | else: 1337 | negative_logit = torch.Tensor([]).to(self.entity_embedding.device) 1338 | 1339 | if len(all_union_axis_embeddings) > 0: 1340 | negative_sample_union = negative_sample[all_union_idxs] 1341 | batch_size, negative_size = negative_sample_union.shape 1342 | negative_embedding = torch.index_select(self.entity_embedding, dim=0, index=negative_sample_union.view(-1)).view(batch_size, 1, negative_size, -1) 1343 | negative_embedding = self.angle_scale(negative_embedding, self.axis_scale) 1344 | negative_embedding = convert_to_axis(negative_embedding) 1345 | 1346 | negative_union_logit = self.cal_logit_cone(negative_embedding, all_union_axis_embeddings, all_union_arg_embeddings) 1347 | negative_union_logit = torch.max(negative_union_logit, dim=1)[0] 1348 | else: 1349 | negative_union_logit = torch.Tensor([]).to(self.entity_embedding.device) 1350 | negative_logit = torch.cat([negative_logit, negative_union_logit], dim=0) 1351 | else: 1352 | negative_logit = None 1353 | 1354 | return positive_logit, negative_logit, subsampling_weight, all_idxs + all_union_idxs, None 1355 | 1356 | def cal_logit_logic(self, entity_embedding, query_embedding): 1357 | if self.bounded: 1358 | lower_embedding, upper_embedding = torch.chunk(entity_embedding, 2, dim=-1) 1359 | query_lower_embedding, query_upper_embedding = torch.chunk(query_embedding, 2, dim=-1) 1360 | 1361 | lower_dist = torch.norm(lower_embedding - query_lower_embedding, p=1, dim=-1) 1362 | upper_dist = torch.norm(query_upper_embedding - upper_embedding, p=1, dim=-1) 1363 | 1364 | logit = self.gamma - (lower_dist + upper_dist) / 2 / lower_embedding.shape[-1] 1365 | else: 1366 | logit = self.gamma - torch.norm(entity_embedding - query_embedding, p=1, dim=-1) / query_embedding.shape[-1] 1367 | 1368 | logit *= 100 1369 | 1370 | return logit 1371 | 1372 | def forward_logic(self, positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict): 1373 | all_entropy = None 1374 | all_idxs, all_embeddings = [], [] 1375 | all_union_idxs, all_union_embeddings = [], [] 1376 | for query_structure in batch_queries_dict: 1377 | if 'u' in self.query_name_dict[query_structure] and 'DNF' in self.query_name_dict[query_structure]: 1378 | embedding, _ = \ 1379 | self.embed_query_logic(self.transform_union_query(batch_queries_dict[query_structure], 1380 | query_structure), 1381 | self.transform_union_structure(query_structure), 1382 | 0) 1383 | all_union_idxs.extend(batch_idxs_dict[query_structure]) 1384 | all_union_embeddings.append(embedding) 1385 | else: 1386 | embedding, _ = self.embed_query_logic(batch_queries_dict[query_structure], 1387 | query_structure, 1388 | 0) 1389 | all_idxs.extend(batch_idxs_dict[query_structure]) 1390 | all_embeddings.append(embedding) 1391 | 1392 | if len(all_embeddings) > 0: 1393 | all_embeddings = torch.cat(all_embeddings, dim=0).unsqueeze(1) 1394 | 1395 | if positive_sample is None and self.bounded: # test step - measure entropy 1396 | lower, upper = torch.chunk(all_embeddings, 2, dim=-1) 1397 | truth_interval = upper - lower 1398 | distribution = torch.distributions.uniform.Uniform(lower, upper + eps) 1399 | all_entropy = (distribution.entropy(), truth_interval) 1400 | 1401 | if len(all_union_embeddings) > 0: 1402 | all_union_embeddings = torch.cat(all_union_embeddings, dim=0).unsqueeze(1) 1403 | all_union_embeddings = all_union_embeddings.view(all_union_embeddings.shape[0] // 2, 2, 1, -1) 1404 | 1405 | if type(subsampling_weight) != type(None): 1406 | subsampling_weight = subsampling_weight[all_idxs + all_union_idxs] 1407 | 1408 | if type(positive_sample) != type(None): 1409 | if len(all_embeddings) > 0: 1410 | positive_sample_regular = positive_sample[all_idxs] # positive samples for non-union queries in this batch 1411 | positive_embedding = torch.index_select(self.entity_embedding, dim=0, index=positive_sample_regular).unsqueeze(1) 1412 | positive_logit = self.cal_logit_logic(positive_embedding, all_embeddings) 1413 | else: 1414 | positive_logit = torch.Tensor([]).to(self.entity_embedding.device) 1415 | 1416 | if len(all_union_embeddings) > 0: 1417 | positive_sample_union = positive_sample[all_union_idxs] # positive samples for union queries in this batch 1418 | positive_embedding = torch.index_select(self.entity_embedding, dim=0, index=positive_sample_union).unsqueeze(1).unsqueeze(1) 1419 | positive_union_logit = self.cal_logit_logic(positive_embedding, all_union_embeddings) 1420 | positive_union_logit = torch.max(positive_union_logit, dim=1)[0] 1421 | else: 1422 | positive_union_logit = torch.Tensor([]).to(self.entity_embedding.device) 1423 | positive_logit = torch.cat([positive_logit, positive_union_logit], dim=0) 1424 | else: 1425 | positive_logit = None 1426 | 1427 | if type(negative_sample) != type(None): 1428 | if len(all_embeddings) > 0: 1429 | negative_sample_regular = negative_sample[all_idxs] 1430 | batch_size, negative_size = negative_sample_regular.shape 1431 | negative_embedding = torch.index_select(self.entity_embedding, dim=0, 1432 | index=negative_sample_regular.view(-1)).view(batch_size, 1433 | negative_size, -1) 1434 | negative_logit = self.cal_logit_logic(negative_embedding, all_embeddings) 1435 | else: 1436 | negative_logit = torch.Tensor([]).to(self.entity_embedding.device) 1437 | 1438 | if len(all_union_embeddings) > 0: 1439 | negative_sample_union = negative_sample[all_union_idxs] 1440 | batch_size, negative_size = negative_sample_union.shape 1441 | negative_embedding = torch.index_select(self.entity_embedding, dim=0, 1442 | index=negative_sample_union.view(-1)).view(batch_size, 1, 1443 | negative_size, -1) 1444 | negative_union_logit = self.cal_logit_logic(negative_embedding, all_union_embeddings) 1445 | negative_union_logit = torch.max(negative_union_logit, dim=1)[0] 1446 | else: 1447 | negative_union_logit = torch.Tensor([]).to(self.entity_embedding.device) 1448 | negative_logit = torch.cat([negative_logit, negative_union_logit], dim=0) 1449 | else: 1450 | negative_logit = None 1451 | 1452 | return positive_logit, negative_logit, subsampling_weight, all_idxs + all_union_idxs, all_entropy 1453 | 1454 | @staticmethod 1455 | def train_step(model, optimizer, train_iterator, args, step): 1456 | model.train() 1457 | optimizer.zero_grad() 1458 | 1459 | positive_sample, negative_sample, subsampling_weight, batch_queries, query_structures = next(train_iterator) 1460 | batch_queries_dict = collections.defaultdict(list) 1461 | batch_idxs_dict = collections.defaultdict(list) 1462 | for i, query in enumerate(batch_queries): # group queries with same structure 1463 | batch_queries_dict[query_structures[i]].append(query) 1464 | batch_idxs_dict[query_structures[i]].append(i) 1465 | for query_structure in batch_queries_dict: 1466 | if args.cuda: 1467 | batch_queries_dict[query_structure] = torch.LongTensor(batch_queries_dict[query_structure]).cuda() 1468 | else: 1469 | batch_queries_dict[query_structure] = torch.LongTensor(batch_queries_dict[query_structure]) 1470 | if args.cuda: 1471 | positive_sample = positive_sample.cuda() 1472 | negative_sample = negative_sample.cuda() 1473 | subsampling_weight = subsampling_weight.cuda() 1474 | 1475 | positive_logit, negative_logit, subsampling_weight, _, _ = model(positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict) 1476 | 1477 | negative_score = F.logsigmoid(-negative_logit).mean(dim=1) 1478 | positive_score = F.logsigmoid(positive_logit).squeeze(dim=1) 1479 | positive_sample_loss = - (subsampling_weight * positive_score).sum() 1480 | negative_sample_loss = - (subsampling_weight * negative_score).sum() 1481 | positive_sample_loss /= subsampling_weight.sum() 1482 | negative_sample_loss /= subsampling_weight.sum() 1483 | 1484 | loss = (positive_sample_loss + negative_sample_loss)/2 1485 | loss.backward() 1486 | optimizer.step() 1487 | log = { 1488 | 'positive_sample_loss': positive_sample_loss.item(), 1489 | 'negative_sample_loss': negative_sample_loss.item(), 1490 | 'loss': loss.item(), 1491 | } 1492 | return log 1493 | 1494 | @staticmethod 1495 | def test_step(model, easy_answers, hard_answers, args, test_dataloader, query_name_dict, save_result=False, save_str="", save_empty=False): 1496 | model.eval() 1497 | 1498 | step = 0 1499 | total_steps = len(test_dataloader) 1500 | logs = collections.defaultdict(list) 1501 | 1502 | with torch.no_grad(): 1503 | for negative_sample, queries, queries_unflatten, query_structures in tqdm(test_dataloader, disable=not args.print_on_screen): 1504 | batch_queries_dict = collections.defaultdict(list) 1505 | batch_idxs_dict = collections.defaultdict(list) 1506 | for i, query in enumerate(queries): 1507 | batch_queries_dict[query_structures[i]].append(query) 1508 | batch_idxs_dict[query_structures[i]].append(i) 1509 | for query_structure in batch_queries_dict: 1510 | if args.cuda: 1511 | batch_queries_dict[query_structure] = torch.LongTensor(batch_queries_dict[query_structure]).cuda() 1512 | else: 1513 | batch_queries_dict[query_structure] = torch.LongTensor(batch_queries_dict[query_structure]) 1514 | if args.cuda: 1515 | negative_sample = negative_sample.cuda() 1516 | 1517 | _, negative_logit, _, idxs, _ = model(None, negative_sample, None, batch_queries_dict, batch_idxs_dict) 1518 | queries_unflatten = [queries_unflatten[i] for i in idxs] 1519 | query_structures = [query_structures[i] for i in idxs] 1520 | argsort = torch.argsort(negative_logit, dim=1, descending=True) 1521 | ranking = argsort.clone().to(torch.float) 1522 | if len(argsort) == args.test_batch_size: # if it is the same shape with test_batch_size, we can reuse batch_entity_range without creating a new one 1523 | ranking = ranking.scatter_(1, argsort, model.batch_entity_range) # achieve the ranking of all entities 1524 | else: # otherwise, create a new torch Tensor for batch_entity_range 1525 | if args.cuda: 1526 | ranking = ranking.scatter_(1, 1527 | argsort, 1528 | torch.arange(model.nentity).to(torch.float).repeat(argsort.shape[0], 1529 | 1).cuda() 1530 | ) # achieve the ranking of all entities 1531 | else: 1532 | ranking = ranking.scatter_(1, 1533 | argsort, 1534 | torch.arange(model.nentity).to(torch.float).repeat(argsort.shape[0], 1535 | 1) 1536 | ) # achieve the ranking of all entities 1537 | for idx, (i, query, query_structure) in enumerate(zip(argsort[:, 0], queries_unflatten, query_structures)): 1538 | hard_answer = hard_answers[query] 1539 | easy_answer = easy_answers[query] 1540 | num_hard = len(hard_answer) 1541 | num_easy = len(easy_answer) 1542 | assert len(hard_answer.intersection(easy_answer)) == 0 1543 | cur_ranking = ranking[idx, list(easy_answer) + list(hard_answer)] 1544 | cur_ranking, indices = torch.sort(cur_ranking) 1545 | masks = indices >= num_easy 1546 | if args.cuda: 1547 | answer_list = torch.arange(num_hard + num_easy).to(torch.float).cuda() 1548 | else: 1549 | answer_list = torch.arange(num_hard + num_easy).to(torch.float) 1550 | cur_ranking = cur_ranking - answer_list + 1 # filtered setting 1551 | cur_ranking = cur_ranking[masks] # only take indices that belong to the hard answers 1552 | 1553 | mrr = torch.mean(1./cur_ranking).item() 1554 | h1 = torch.mean((cur_ranking <= 1).to(torch.float)).item() 1555 | h3 = torch.mean((cur_ranking <= 3).to(torch.float)).item() 1556 | h10 = torch.mean((cur_ranking <= 10).to(torch.float)).item() 1557 | 1558 | logs[query_structure].append({ 1559 | 'MRR': mrr, 1560 | 'HITS1': h1, 1561 | 'HITS3': h3, 1562 | 'HITS10': h10, 1563 | 'num_hard_answer': num_hard, 1564 | }) 1565 | 1566 | if step % args.test_log_steps == 0: 1567 | logging.info('Evaluating the model... (%d/%d)' % (step, total_steps)) 1568 | 1569 | step += 1 1570 | 1571 | metrics = collections.defaultdict(lambda: collections.defaultdict(int)) 1572 | for query_structure in logs: 1573 | for metric in logs[query_structure][0].keys(): 1574 | if metric in ['num_hard_answer']: 1575 | continue 1576 | metrics[query_structure][metric] = sum([log[metric] for log in logs[query_structure]])/len(logs[query_structure]) 1577 | metrics[query_structure]['num_queries'] = len(logs[query_structure]) 1578 | 1579 | return metrics --------------------------------------------------------------------------------