├── .gitignore ├── README.md ├── data ├── FB15K237 │ ├── FB15K237.pickle │ ├── README.txt │ ├── entities.dict │ ├── relations.dict │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── NELL-995 │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── WN18RR │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── WN18RR_v1 │ ├── 1 │ │ ├── grail_neg_test_0_predictions.txt │ │ ├── grail_test_predictions.txt │ │ ├── neg_test_0.txt │ │ ├── relation2id.json │ │ ├── subgraphs_en_True_neg_1_hop_3 │ │ │ ├── data.mdb │ │ │ └── lock.mdb │ │ └── test_subgraphs_grail_wn_v1_0_en_True │ │ │ ├── data.mdb │ │ │ └── lock.mdb │ ├── 2 │ │ ├── relation2id.json │ │ └── subgraphs_en_True_neg_1_hop_3 │ │ │ ├── data.mdb │ │ │ └── lock.mdb │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── WN18RR_v1_ind │ ├── 1 │ │ ├── grail_neg_test_0_predictions.txt │ │ ├── grail_ranking_head_predictions.txt │ │ ├── grail_ranking_tail_predictions.txt │ │ ├── grail_test_predictions.txt │ │ ├── neg_test_0.txt │ │ ├── ranking_head.txt │ │ ├── ranking_tail.txt │ │ └── test_subgraphs_grail_wn_v1_0_en_True │ │ │ ├── data.mdb │ │ │ └── lock.mdb │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── WN18RR_v2 │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── WN18RR_v2_ind │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── WN18RR_v3 │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── WN18RR_v3_ind │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── WN18RR_v4 │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── WN18RR_v4_ind │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── fb237_v1 │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── fb237_v1_ind │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── fb237_v2 │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── fb237_v2_ind │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── fb237_v3 │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── fb237_v3_ind │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── fb237_v4 │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── fb237_v4_ind │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── nell_v1 │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── nell_v1_ind │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── nell_v2 │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── nell_v2_ind │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── nell_v3 │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── nell_v3_ind │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── nell_v4 │ ├── test.txt │ ├── train.txt │ └── valid.txt └── nell_v4_ind │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── ensembling ├── blend.py ├── compute_auc.py ├── compute_rank_metrics.py ├── get_ensemble_predictions.sh ├── get_kge_ensemble.sh └── score_triplets_kge.py ├── managers ├── evaluator.py └── trainer.py ├── model └── dgl │ ├── __init__.py │ ├── aggregators.py │ ├── graph_classifier.py │ ├── layers.py │ ├── layers_new.py │ ├── layers_ori.py │ └── rgcn_model.py ├── relational_path ├── path_process.py ├── path_sampler.py └── readme.txt ├── requirements.txt ├── subgraph_extraction ├── datasets.py └── datasets_path.py ├── test_auc.py ├── train.py └── utils ├── clean_data.py ├── data_utils.py ├── dgl_utils.py ├── graph_utils.py ├── initialization_utils.py └── prepare_meta_data.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | __pycache__/ 3 | tmp.txt 4 | experiments/ 5 | 6 | #Saved and downloaded data files 7 | *.nt.gz 8 | *.npz 9 | *.pkl 10 | *.ipynb 11 | *.npy 12 | *.pyc 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | A LogCo demo. 2 | Thanks for the framework by https://github.com/kkteru/grail 3 | -------------------------------------------------------------------------------- /data/FB15K237/FB15K237.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyd418/LogCo/7b6b5795cac46a5ff7c781f93d03c741f48ced71/data/FB15K237/FB15K237.pickle -------------------------------------------------------------------------------- /data/FB15K237/README.txt: -------------------------------------------------------------------------------- 1 | FB15K-237 Knowledge Base Completion Dataset 2 | 3 | This dataset contains knowledge base relation triples and textual mentions of Freebase entity pairs, as used in the work published in [1] and [2]. 4 | The knowledge base triples are a subset of the FB15K set [3], originally derived from Freebase. The textual mentions are derived from 200 million sentences from the ClueWeb12 [5] corpus coupled with Freebase entity mention annotations [4]. 5 | 6 | 7 | FILE FORMAT DETAILS 8 | 9 | The files train.txt, valid.txt, and test.text contain the training, development, and test set knowledge base triples used in both [1] and [2]. 10 | The file text_cvsc.txt contains the textual triples used in [2] and the file text_emnlp.txt contains the textual triples used in [1]. 11 | 12 | The knowledge base triples contain lines like this: 13 | 14 | /m/0grwj /people/person/profession /m/05sxg2 15 | 16 | The format is: 17 | 18 | mid1 relation mid2 19 | 20 | The separator is a tab character; the mids are Freebase ids of entities, and the relation is a single or a two-hop relation from Freebase, where an intermediate complex value type entity has been collapsed out. 21 | 22 | The textual mentions files have lines like this: 23 | 24 | /m/02qkt [XXX]:<-nn>:fact:<-pobj>:in:<-prep>:game:<-nsubj>:'s::pivot::[YYY] /m/05sb1 3 25 | 26 | This indicates the mids of two Freebase entities, together with a fully lexicalized dependency path between the entities. The last element in the tuple is the number of occurrences of the specified entity pair with the given dependency path in sentences from ClueWeb12. 27 | The dependency paths are specified as sequences of words (like the word "fact" above) and labeled dependency links (like above). The direction of traversal of a dependency arc is indicated by whether there is a - sign in front of the arc label "e.g." <-nsubj> vs . 28 | 29 | 30 | REFERENCES 31 | 32 | [1] Kristina Toutanova, Danqi Chen, Patrick Pantel, Hoifung Poon, Pallavi Choudhury, and Michael Gamon. Representing text for joint embedding of text and knowledge bases. In Proceedings of EMNLP 2015. 33 | [2] Kristina Toutanova and Danqi Chen. Observed versus latent features for knowledge base and text inference. In Proceedings of the 3rd Workshop on Continuous Vector Space Models and Their Compositionality 2015. 34 | [3] Antoine Bordes, Nicolas Usunier, Alberto Garcia Duran, Jason Weston, and Oksana Yakhnenko. Translating embeddings for modeling multirelational data. In Advances in Neural Information Processing Systems (NIPS) 2013. 35 | [4] Evgeniy Gabrilovich, Michael Ringgaard, and Amarnag Subramanya. FACC1: Freebase annotation of ClueWeb corpora, Version 1 (release date 2013-06-26, format version 1, correction level 0). http://lemurproject.org/clueweb12/FACC1/ 36 | [5] http://lemurproject.org/clueweb12/ 37 | 38 | 39 | CONTACT 40 | 41 | Please contact Kristina Toutanova kristout@microsoft.com if you have questions about the dataset. 42 | -------------------------------------------------------------------------------- /data/WN18RR_v1/1/relation2id.json: -------------------------------------------------------------------------------- 1 | {"_hypernym": 0, "_derivationally_related_form": 1, "_also_see": 2, "_synset_domain_topic_of": 3, "_has_part": 4, "_verb_group": 5, "_member_meronym": 6, "_similar_to": 7, "_instance_hypernym": 8} -------------------------------------------------------------------------------- /data/WN18RR_v1/1/subgraphs_en_True_neg_1_hop_3/data.mdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyd418/LogCo/7b6b5795cac46a5ff7c781f93d03c741f48ced71/data/WN18RR_v1/1/subgraphs_en_True_neg_1_hop_3/data.mdb -------------------------------------------------------------------------------- /data/WN18RR_v1/1/subgraphs_en_True_neg_1_hop_3/lock.mdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyd418/LogCo/7b6b5795cac46a5ff7c781f93d03c741f48ced71/data/WN18RR_v1/1/subgraphs_en_True_neg_1_hop_3/lock.mdb -------------------------------------------------------------------------------- /data/WN18RR_v1/1/test_subgraphs_grail_wn_v1_0_en_True/data.mdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyd418/LogCo/7b6b5795cac46a5ff7c781f93d03c741f48ced71/data/WN18RR_v1/1/test_subgraphs_grail_wn_v1_0_en_True/data.mdb -------------------------------------------------------------------------------- /data/WN18RR_v1/1/test_subgraphs_grail_wn_v1_0_en_True/lock.mdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyd418/LogCo/7b6b5795cac46a5ff7c781f93d03c741f48ced71/data/WN18RR_v1/1/test_subgraphs_grail_wn_v1_0_en_True/lock.mdb -------------------------------------------------------------------------------- /data/WN18RR_v1/2/relation2id.json: -------------------------------------------------------------------------------- 1 | {"_hypernym": 0, "_derivationally_related_form": 1, "_also_see": 2, "_synset_domain_topic_of": 3, "_has_part": 4, "_verb_group": 5, "_member_meronym": 6, "_similar_to": 7, "_instance_hypernym": 8} -------------------------------------------------------------------------------- /data/WN18RR_v1/2/subgraphs_en_True_neg_1_hop_3/data.mdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyd418/LogCo/7b6b5795cac46a5ff7c781f93d03c741f48ced71/data/WN18RR_v1/2/subgraphs_en_True_neg_1_hop_3/data.mdb -------------------------------------------------------------------------------- /data/WN18RR_v1/2/subgraphs_en_True_neg_1_hop_3/lock.mdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyd418/LogCo/7b6b5795cac46a5ff7c781f93d03c741f48ced71/data/WN18RR_v1/2/subgraphs_en_True_neg_1_hop_3/lock.mdb -------------------------------------------------------------------------------- /data/WN18RR_v1_ind/1/grail_neg_test_0_predictions.txt: -------------------------------------------------------------------------------- 1 | 00445169 _similar_to 04181228 5.445107936859131 2 | 10291110 _derivationally_related_form 01410363 -10.292617797851562 3 | 01666717 _derivationally_related_form 01087197 -10.292617797851562 4 | 01149494 _also_see 06799897 -1.6430387496948242 5 | 00233335 _derivationally_related_form 01093587 -10.292617797851562 6 | 10341660 _derivationally_related_form 04713118 -10.292617797851562 7 | 01283208 _derivationally_related_form 01340439 -10.292617797851562 8 | 00088481 _derivationally_related_form 03128583 -10.292617797851562 9 | 04565375 _derivationally_related_form 01662771 -10.292617797851562 10 | 01021128 _derivationally_related_form 02694933 -10.292617797851562 11 | 02512305 _derivationally_related_form 02023992 -10.292617797851562 12 | 00456740 _derivationally_related_form 02820210 -10.292617797851562 13 | 01292885 _derivationally_related_form 05748285 -10.292617797851562 14 | 04802776 _derivationally_related_form 04750164 -10.292617797851562 15 | 03728811 _derivationally_related_form 03075768 -10.292617797851562 16 | 13903387 _derivationally_related_form 01780941 -10.292617797851562 17 | 05849789 _derivationally_related_form 07515560 -10.292618751525879 18 | 03173524 _derivationally_related_form 01276361 37.611061096191406 19 | 00847340 _derivationally_related_form 00770141 -10.292618751525879 20 | 02509287 _derivationally_related_form 01276361 -10.292618751525879 21 | 00527572 _derivationally_related_form 06000644 -10.292618751525879 22 | 03961939 _derivationally_related_form 01410363 -10.292618751525879 23 | 00253761 _derivationally_related_form 01590171 -10.292618751525879 24 | 00915722 _derivationally_related_form 00921790 -10.292618751525879 25 | 00737188 _hypernym 07355887 -0.2847031354904175 26 | 02728440 _derivationally_related_form 01151097 -10.292618751525879 27 | 01159776 _derivationally_related_form 01612053 -10.292618751525879 28 | 01052853 _hypernym 03605915 -0.2847031354904175 29 | 03264542 _hypernym 09769929 -0.2847031354904175 30 | 07976936 _derivationally_related_form 01135529 -10.292618751525879 31 | 01779165 _derivationally_related_form 07355887 -10.292618751525879 32 | 01779165 _derivationally_related_form 09941383 -10.292618751525879 33 | 04608567 _derivationally_related_form 09972157 -10.292617797851562 34 | 00922438 _hypernym 04639113 -0.2847031354904175 35 | 10078806 _derivationally_related_form 14110411 -10.292617797851562 36 | 00064095 _derivationally_related_form 06526291 -10.292617797851562 37 | 07006712 _derivationally_related_form 01285440 -10.292617797851562 38 | 03385557 _derivationally_related_form 01662771 -10.292617797851562 39 | 00353782 _derivationally_related_form 05748285 -10.292617797851562 40 | 02703952 _derivationally_related_form 01026262 -10.292617797851562 41 | 10371741 _derivationally_related_form 02785648 -10.292617797851562 42 | 00476819 _derivationally_related_form 13913566 -10.292617797851562 43 | 04361641 _hypernym 01240210 -0.2847031354904175 44 | 05085572 _derivationally_related_form 07369604 -10.292617797851562 45 | 13860793 _derivationally_related_form 02441022 -10.292617797851562 46 | 00082308 _derivationally_related_form 06004685 -10.292617797851562 47 | 00429060 _derivationally_related_form 03330947 -10.292617797851562 48 | 04613158 _derivationally_related_form 05085572 -10.292617797851562 49 | 00751887 _derivationally_related_form 00595935 -10.292617797851562 50 | 00919513 _derivationally_related_form 00462092 -10.292617797851562 51 | 02538765 _derivationally_related_form 03878963 -10.292617797851562 52 | 01059400 _also_see 03106110 -1.6430387496948242 53 | 10291110 _hypernym 04712735 -0.2847031354904175 54 | 13999663 _derivationally_related_form 01689379 -10.292617797851562 55 | 01089137 _derivationally_related_form 01586850 -10.292617797851562 56 | 00188466 _derivationally_related_form 00351638 -10.292617797851562 57 | 00520257 _derivationally_related_form 00893955 -10.292617797851562 58 | 10213034 _derivationally_related_form 10080337 -10.292617797851562 59 | 01203715 _hypernym 01779165 -0.2847031354904175 60 | 01667449 _also_see 02451113 -1.6430387496948242 61 | 00119873 _derivationally_related_form 00253761 -10.292617797851562 62 | 10078806 _derivationally_related_form 08630985 -10.292617797851562 63 | 00169651 _derivationally_related_form 00047317 -10.292617797851562 64 | 10399491 _derivationally_related_form 00650353 -10.292617797851562 65 | 15224293 _derivationally_related_form 01093587 -10.292617797851562 66 | 01687569 _derivationally_related_form 01660386 -10.292617797851562 67 | 02447001 _derivationally_related_form 09476521 -10.292617797851562 68 | 01052853 _derivationally_related_form 04930307 -10.292617797851562 69 | 01819554 _derivationally_related_form 08512259 -10.292617797851562 70 | 08645963 _hypernym 09952163 -0.2847031354904175 71 | 02506555 _also_see 01321002 -1.6430387496948242 72 | 02700104 _derivationally_related_form 05749619 -10.292617797851562 73 | 00083334 _derivationally_related_form 02612368 -10.292617797851562 74 | 01364008 _also_see 01000214 -1.6430387496948242 75 | 01606205 _derivationally_related_form 04033995 -10.292617797851562 76 | 07366289 _derivationally_related_form 00017222 -10.292617797851562 77 | 01667304 _derivationally_related_form 10164233 -10.292617797851562 78 | 10160412 _hypernym 01296462 -0.2847031354904175 79 | 02661252 _derivationally_related_form 04623612 -10.292617797851562 80 | 00252019 _derivationally_related_form 00320852 -10.292617797851562 81 | 00898804 _derivationally_related_form 00140393 -10.292618751525879 82 | 02840361 _derivationally_related_form 02660147 -10.292618751525879 83 | 10012815 _derivationally_related_form 00445169 -10.292618751525879 84 | 01643464 _derivationally_related_form 00100044 -10.292618751525879 85 | 01693881 _derivationally_related_form 00152887 -10.292618751525879 86 | 00751887 _derivationally_related_form 06817782 -10.292618751525879 87 | 00482893 _derivationally_related_form 02539334 -10.292618751525879 88 | 00456740 _derivationally_related_form 01027263 15.823774337768555 89 | 00233335 _derivationally_related_form 00187526 -10.292618751525879 90 | 07515560 _derivationally_related_form 02702830 -10.292618751525879 91 | 05750657 _derivationally_related_form 00709625 13.538935661315918 92 | 03496892 _hypernym 10566072 -0.2847031354904175 93 | 01690294 _derivationally_related_form 14442530 -10.292618751525879 94 | 00918820 _derivationally_related_form 00410247 -10.292618751525879 95 | 01531375 _also_see 08677628 -1.6430387496948242 96 | 00299580 _derivationally_related_form 04321238 -10.292618751525879 97 | 00233335 _derivationally_related_form 04433185 -10.292618751525879 98 | 01742886 _also_see 00044149 -1.6430387496948242 99 | 13970236 _hypernym 08512259 -0.2847031354904175 100 | 00353782 _also_see 01922763 -1.6430387496948242 101 | 01624568 _derivationally_related_form 01467370 -10.292618751525879 102 | 02806907 _derivationally_related_form 13489037 -10.292618751525879 103 | 00476819 _derivationally_related_form 09941964 -10.292618751525879 104 | 13971561 _derivationally_related_form 01205827 16.482406616210938 105 | 06526291 _derivationally_related_form 00791227 -10.292618751525879 106 | 00651991 _derivationally_related_form 09758781 -10.292618751525879 107 | 08677628 _hypernym 07253637 -0.2847031354904175 108 | 01582645 _derivationally_related_form 00761713 -10.292618751525879 109 | 00709625 _derivationally_related_form 00696518 -10.292618751525879 110 | 07254057 _derivationally_related_form 01085474 -10.292618751525879 111 | 00859325 _derivationally_related_form 03051540 -10.292618751525879 112 | 09941571 _hypernym 01588493 -0.2847031354904175 113 | 03573282 _derivationally_related_form 13928668 -10.292617797851562 114 | 10300303 _hypernym 00027807 -0.2847031354904175 115 | 00413876 _derivationally_related_form 01190884 -10.292617797851562 116 | 01735308 _derivationally_related_form 01922763 -10.292617797851562 117 | 08552138 _derivationally_related_form 02806907 -10.292617797851562 118 | 00462092 _derivationally_related_form 00714944 -10.292617797851562 119 | 13903387 _derivationally_related_form 00355252 -10.292617797851562 120 | 09779790 _derivationally_related_form 02389592 -10.292617797851562 121 | 01467370 _derivationally_related_form 14110411 -10.292617797851562 122 | 10054657 _derivationally_related_form 03178782 -10.292617797851562 123 | 08622950 _derivationally_related_form 10300303 -10.292617797851562 124 | 07337390 _derivationally_related_form 00038849 -10.292617797851562 125 | 00236289 _derivationally_related_form 05257737 -10.292617797851562 126 | 03420559 _verb_group 01684337 -12.012219429016113 127 | 05075602 _derivationally_related_form 00759551 -10.292617797851562 128 | 09952163 _derivationally_related_form 01722980 -10.292617797851562 129 | 09941571 _derivationally_related_form 07369604 -10.292617797851562 130 | 02700104 _derivationally_related_form 04463273 -10.292617797851562 131 | 01029852 _derivationally_related_form 13928668 -10.292617797851562 132 | 03879854 _hypernym 05093890 -0.2847031354904175 133 | 00236289 _hypernym 15046900 -0.2847031354904175 134 | 00039021 _derivationally_related_form 02539334 -10.292617797851562 135 | 09779790 _derivationally_related_form 10529231 -10.292617797851562 136 | 01782218 _hypernym 07006119 -0.2847031354904175 137 | 03051540 _derivationally_related_form 01286913 -10.292617797851562 138 | 03792334 _derivationally_related_form 01159964 -10.292617797851562 139 | 00661213 _derivationally_related_form 04085873 -10.292617797851562 140 | 13969243 _derivationally_related_form 02553697 -10.292617797851562 141 | 09812338 _derivationally_related_form 04612840 -10.292617797851562 142 | 08612786 _hypernym 03330947 -0.2847031354904175 143 | 13971561 _derivationally_related_form 02671279 -10.292617797851562 144 | 08513163 _derivationally_related_form 01779165 -10.292617797851562 145 | 03146846 _derivationally_related_form 00893955 -10.292617797851562 146 | 02451113 _derivationally_related_form 10389398 -10.292617797851562 147 | 00169651 _hypernym 08620763 -0.2847031354904175 148 | 03670849 _has_part 06022291 0.17999762296676636 149 | 10132641 _derivationally_related_form 00482893 -10.292617797851562 150 | 05696020 _derivationally_related_form 05844105 -10.292617797851562 151 | 06526291 _derivationally_related_form 04321238 -10.292617797851562 152 | 06773976 _derivationally_related_form 01293389 -10.292617797851562 153 | 13969243 _derivationally_related_form 09918248 -10.292617797851562 154 | 09918248 _hypernym 03574816 -0.2847031354904175 155 | 05641959 _derivationally_related_form 00355547 -10.292617797851562 156 | 02710673 _derivationally_related_form 00119524 -10.292617797851562 157 | 01684337 _derivationally_related_form 04051825 -10.292617797851562 158 | 00709625 _derivationally_related_form 13903079 -10.292617797851562 159 | 02512305 _derivationally_related_form 00187526 -10.292617797851562 160 | 03777283 _derivationally_related_form 03051540 -10.292617797851562 161 | 00759551 _hypernym 03315644 -0.2847031354904175 162 | 00696518 _hypernym 04640927 10.4762544631958 163 | 08398036 _derivationally_related_form 01781180 -10.292618751525879 164 | 00119873 _derivationally_related_form 01687569 -10.292618751525879 165 | 10525134 _derivationally_related_form 10640620 -10.292618751525879 166 | 14442530 _hypernym 00187526 -0.2847031354904175 167 | 00650016 _derivationally_related_form 01590171 -10.292618751525879 168 | 00696518 _also_see 10213034 -1.6430387496948242 169 | 01697027 _derivationally_related_form 08512259 -10.292618751525879 170 | 01301410 _derivationally_related_form 10161363 -10.292618751525879 171 | 10388924 _derivationally_related_form 09609232 -1.2063889503479004 172 | 00482893 _derivationally_related_form 01697027 -10.292618751525879 173 | 00234725 _derivationally_related_form 07355887 -10.292618751525879 174 | 08677628 _derivationally_related_form 01203715 -10.292618751525879 175 | 06526291 _also_see 01123148 -1.6430387496948242 176 | 05162455 _hypernym 00579712 -0.2847031354904175 177 | 01753596 _verb_group 00044797 -12.012222290039062 178 | 01295275 _derivationally_related_form 00791227 -10.292619705200195 179 | 00456740 _hypernym 00040804 -0.28470659255981445 180 | 04802629 _derivationally_related_form 01128071 -10.292619705200195 181 | 03533486 _derivationally_related_form 01026262 -10.292619705200195 182 | 07066659 _derivationally_related_form 00921790 -10.292619705200195 183 | 00354884 _derivationally_related_form 01027263 -10.292619705200195 184 | 03237639 _derivationally_related_form 00921738 -10.292619705200195 185 | 03738241 _derivationally_related_form 02666239 -10.292619705200195 186 | 06709533 _derivationally_related_form 00527572 -10.292619705200195 187 | 09260907 _derivationally_related_form 01498713 -10.292619705200195 188 | 00224901 _derivationally_related_form 14445379 -10.292619705200195 189 | -------------------------------------------------------------------------------- /data/WN18RR_v1_ind/1/grail_test_predictions.txt: -------------------------------------------------------------------------------- 1 | 00445169 _similar_to 00444519 82.04751586914062 2 | 02666239 _derivationally_related_form 01410363 120.50161743164062 3 | 03420559 _derivationally_related_form 01087197 156.00933837890625 4 | 01149494 _also_see 01364008 146.71559143066406 5 | 00233335 _derivationally_related_form 05162455 72.59259033203125 6 | 10341660 _derivationally_related_form 02987454 168.15286254882812 7 | 03354613 _derivationally_related_form 01340439 156.00933837890625 8 | 00088481 _derivationally_related_form 02272549 109.56170654296875 9 | 02979662 _derivationally_related_form 01662771 156.00933837890625 10 | 01021128 _derivationally_related_form 05925366 95.9883804321289 11 | 02512305 _derivationally_related_form 01153548 156.00933837890625 12 | 00456740 _derivationally_related_form 07369604 127.81011962890625 13 | 01292885 _derivationally_related_form 07976936 119.663818359375 14 | 01410363 _derivationally_related_form 04750164 25.860916137695312 15 | 00082308 _derivationally_related_form 03075768 156.00933837890625 16 | 07520612 _derivationally_related_form 01780941 128.05587768554688 17 | 05849789 _derivationally_related_form 02630189 156.00933837890625 18 | 08612786 _derivationally_related_form 01276361 121.01397705078125 19 | 00847340 _derivationally_related_form 01428853 156.00933837890625 20 | 02509287 _derivationally_related_form 00808182 124.80843353271484 21 | 00527572 _derivationally_related_form 13491060 -10.292617797851562 22 | 04750164 _derivationally_related_form 01410363 21.537063598632812 23 | 00267349 _derivationally_related_form 01590171 131.60147094726562 24 | 00915722 _derivationally_related_form 01742726 71.46369934082031 25 | 07423001 _hypernym 07355887 25.619985580444336 26 | 02728440 _derivationally_related_form 00046534 130.76214599609375 27 | 01167146 _derivationally_related_form 01612053 132.64356994628906 28 | 03600977 _hypernym 03605915 -0.2847031354904175 29 | 03264542 _hypernym 08592656 -0.2847031354904175 30 | 02443049 _derivationally_related_form 01135529 57.2424430847168 31 | 01779165 _derivationally_related_form 04143712 156.00933837890625 32 | 01779165 _derivationally_related_form 07520612 26.387861251831055 33 | 01753596 _derivationally_related_form 09972157 158.56478881835938 34 | 00922438 _hypernym 00921738 -0.2847031354904175 35 | 00443384 _derivationally_related_form 14110411 170.698486328125 36 | 00064095 _derivationally_related_form 03879854 156.00933837890625 37 | 00149084 _derivationally_related_form 01285440 130.59417724609375 38 | 03779621 _derivationally_related_form 01662771 22.006458282470703 39 | 00353782 _derivationally_related_form 00429060 129.47593688964844 40 | 04659287 _derivationally_related_form 01026262 156.00933837890625 41 | 10371741 _derivationally_related_form 00752335 156.00933837890625 42 | 01662771 _derivationally_related_form 13913566 113.4879150390625 43 | 01240432 _hypernym 01240210 36.353179931640625 44 | 05085572 _derivationally_related_form 00444519 -10.292617797851562 45 | 09941964 _derivationally_related_form 02441022 32.52286148071289 46 | 00082308 _derivationally_related_form 00354884 155.4448699951172 47 | 00429060 _derivationally_related_form 00359903 156.00933837890625 48 | 00444519 _derivationally_related_form 05085572 -10.292617797851562 49 | 00751887 _derivationally_related_form 09941964 32.05967330932617 50 | 00201923 _derivationally_related_form 00462092 156.00933837890625 51 | 01130607 _derivationally_related_form 03878963 10.159109115600586 52 | 01059400 _also_see 02095311 160.81529235839844 53 | 04713332 _hypernym 04712735 23.194915771484375 54 | 13999663 _derivationally_related_form 01301410 156.00933837890625 55 | 03391301 _derivationally_related_form 01586850 156.00933837890625 56 | 00290740 _derivationally_related_form 00351638 129.31231689453125 57 | 00833702 _derivationally_related_form 00893955 156.00933837890625 58 | 01340439 _derivationally_related_form 10080337 156.00933837890625 59 | 01780941 _hypernym 01779165 28.33177375793457 60 | 01474513 _also_see 02451113 118.79241180419922 61 | 00119873 _derivationally_related_form 02987454 116.0631332397461 62 | 10078806 _derivationally_related_form 01739814 126.29859924316406 63 | 00795008 _derivationally_related_form 00047317 -10.292617797851562 64 | 10012815 _derivationally_related_form 00650353 12.988314628601074 65 | 15224293 _derivationally_related_form 00233335 22.845178604125977 66 | 01687569 _derivationally_related_form 01159964 -10.292618751525879 67 | 02447001 _derivationally_related_form 10298912 156.00933837890625 68 | 02659763 _derivationally_related_form 04930307 -10.292618751525879 69 | 01819554 _derivationally_related_form 10525134 156.00933837890625 70 | 09800249 _hypernym 09952163 11.225247383117676 71 | 02506555 _also_see 02064745 203.80670166015625 72 | 02700104 _derivationally_related_form 00552841 156.00933837890625 73 | 00083334 _derivationally_related_form 00149084 156.00933837890625 74 | 01364008 _also_see 01368192 170.90228271484375 75 | 01667449 _derivationally_related_form 04033995 156.00933837890625 76 | 07366289 _derivationally_related_form 02661252 -10.292618751525879 77 | 00595146 _derivationally_related_form 10164233 156.00933837890625 78 | 01340439 _hypernym 01296462 29.873443603515625 79 | 02661252 _derivationally_related_form 04802776 156.00933837890625 80 | 02502536 _derivationally_related_form 00320852 156.00933837890625 81 | 00898804 _derivationally_related_form 01697027 18.502182006835938 82 | 03600977 _derivationally_related_form 02660147 156.00933837890625 83 | 13860793 _derivationally_related_form 00445169 156.00933837890625 84 | 01643464 _derivationally_related_form 10029068 119.0886001586914 85 | 07355887 _derivationally_related_form 00152887 15.64391040802002 86 | 00751887 _derivationally_related_form 09941383 110.06763458251953 87 | 00482893 _derivationally_related_form 00185104 130.6973419189453 88 | 00456740 _derivationally_related_form 05696020 16.8956298828125 89 | 00233335 _derivationally_related_form 15224293 29.14869499206543 90 | 10093908 _derivationally_related_form 02702830 119.43779754638672 91 | 09442838 _derivationally_related_form 00709625 156.00933837890625 92 | 03496892 _hypernym 03322940 -0.2847031354904175 93 | 02646931 _derivationally_related_form 14442530 14.382972717285156 94 | 01834304 _derivationally_related_form 00410247 156.00933837890625 95 | 01531375 _also_see 01508719 203.80670166015625 96 | 00299580 _derivationally_related_form 07369604 28.026962280273438 97 | 00233335 _derivationally_related_form 10525134 -0.8377273082733154 98 | 00046534 _also_see 00044149 52.63560104370117 99 | 08592656 _hypernym 08512259 15.167888641357422 100 | 00087152 _also_see 01922763 203.80670166015625 101 | 02875013 _derivationally_related_form 01467370 156.00933837890625 102 | 02806907 _derivationally_related_form 06806469 156.00933837890625 103 | 02441022 _derivationally_related_form 09941964 35.193695068359375 104 | 00764902 _derivationally_related_form 01205827 156.00933837890625 105 | 14441825 _derivationally_related_form 00791227 146.74520874023438 106 | 00651991 _derivationally_related_form 05748285 128.20938110351562 107 | 07254057 _hypernym 07253637 -0.2847031354904175 108 | 01582645 _derivationally_related_form 03234306 124.86278533935547 109 | 02542280 _derivationally_related_form 00696518 30.13683319091797 110 | 04181228 _derivationally_related_form 01085474 156.00933837890625 111 | 00859325 _derivationally_related_form 04630689 156.00933837890625 112 | 09941571 _hypernym 09943541 -0.2847031354904175 113 | 03573282 _derivationally_related_form 00187526 25.5428524017334 114 | 13860793 _hypernym 00027807 -0.2847031354904175 115 | 00413876 _derivationally_related_form 01051331 156.00933837890625 116 | 07515560 _derivationally_related_form 01922763 -10.292617797851562 117 | 08552138 _derivationally_related_form 02512150 -10.292617797851562 118 | 00462092 _derivationally_related_form 01070892 156.00933837890625 119 | 00290740 _derivationally_related_form 00355252 119.58138275146484 120 | 09779790 _derivationally_related_form 00245457 108.3304443359375 121 | 01467370 _derivationally_related_form 08512736 115.13777160644531 122 | 01753596 _derivationally_related_form 03178782 22.4537296295166 123 | 01428853 _derivationally_related_form 10300303 156.00933837890625 124 | 07337390 _derivationally_related_form 01876907 156.00933837890625 125 | 00236289 _derivationally_related_form 04181228 156.00933837890625 126 | 01551871 _verb_group 01684337 39.13064193725586 127 | 00764902 _derivationally_related_form 00759551 121.2281494140625 128 | 09952163 _derivationally_related_form 01765392 55.56953811645508 129 | 09941571 _derivationally_related_form 00590626 107.76318359375 130 | 02700104 _derivationally_related_form 04713118 102.20590209960938 131 | 01029852 _derivationally_related_form 07202579 156.00933837890625 132 | 05117660 _hypernym 05093890 -0.2847031354904175 133 | 00236289 _hypernym 00233335 35.39875411987305 134 | 10388440 _derivationally_related_form 02539334 137.35916137695312 135 | 09779790 _derivationally_related_form 01104406 156.00933837890625 136 | 01782218 _hypernym 01780202 21.35649871826172 137 | 03051540 _derivationally_related_form 00050652 142.4058380126953 138 | 03792334 _derivationally_related_form 01660640 156.00933837890625 139 | 00661213 _derivationally_related_form 05748786 130.63243103027344 140 | 01146039 _derivationally_related_form 02553697 -10.292618751525879 141 | 09812338 _derivationally_related_form 02991122 32.9810905456543 142 | 08612786 _hypernym 08512259 -0.2847031354904175 143 | 13971561 _derivationally_related_form 00764902 9.248350143432617 144 | 07519253 _derivationally_related_form 01779165 136.30514526367188 145 | 00100044 _derivationally_related_form 00893955 19.923555374145508 146 | 02204692 _derivationally_related_form 10389398 141.67481994628906 147 | 00169651 _hypernym 00170844 -0.2847031354904175 148 | 03670849 _has_part 02845576 0.17999762296676636 149 | 07177437 _derivationally_related_form 00482893 156.00933837890625 150 | 05696020 _derivationally_related_form 00456740 18.010271072387695 151 | 06526291 _derivationally_related_form 10402417 121.0009994506836 152 | 06773976 _derivationally_related_form 01647867 156.00933837890625 153 | 13969243 _derivationally_related_form 02700104 156.00933837890625 154 | 03852280 _hypernym 03574816 -0.2847031354904175 155 | 05641959 _derivationally_related_form 00597385 130.8251495361328 156 | 04748836 _derivationally_related_form 00119524 156.00933837890625 157 | 01684337 _derivationally_related_form 04157320 143.03916931152344 158 | 00709625 _derivationally_related_form 00928077 156.00933837890625 159 | 00321956 _derivationally_related_form 00187526 141.5396270751953 160 | 00047745 _derivationally_related_form 03051540 24.626670837402344 161 | 04085873 _hypernym 03315644 -0.2847031354904175 162 | 04641153 _hypernym 04640927 -0.2847031354904175 163 | 07254057 _derivationally_related_form 01781180 91.38477325439453 164 | 05844105 _derivationally_related_form 01687569 1.1977955102920532 165 | 10525134 _derivationally_related_form 01301051 -10.292617797851562 166 | 14442530 _hypernym 14441825 15.663484573364258 167 | 00650016 _derivationally_related_form 10012815 129.16856384277344 168 | 00696518 _also_see 02564986 203.80670166015625 169 | 01697027 _derivationally_related_form 00898804 14.487100601196289 170 | 01301410 _derivationally_related_form 13998781 -10.292617797851562 171 | 10388924 _derivationally_related_form 00809465 126.84770202636719 172 | 03779621 _derivationally_related_form 01697027 156.00933837890625 173 | 00152887 _derivationally_related_form 07355887 12.46866226196289 174 | 08677628 _derivationally_related_form 02695895 131.54847717285156 175 | 01612053 _also_see 01123148 203.80670166015625 176 | 05162455 _hypernym 05161614 -0.2847031354904175 177 | 00044149 _verb_group 00044797 74.32205963134766 178 | 02539334 _derivationally_related_form 00791227 107.91422271728516 179 | 00040962 _hypernym 00040804 -0.28470659255981445 180 | 04051825 _derivationally_related_form 01128071 156.00936889648438 181 | 04905188 _derivationally_related_form 01026262 156.00936889648438 182 | 01320009 _derivationally_related_form 00921790 -10.292619705200195 183 | 00354884 _derivationally_related_form 01815185 156.00936889648438 184 | 03721797 _derivationally_related_form 00921738 156.00936889648438 185 | 02064745 _derivationally_related_form 02666239 124.5009765625 186 | 13491060 _derivationally_related_form 00527572 -10.292619705200195 187 | 02928413 _derivationally_related_form 01498713 156.00936889648438 188 | 00224901 _derivationally_related_form 09476521 94.54827880859375 189 | -------------------------------------------------------------------------------- /data/WN18RR_v1_ind/1/neg_test_0.txt: -------------------------------------------------------------------------------- 1 | 00445169 _similar_to 04181228 2 | 10291110 _derivationally_related_form 01410363 3 | 01666717 _derivationally_related_form 01087197 4 | 01149494 _also_see 06799897 5 | 00233335 _derivationally_related_form 01093587 6 | 10341660 _derivationally_related_form 04713118 7 | 01283208 _derivationally_related_form 01340439 8 | 00088481 _derivationally_related_form 03128583 9 | 04565375 _derivationally_related_form 01662771 10 | 01021128 _derivationally_related_form 02694933 11 | 02512305 _derivationally_related_form 02023992 12 | 00456740 _derivationally_related_form 02820210 13 | 01292885 _derivationally_related_form 05748285 14 | 04802776 _derivationally_related_form 04750164 15 | 03728811 _derivationally_related_form 03075768 16 | 13903387 _derivationally_related_form 01780941 17 | 05849789 _derivationally_related_form 07515560 18 | 03173524 _derivationally_related_form 01276361 19 | 00847340 _derivationally_related_form 00770141 20 | 02509287 _derivationally_related_form 01276361 21 | 00527572 _derivationally_related_form 06000644 22 | 03961939 _derivationally_related_form 01410363 23 | 00253761 _derivationally_related_form 01590171 24 | 00915722 _derivationally_related_form 00921790 25 | 00737188 _hypernym 07355887 26 | 02728440 _derivationally_related_form 01151097 27 | 01159776 _derivationally_related_form 01612053 28 | 01052853 _hypernym 03605915 29 | 03264542 _hypernym 09769929 30 | 07976936 _derivationally_related_form 01135529 31 | 01779165 _derivationally_related_form 07355887 32 | 01779165 _derivationally_related_form 09941383 33 | 04608567 _derivationally_related_form 09972157 34 | 00922438 _hypernym 04639113 35 | 10078806 _derivationally_related_form 14110411 36 | 00064095 _derivationally_related_form 06526291 37 | 07006712 _derivationally_related_form 01285440 38 | 03385557 _derivationally_related_form 01662771 39 | 00353782 _derivationally_related_form 05748285 40 | 02703952 _derivationally_related_form 01026262 41 | 10371741 _derivationally_related_form 02785648 42 | 00476819 _derivationally_related_form 13913566 43 | 04361641 _hypernym 01240210 44 | 05085572 _derivationally_related_form 07369604 45 | 13860793 _derivationally_related_form 02441022 46 | 00082308 _derivationally_related_form 06004685 47 | 00429060 _derivationally_related_form 03330947 48 | 04613158 _derivationally_related_form 05085572 49 | 00751887 _derivationally_related_form 00595935 50 | 00919513 _derivationally_related_form 00462092 51 | 02538765 _derivationally_related_form 03878963 52 | 01059400 _also_see 03106110 53 | 10291110 _hypernym 04712735 54 | 13999663 _derivationally_related_form 01689379 55 | 01089137 _derivationally_related_form 01586850 56 | 00188466 _derivationally_related_form 00351638 57 | 00520257 _derivationally_related_form 00893955 58 | 10213034 _derivationally_related_form 10080337 59 | 01203715 _hypernym 01779165 60 | 01667449 _also_see 02451113 61 | 00119873 _derivationally_related_form 00253761 62 | 10078806 _derivationally_related_form 08630985 63 | 00169651 _derivationally_related_form 00047317 64 | 10399491 _derivationally_related_form 00650353 65 | 15224293 _derivationally_related_form 01093587 66 | 01687569 _derivationally_related_form 01660386 67 | 02447001 _derivationally_related_form 09476521 68 | 01052853 _derivationally_related_form 04930307 69 | 01819554 _derivationally_related_form 08512259 70 | 08645963 _hypernym 09952163 71 | 02506555 _also_see 01321002 72 | 02700104 _derivationally_related_form 05749619 73 | 00083334 _derivationally_related_form 02612368 74 | 01364008 _also_see 01000214 75 | 01606205 _derivationally_related_form 04033995 76 | 07366289 _derivationally_related_form 00017222 77 | 01667304 _derivationally_related_form 10164233 78 | 10160412 _hypernym 01296462 79 | 02661252 _derivationally_related_form 04623612 80 | 00252019 _derivationally_related_form 00320852 81 | 00898804 _derivationally_related_form 00140393 82 | 02840361 _derivationally_related_form 02660147 83 | 10012815 _derivationally_related_form 00445169 84 | 01643464 _derivationally_related_form 00100044 85 | 01693881 _derivationally_related_form 00152887 86 | 00751887 _derivationally_related_form 06817782 87 | 00482893 _derivationally_related_form 02539334 88 | 00456740 _derivationally_related_form 01027263 89 | 00233335 _derivationally_related_form 00187526 90 | 07515560 _derivationally_related_form 02702830 91 | 05750657 _derivationally_related_form 00709625 92 | 03496892 _hypernym 10566072 93 | 01690294 _derivationally_related_form 14442530 94 | 00918820 _derivationally_related_form 00410247 95 | 01531375 _also_see 08677628 96 | 00299580 _derivationally_related_form 04321238 97 | 00233335 _derivationally_related_form 04433185 98 | 01742886 _also_see 00044149 99 | 13970236 _hypernym 08512259 100 | 00353782 _also_see 01922763 101 | 01624568 _derivationally_related_form 01467370 102 | 02806907 _derivationally_related_form 13489037 103 | 00476819 _derivationally_related_form 09941964 104 | 13971561 _derivationally_related_form 01205827 105 | 06526291 _derivationally_related_form 00791227 106 | 00651991 _derivationally_related_form 09758781 107 | 08677628 _hypernym 07253637 108 | 01582645 _derivationally_related_form 00761713 109 | 00709625 _derivationally_related_form 00696518 110 | 07254057 _derivationally_related_form 01085474 111 | 00859325 _derivationally_related_form 03051540 112 | 09941571 _hypernym 01588493 113 | 03573282 _derivationally_related_form 13928668 114 | 10300303 _hypernym 00027807 115 | 00413876 _derivationally_related_form 01190884 116 | 01735308 _derivationally_related_form 01922763 117 | 08552138 _derivationally_related_form 02806907 118 | 00462092 _derivationally_related_form 00714944 119 | 13903387 _derivationally_related_form 00355252 120 | 09779790 _derivationally_related_form 02389592 121 | 01467370 _derivationally_related_form 14110411 122 | 10054657 _derivationally_related_form 03178782 123 | 08622950 _derivationally_related_form 10300303 124 | 07337390 _derivationally_related_form 00038849 125 | 00236289 _derivationally_related_form 05257737 126 | 03420559 _verb_group 01684337 127 | 05075602 _derivationally_related_form 00759551 128 | 09952163 _derivationally_related_form 01722980 129 | 09941571 _derivationally_related_form 07369604 130 | 02700104 _derivationally_related_form 04463273 131 | 01029852 _derivationally_related_form 13928668 132 | 03879854 _hypernym 05093890 133 | 00236289 _hypernym 15046900 134 | 00039021 _derivationally_related_form 02539334 135 | 09779790 _derivationally_related_form 10529231 136 | 01782218 _hypernym 07006119 137 | 03051540 _derivationally_related_form 01286913 138 | 03792334 _derivationally_related_form 01159964 139 | 00661213 _derivationally_related_form 04085873 140 | 13969243 _derivationally_related_form 02553697 141 | 09812338 _derivationally_related_form 04612840 142 | 08612786 _hypernym 03330947 143 | 13971561 _derivationally_related_form 02671279 144 | 08513163 _derivationally_related_form 01779165 145 | 03146846 _derivationally_related_form 00893955 146 | 02451113 _derivationally_related_form 10389398 147 | 00169651 _hypernym 08620763 148 | 03670849 _has_part 06022291 149 | 10132641 _derivationally_related_form 00482893 150 | 05696020 _derivationally_related_form 05844105 151 | 06526291 _derivationally_related_form 04321238 152 | 06773976 _derivationally_related_form 01293389 153 | 13969243 _derivationally_related_form 09918248 154 | 09918248 _hypernym 03574816 155 | 05641959 _derivationally_related_form 00355547 156 | 02710673 _derivationally_related_form 00119524 157 | 01684337 _derivationally_related_form 04051825 158 | 00709625 _derivationally_related_form 13903079 159 | 02512305 _derivationally_related_form 00187526 160 | 03777283 _derivationally_related_form 03051540 161 | 00759551 _hypernym 03315644 162 | 00696518 _hypernym 04640927 163 | 08398036 _derivationally_related_form 01781180 164 | 00119873 _derivationally_related_form 01687569 165 | 10525134 _derivationally_related_form 10640620 166 | 14442530 _hypernym 00187526 167 | 00650016 _derivationally_related_form 01590171 168 | 00696518 _also_see 10213034 169 | 01697027 _derivationally_related_form 08512259 170 | 01301410 _derivationally_related_form 10161363 171 | 10388924 _derivationally_related_form 09609232 172 | 00482893 _derivationally_related_form 01697027 173 | 00234725 _derivationally_related_form 07355887 174 | 08677628 _derivationally_related_form 01203715 175 | 06526291 _also_see 01123148 176 | 05162455 _hypernym 00579712 177 | 01753596 _verb_group 00044797 178 | 01295275 _derivationally_related_form 00791227 179 | 00456740 _hypernym 00040804 180 | 04802629 _derivationally_related_form 01128071 181 | 03533486 _derivationally_related_form 01026262 182 | 07066659 _derivationally_related_form 00921790 183 | 00354884 _derivationally_related_form 01027263 184 | 03237639 _derivationally_related_form 00921738 185 | 03738241 _derivationally_related_form 02666239 186 | 06709533 _derivationally_related_form 00527572 187 | 09260907 _derivationally_related_form 01498713 188 | 00224901 _derivationally_related_form 14445379 189 | -------------------------------------------------------------------------------- /data/WN18RR_v1_ind/1/test_subgraphs_grail_wn_v1_0_en_True/data.mdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyd418/LogCo/7b6b5795cac46a5ff7c781f93d03c741f48ced71/data/WN18RR_v1_ind/1/test_subgraphs_grail_wn_v1_0_en_True/data.mdb -------------------------------------------------------------------------------- /data/WN18RR_v1_ind/1/test_subgraphs_grail_wn_v1_0_en_True/lock.mdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyd418/LogCo/7b6b5795cac46a5ff7c781f93d03c741f48ced71/data/WN18RR_v1_ind/1/test_subgraphs_grail_wn_v1_0_en_True/lock.mdb -------------------------------------------------------------------------------- /data/WN18RR_v1_ind/test.txt: -------------------------------------------------------------------------------- 1 | 00445169 _similar_to 00444519 2 | 02666239 _derivationally_related_form 01410363 3 | 03420559 _derivationally_related_form 01087197 4 | 01149494 _also_see 01364008 5 | 00233335 _derivationally_related_form 05162455 6 | 10341660 _derivationally_related_form 02987454 7 | 03354613 _derivationally_related_form 01340439 8 | 00088481 _derivationally_related_form 02272549 9 | 02979662 _derivationally_related_form 01662771 10 | 01021128 _derivationally_related_form 05925366 11 | 02512305 _derivationally_related_form 01153548 12 | 00456740 _derivationally_related_form 07369604 13 | 01292885 _derivationally_related_form 07976936 14 | 01410363 _derivationally_related_form 04750164 15 | 00082308 _derivationally_related_form 03075768 16 | 07520612 _derivationally_related_form 01780941 17 | 05849789 _derivationally_related_form 02630189 18 | 08612786 _derivationally_related_form 01276361 19 | 00847340 _derivationally_related_form 01428853 20 | 02509287 _derivationally_related_form 00808182 21 | 00527572 _derivationally_related_form 13491060 22 | 04750164 _derivationally_related_form 01410363 23 | 00267349 _derivationally_related_form 01590171 24 | 00915722 _derivationally_related_form 01742726 25 | 07423001 _hypernym 07355887 26 | 02728440 _derivationally_related_form 00046534 27 | 01167146 _derivationally_related_form 01612053 28 | 03600977 _hypernym 03605915 29 | 03264542 _hypernym 08592656 30 | 02443049 _derivationally_related_form 01135529 31 | 01779165 _derivationally_related_form 04143712 32 | 01779165 _derivationally_related_form 07520612 33 | 01753596 _derivationally_related_form 09972157 34 | 00922438 _hypernym 00921738 35 | 00443384 _derivationally_related_form 14110411 36 | 00064095 _derivationally_related_form 03879854 37 | 00149084 _derivationally_related_form 01285440 38 | 03779621 _derivationally_related_form 01662771 39 | 00353782 _derivationally_related_form 00429060 40 | 04659287 _derivationally_related_form 01026262 41 | 10371741 _derivationally_related_form 00752335 42 | 01662771 _derivationally_related_form 13913566 43 | 01240432 _hypernym 01240210 44 | 05085572 _derivationally_related_form 00444519 45 | 09941964 _derivationally_related_form 02441022 46 | 00082308 _derivationally_related_form 00354884 47 | 00429060 _derivationally_related_form 00359903 48 | 00444519 _derivationally_related_form 05085572 49 | 00751887 _derivationally_related_form 09941964 50 | 00201923 _derivationally_related_form 00462092 51 | 01130607 _derivationally_related_form 03878963 52 | 01059400 _also_see 02095311 53 | 04713332 _hypernym 04712735 54 | 13999663 _derivationally_related_form 01301410 55 | 03391301 _derivationally_related_form 01586850 56 | 00290740 _derivationally_related_form 00351638 57 | 00833702 _derivationally_related_form 00893955 58 | 01340439 _derivationally_related_form 10080337 59 | 01780941 _hypernym 01779165 60 | 01474513 _also_see 02451113 61 | 00119873 _derivationally_related_form 02987454 62 | 10078806 _derivationally_related_form 01739814 63 | 00795008 _derivationally_related_form 00047317 64 | 10012815 _derivationally_related_form 00650353 65 | 15224293 _derivationally_related_form 00233335 66 | 01687569 _derivationally_related_form 01159964 67 | 02447001 _derivationally_related_form 10298912 68 | 02659763 _derivationally_related_form 04930307 69 | 01819554 _derivationally_related_form 10525134 70 | 09800249 _hypernym 09952163 71 | 02506555 _also_see 02064745 72 | 02700104 _derivationally_related_form 00552841 73 | 00083334 _derivationally_related_form 00149084 74 | 01364008 _also_see 01368192 75 | 01667449 _derivationally_related_form 04033995 76 | 07366289 _derivationally_related_form 02661252 77 | 00595146 _derivationally_related_form 10164233 78 | 01340439 _hypernym 01296462 79 | 02661252 _derivationally_related_form 04802776 80 | 02502536 _derivationally_related_form 00320852 81 | 00898804 _derivationally_related_form 01697027 82 | 03600977 _derivationally_related_form 02660147 83 | 13860793 _derivationally_related_form 00445169 84 | 01643464 _derivationally_related_form 10029068 85 | 07355887 _derivationally_related_form 00152887 86 | 00751887 _derivationally_related_form 09941383 87 | 00482893 _derivationally_related_form 00185104 88 | 00456740 _derivationally_related_form 05696020 89 | 00233335 _derivationally_related_form 15224293 90 | 10093908 _derivationally_related_form 02702830 91 | 09442838 _derivationally_related_form 00709625 92 | 03496892 _hypernym 03322940 93 | 02646931 _derivationally_related_form 14442530 94 | 01834304 _derivationally_related_form 00410247 95 | 01531375 _also_see 01508719 96 | 00299580 _derivationally_related_form 07369604 97 | 00233335 _derivationally_related_form 10525134 98 | 00046534 _also_see 00044149 99 | 08592656 _hypernym 08512259 100 | 00087152 _also_see 01922763 101 | 02875013 _derivationally_related_form 01467370 102 | 02806907 _derivationally_related_form 06806469 103 | 02441022 _derivationally_related_form 09941964 104 | 00764902 _derivationally_related_form 01205827 105 | 14441825 _derivationally_related_form 00791227 106 | 00651991 _derivationally_related_form 05748285 107 | 07254057 _hypernym 07253637 108 | 01582645 _derivationally_related_form 03234306 109 | 02542280 _derivationally_related_form 00696518 110 | 04181228 _derivationally_related_form 01085474 111 | 00859325 _derivationally_related_form 04630689 112 | 09941571 _hypernym 09943541 113 | 03573282 _derivationally_related_form 00187526 114 | 13860793 _hypernym 00027807 115 | 00413876 _derivationally_related_form 01051331 116 | 07515560 _derivationally_related_form 01922763 117 | 08552138 _derivationally_related_form 02512150 118 | 00462092 _derivationally_related_form 01070892 119 | 00290740 _derivationally_related_form 00355252 120 | 09779790 _derivationally_related_form 00245457 121 | 01467370 _derivationally_related_form 08512736 122 | 01753596 _derivationally_related_form 03178782 123 | 01428853 _derivationally_related_form 10300303 124 | 07337390 _derivationally_related_form 01876907 125 | 00236289 _derivationally_related_form 04181228 126 | 01551871 _verb_group 01684337 127 | 00764902 _derivationally_related_form 00759551 128 | 09952163 _derivationally_related_form 01765392 129 | 09941571 _derivationally_related_form 00590626 130 | 02700104 _derivationally_related_form 04713118 131 | 01029852 _derivationally_related_form 07202579 132 | 05117660 _hypernym 05093890 133 | 00236289 _hypernym 00233335 134 | 10388440 _derivationally_related_form 02539334 135 | 09779790 _derivationally_related_form 01104406 136 | 01782218 _hypernym 01780202 137 | 03051540 _derivationally_related_form 00050652 138 | 03792334 _derivationally_related_form 01660640 139 | 00661213 _derivationally_related_form 05748786 140 | 01146039 _derivationally_related_form 02553697 141 | 09812338 _derivationally_related_form 02991122 142 | 08612786 _hypernym 08512259 143 | 13971561 _derivationally_related_form 00764902 144 | 07519253 _derivationally_related_form 01779165 145 | 00100044 _derivationally_related_form 00893955 146 | 02204692 _derivationally_related_form 10389398 147 | 00169651 _hypernym 00170844 148 | 03670849 _has_part 02845576 149 | 07177437 _derivationally_related_form 00482893 150 | 05696020 _derivationally_related_form 00456740 151 | 06526291 _derivationally_related_form 10402417 152 | 06773976 _derivationally_related_form 01647867 153 | 13969243 _derivationally_related_form 02700104 154 | 03852280 _hypernym 03574816 155 | 05641959 _derivationally_related_form 00597385 156 | 04748836 _derivationally_related_form 00119524 157 | 01684337 _derivationally_related_form 04157320 158 | 00709625 _derivationally_related_form 00928077 159 | 00321956 _derivationally_related_form 00187526 160 | 00047745 _derivationally_related_form 03051540 161 | 04085873 _hypernym 03315644 162 | 04641153 _hypernym 04640927 163 | 07254057 _derivationally_related_form 01781180 164 | 05844105 _derivationally_related_form 01687569 165 | 10525134 _derivationally_related_form 01301051 166 | 14442530 _hypernym 14441825 167 | 00650016 _derivationally_related_form 10012815 168 | 00696518 _also_see 02564986 169 | 01697027 _derivationally_related_form 00898804 170 | 01301410 _derivationally_related_form 13998781 171 | 10388924 _derivationally_related_form 00809465 172 | 03779621 _derivationally_related_form 01697027 173 | 00152887 _derivationally_related_form 07355887 174 | 08677628 _derivationally_related_form 02695895 175 | 01612053 _also_see 01123148 176 | 05162455 _hypernym 05161614 177 | 00044149 _verb_group 00044797 178 | 02539334 _derivationally_related_form 00791227 179 | 00040962 _hypernym 00040804 180 | 04051825 _derivationally_related_form 01128071 181 | 04905188 _derivationally_related_form 01026262 182 | 01320009 _derivationally_related_form 00921790 183 | 00354884 _derivationally_related_form 01815185 184 | 03721797 _derivationally_related_form 00921738 185 | 02064745 _derivationally_related_form 02666239 186 | 13491060 _derivationally_related_form 00527572 187 | 02928413 _derivationally_related_form 01498713 188 | 00224901 _derivationally_related_form 09476521 189 | -------------------------------------------------------------------------------- /data/WN18RR_v1_ind/valid.txt: -------------------------------------------------------------------------------- 1 | 09953178 _hypernym 09931640 2 | 01027263 _derivationally_related_form 00299580 3 | 03728811 _derivationally_related_form 01292885 4 | 04928903 _derivationally_related_form 01687569 5 | 01301051 _derivationally_related_form 10525134 6 | 09273291 _derivationally_related_form 02711114 7 | 13969700 _derivationally_related_form 01765392 8 | 00590626 _derivationally_related_form 09780828 9 | 13903079 _derivationally_related_form 01466978 10 | 00050652 _hypernym 00046534 11 | 10029068 _derivationally_related_form 00935940 12 | 10676877 _derivationally_related_form 02443049 13 | 02420232 _derivationally_related_form 10078806 14 | 01256157 _derivationally_related_form 10566072 15 | 00151689 _derivationally_related_form 05111835 16 | 04770911 _derivationally_related_form 01876907 17 | 00262703 _derivationally_related_form 03745285 18 | 00708017 _similar_to 00709625 19 | 02695895 _hypernym 02694933 20 | 03051540 _derivationally_related_form 00047745 21 | 01291069 _derivationally_related_form 00145218 22 | 13905792 _derivationally_related_form 01276361 23 | 03878963 _derivationally_related_form 01130607 24 | 00898804 _derivationally_related_form 01743784 25 | 10448983 _derivationally_related_form 00752335 26 | 00082308 _derivationally_related_form 14445379 27 | 00921790 _derivationally_related_form 01320009 28 | 03792048 _derivationally_related_form 01660640 29 | 10525134 _derivationally_related_form 01301410 30 | 01711749 _derivationally_related_form 08664443 31 | 10525134 _derivationally_related_form 00233335 32 | 01693881 _derivationally_related_form 03104594 33 | 02928413 _hypernym 03600977 34 | 03104594 _derivationally_related_form 01693881 35 | 10529231 _derivationally_related_form 02204692 36 | 10689564 _derivationally_related_form 04160372 37 | 13454318 _derivationally_related_form 01742726 38 | 03265479 _hypernym 02875013 39 | 05844105 _derivationally_related_form 10155849 40 | 03932670 _hypernym 03932203 41 | 00083809 _synset_domain_topic_of 00612160 42 | 00151689 _derivationally_related_form 13458571 43 | 01069190 _verb_group 01069391 44 | 01612053 _derivationally_related_form 02542795 45 | 04644512 _derivationally_related_form 02564986 46 | 01647867 _derivationally_related_form 13970236 47 | 00321956 _derivationally_related_form 01580467 48 | 03257343 _derivationally_related_form 01735308 49 | 01410905 _derivationally_related_form 04750164 50 | 00919513 _derivationally_related_form 01567275 51 | 00043683 _derivationally_related_form 02728440 52 | 00764902 _derivationally_related_form 01026262 53 | 04748836 _derivationally_related_form 00651991 54 | 10160412 _derivationally_related_form 00482473 55 | 01739814 _derivationally_related_form 00916464 56 | 13998781 _derivationally_related_form 01301410 57 | 01363613 _also_see 01148283 58 | 03282060 _has_part 04085873 59 | 03792048 _derivationally_related_form 01660386 60 | 00730301 _derivationally_related_form 08512736 61 | 13085864 _derivationally_related_form 01741446 62 | 06998748 _derivationally_related_form 09812338 63 | 00933566 _derivationally_related_form 05117660 64 | 07369604 _derivationally_related_form 00299580 65 | 05902327 _derivationally_related_form 01743784 66 | 07369604 _derivationally_related_form 00300537 67 | 01151110 _hypernym 01987160 68 | 01135529 _derivationally_related_form 02443049 69 | 00696882 _derivationally_related_form 00082714 70 | 00233335 _derivationally_related_form 05846355 71 | 09941964 _derivationally_related_form 00751887 72 | 01743784 _derivationally_related_form 05902327 73 | 00084230 _derivationally_related_form 00612160 74 | 01020936 _hypernym 01019524 75 | 10529231 _derivationally_related_form 02203362 76 | 03777283 _derivationally_related_form 01697406 77 | 01167146 _derivationally_related_form 02542795 78 | 02657219 _derivationally_related_form 03728811 79 | 03327234 _derivationally_related_form 01588134 80 | 01020936 _derivationally_related_form 01742886 81 | 10668450 _hypernym 10525134 82 | 00796047 _derivationally_related_form 06893885 83 | 04613158 _derivationally_related_form 01492052 84 | 04905842 _derivationally_related_form 02388145 85 | 00765213 _derivationally_related_form 09800249 86 | 04463273 _hypernym 03234306 87 | 04930307 _derivationally_related_form 02659763 88 | 01662771 _derivationally_related_form 03779370 89 | 13998576 _derivationally_related_form 02711114 90 | 01813884 _derivationally_related_form 07527352 91 | 03386011 _derivationally_related_form 01606205 92 | 01052853 _derivationally_related_form 01613239 93 | 14442530 _derivationally_related_form 02646931 94 | 01640550 _derivationally_related_form 09972157 95 | 00267349 _hypernym 00266806 96 | 01711749 _derivationally_related_form 08677628 97 | 10155849 _derivationally_related_form 05844105 98 | 01159964 _derivationally_related_form 01687569 99 | 01662771 _derivationally_related_form 00909899 100 | 01780941 _derivationally_related_form 01222666 101 | 13913566 _hypernym 13860793 102 | 00761713 _derivationally_related_form 10351874 103 | 00909363 _also_see 01149494 104 | 01711749 _derivationally_related_form 05075602 105 | 02899439 _hypernym 03673971 106 | 07369604 _derivationally_related_form 00482893 107 | 01248191 _derivationally_related_form 02502536 108 | 05902327 _derivationally_related_form 01683582 109 | 10317007 _hypernym 10582746 110 | 00764902 _derivationally_related_form 07151122 111 | 03322099 _derivationally_related_form 02420232 112 | 02388145 _derivationally_related_form 04905842 113 | 00482473 _hypernym 00296178 114 | 07527352 _derivationally_related_form 01363613 115 | 01765392 _derivationally_related_form 01151407 116 | 00730499 _derivationally_related_form 08592656 117 | 01051331 _derivationally_related_form 01496630 118 | 14441825 _derivationally_related_form 02646931 119 | 00047317 _derivationally_related_form 00795008 120 | 02512305 _derivationally_related_form 10012815 121 | 07537068 _hypernym 07532440 122 | 10093908 _derivationally_related_form 00300537 123 | 05937112 _derivationally_related_form 02723733 124 | 04433185 _derivationally_related_form 01285440 125 | 00651991 _derivationally_related_form 07270179 126 | 02389346 _derivationally_related_form 00145218 127 | 01819554 _derivationally_related_form 01222477 128 | 00409211 _derivationally_related_form 02443849 129 | 00083809 _derivationally_related_form 00671351 130 | 02671279 _derivationally_related_form 13321495 131 | 01224744 _derivationally_related_form 10378412 132 | 01291069 _hypernym 01354673 133 | 10378780 _hypernym 09882007 134 | 09476521 _derivationally_related_form 00290740 135 | 03285912 _derivationally_related_form 02711114 136 | 00150287 _derivationally_related_form 09957614 137 | 05765415 _derivationally_related_form 02806907 138 | 00915722 _hypernym 00913705 139 | 01190884 _hypernym 01187810 140 | 13427078 _derivationally_related_form 00150287 141 | 06791372 _derivationally_related_form 02296984 142 | 01922763 _also_see 01740892 143 | 00119074 _derivationally_related_form 04748836 144 | 07066659 _derivationally_related_form 10155849 145 | 02003725 _derivationally_related_form 00236592 146 | 00916464 _has_part 00921790 147 | 01765392 _derivationally_related_form 07515790 148 | 01222477 _derivationally_related_form 01819554 149 | 00364479 _also_see 01368192 150 | 01684337 _verb_group 01551871 151 | 01148283 _also_see 00999817 152 | 01340439 _derivationally_related_form 00147595 153 | 01922763 _derivationally_related_form 07515560 154 | 00815644 _derivationally_related_form 01150559 155 | 00462092 _derivationally_related_form 04361641 156 | 01876907 _derivationally_related_form 00348571 157 | 07366627 _hypernym 07366289 158 | 05846355 _derivationally_related_form 00235368 159 | 09956578 _derivationally_related_form 00462092 160 | 00650016 _derivationally_related_form 05748054 161 | 03091374 _derivationally_related_form 01354673 162 | 00751887 _hypernym 02539334 163 | 02840361 _hypernym 03496892 164 | 00300537 _derivationally_related_form 04930307 165 | 08592656 _derivationally_related_form 00730499 166 | 01624568 _derivationally_related_form 13913566 167 | 04630689 _derivationally_related_form 00859153 168 | 02991122 _derivationally_related_form 02743547 169 | 01148283 _also_see 00362467 170 | 06003682 _hypernym 06000644 171 | 05198036 _derivationally_related_form 10388440 172 | 02372326 _derivationally_related_form 00040152 173 | 00795008 _derivationally_related_form 02659763 174 | 10645611 _hypernym 10676877 175 | 07527352 _derivationally_related_form 01813884 176 | 03779370 _derivationally_related_form 01697027 177 | 13489037 _derivationally_related_form 00245457 178 | 01051331 _derivationally_related_form 02333689 179 | 02991122 _derivationally_related_form 09812338 180 | 10689564 _hypernym 10120816 181 | 05844105 _derivationally_related_form 01666894 182 | 02539359 _derivationally_related_form 00047745 183 | 03932203 _derivationally_related_form 01656788 184 | 00696518 _also_see 01612053 185 | 01492052 _derivationally_related_form 04612840 186 | -------------------------------------------------------------------------------- /data/nell_v1_ind/test.txt: -------------------------------------------------------------------------------- 1 | concept:televisionstation:wqec concept:subpartof concept:company:pbs 2 | concept:televisionstation:kusm concept:agentbelongstoorganization concept:company:pbs 3 | concept:televisionstation:kqsd_tv concept:agentcollaborateswithagent concept:company:pbs 4 | concept:company:pbs concept:agentcontrols concept:televisionstation:kufm_tv 5 | concept:company:pbs concept:acquired concept:company:kues 6 | concept:company:pbs concept:agentcontrols concept:televisionstation:kyne_tv 7 | concept:company:pbs concept:agentcontrols concept:televisionstation:kcos_tv 8 | concept:televisionstation:wxxi_tv concept:agentbelongstoorganization concept:company:pbs 9 | concept:televisionstation:wvta concept:subpartof concept:company:pbs 10 | concept:televisionstation:wunp_tv concept:agentcollaborateswithagent concept:company:pbs 11 | concept:televisionstation:koac_tv concept:agentbelongstoorganization concept:company:pbs 12 | concept:televisionstation:wsec concept:subpartof concept:company:pbs 13 | concept:televisionstation:wvta concept:agentbelongstoorganization concept:company:pbs 14 | concept:televisionstation:kepb_tv concept:agentbelongstoorganization concept:company:pbs 15 | concept:televisionstation:wedn concept:subpartof concept:company:pbs 16 | concept:televisionstation:wunf_tv concept:agentbelongstoorganization concept:company:pbs 17 | concept:televisionstation:ktwu concept:agentbelongstoorganization concept:company:pbs 18 | concept:televisionstation:kood concept:televisionstationaffiliatedwith concept:televisionnetwork:pbs 19 | concept:televisionstation:wvut concept:subpartof concept:company:pbs 20 | concept:televisionstation:kfts concept:subpartof concept:company:pbs 21 | concept:televisionstation:wunm_tv concept:agentcollaborateswithagent concept:company:pbs 22 | concept:televisionstation:whla_tv concept:televisionstationaffiliatedwith concept:televisionnetwork:pbs 23 | concept:televisionstation:whtj concept:agentbelongstoorganization concept:company:pbs 24 | concept:company:pbs concept:agentcontrols concept:televisionstation:wmed_tv 25 | concept:televisionstation:kpsd_tv concept:subpartof concept:company:pbs 26 | concept:televisionstation:wkyu_tv concept:agentcollaborateswithagent concept:company:pbs 27 | concept:televisionstation:wkon concept:agentcollaborateswithagent concept:company:pbs 28 | concept:televisionstation:kuid_tv concept:subpartof concept:company:pbs 29 | concept:televisionstation:ktsd_tv concept:agentcollaborateswithagent concept:company:pbs 30 | concept:televisionstation:wgiq concept:agentcollaborateswithagent concept:company:pbs 31 | concept:televisionstation:wlef_tv concept:subpartof concept:company:pbs 32 | concept:televisionstation:whro_tv concept:subpartof concept:company:pbs 33 | concept:televisionstation:wung_tv concept:subpartoforganization concept:televisionnetwork:pbs 34 | concept:televisionstation:wkoh concept:agentbelongstoorganization concept:company:pbs 35 | concept:televisionstation:wetp_tv concept:agentbelongstoorganization concept:company:pbs 36 | concept:televisionstation:wfiq concept:agentcollaborateswithagent concept:company:pbs 37 | concept:televisionstation:kopb_tv concept:subpartof concept:company:pbs 38 | concept:televisionstation:wfiq concept:agentbelongstoorganization concept:company:pbs 39 | concept:televisionstation:wmpt concept:subpartof concept:company:pbs 40 | concept:televisionstation:kbyu_tv concept:agentcollaborateswithagent concept:company:pbs 41 | concept:televisionstation:wmeb_tv concept:subpartof concept:company:pbs 42 | concept:company:pbs concept:agentcontrols concept:televisionstation:krma_tv 43 | concept:televisionstation:wntv concept:agentbelongstoorganization concept:company:pbs 44 | concept:televisionstation:ktsc_tv concept:subpartoforganization concept:televisionnetwork:pbs 45 | concept:televisionstation:wmae_tv concept:televisionstationaffiliatedwith concept:televisionnetwork:pbs 46 | concept:televisionstation:wbcc_tv concept:agentbelongstoorganization concept:company:pbs 47 | concept:televisionstation:kcet concept:televisionstationaffiliatedwith concept:televisionnetwork:pbs 48 | concept:televisionstation:kwse concept:agentbelongstoorganization concept:company:pbs 49 | concept:televisionstation:ktci_tv concept:agentbelongstoorganization concept:company:pbs 50 | concept:televisionstation:wmae_tv concept:agentcollaborateswithagent concept:company:pbs 51 | concept:televisionstation:wunj_tv concept:televisionstationaffiliatedwith concept:televisionnetwork:pbs 52 | concept:televisionstation:kufm_tv concept:agentbelongstoorganization concept:company:pbs 53 | concept:televisionstation:kufm_tv concept:subpartoforganization concept:televisionnetwork:pbs 54 | concept:televisionstation:wpbt concept:subpartof concept:company:pbs 55 | concept:televisionstation:whmc concept:subpartoforganization concept:televisionnetwork:pbs 56 | concept:company:pbs concept:agentcontrols concept:televisionstation:kcts_tv 57 | concept:televisionstation:whtj concept:subpartoforganization concept:televisionnetwork:pbs 58 | concept:company:pbs concept:agentcontrols concept:televisionstation:wlpb_tv 59 | concept:televisionstation:wkno_tv concept:agentbelongstoorganization concept:company:pbs 60 | concept:televisionstation:ktsd_tv concept:agentbelongstoorganization concept:company:pbs 61 | concept:televisionstation:wviz_tv concept:agentbelongstoorganization concept:company:pbs 62 | concept:company:pbs concept:agentcontrols concept:televisionstation:wgvu_tv 63 | concept:company:pbs concept:agentcontrols concept:televisionstation:wbiq_tv 64 | concept:company:pbs concept:agentcontrols concept:televisionstation:woub_tv 65 | concept:televisionstation:wenh concept:televisionstationaffiliatedwith concept:televisionnetwork:pbs 66 | concept:televisionstation:kbyu_tv concept:subpartof concept:company:pbs 67 | concept:televisionstation:wunm_tv concept:agentbelongstoorganization concept:company:pbs 68 | concept:televisionstation:klru concept:televisionstationaffiliatedwith concept:televisionnetwork:pbs 69 | concept:televisionstation:wgby_tv concept:subpartof concept:company:pbs 70 | concept:company:pbs concept:agentcontrols concept:televisionstation:wdse_tv 71 | concept:company:pbs concept:agentcontrols concept:televisionstation:wmpt 72 | concept:televisionstation:kacv_tv concept:subpartof concept:company:pbs 73 | concept:televisionstation:kaet concept:subpartof concept:company:pbs 74 | concept:televisionstation:ketc concept:agentcollaborateswithagent concept:company:pbs 75 | concept:televisionstation:kbhe_tv concept:subpartof concept:company:pbs 76 | concept:televisionstation:kwse concept:televisionstationaffiliatedwith concept:televisionnetwork:pbs 77 | concept:televisionstation:krwg_tv concept:subpartof concept:company:pbs 78 | concept:televisionstation:weta_tv concept:subpartof concept:company:pbs 79 | concept:company:pbs concept:agentcontrols concept:televisionstation:kisu_tv 80 | concept:televisionstation:wfwa concept:subpartof concept:company:pbs 81 | concept:televisionstation:weta_tv concept:televisionstationaffiliatedwith concept:televisionnetwork:pbs 82 | concept:televisionstation:kcwc_tv concept:subpartof concept:company:pbs 83 | concept:televisionstation:ktne_tv concept:televisionstationaffiliatedwith concept:televisionnetwork:pbs 84 | concept:company:pbs concept:agentcontrols concept:televisionstation:wmsy_tv 85 | concept:company:pbs concept:agentcontrols concept:televisionstation:kopb_tv 86 | concept:televisionstation:wund_tv concept:subpartof concept:company:pbs 87 | concept:company:pbs concept:agentcontrols concept:televisionstation:kamu_tv 88 | concept:company:pbs concept:agentcontrols concept:televisionstation:wkzt_tv 89 | concept:company:pbs concept:agentcontrols concept:televisionstation:wpby_tv 90 | concept:televisionstation:kedt concept:televisionstationaffiliatedwith concept:televisionnetwork:pbs 91 | concept:televisionstation:ketg concept:subpartof concept:company:pbs 92 | concept:televisionstation:whro_tv concept:agentcollaborateswithagent concept:company:pbs 93 | concept:televisionstation:wneo concept:agentcollaborateswithagent concept:company:pbs 94 | concept:televisionstation:wpbt concept:subpartoforganization concept:televisionnetwork:pbs 95 | concept:televisionstation:kcwc_tv concept:agentbelongstoorganization concept:company:pbs 96 | concept:televisionstation:kopb_tv concept:televisionstationaffiliatedwith concept:televisionnetwork:pbs 97 | concept:televisionstation:wlrn_tv concept:agentcollaborateswithagent concept:company:pbs 98 | concept:televisionstation:weta_tv concept:agentcollaborateswithagent concept:company:pbs 99 | concept:televisionstation:kera_tv concept:agentbelongstoorganization concept:company:pbs 100 | concept:company:pbs concept:agentcontrols concept:televisionstation:wkha 101 | -------------------------------------------------------------------------------- /data/nell_v1_ind/valid.txt: -------------------------------------------------------------------------------- 1 | concept:televisionstation:wune_tv concept:agentbelongstoorganization concept:company:pbs 2 | concept:televisionstation:wsbn_tv concept:subpartof concept:company:pbs 3 | concept:televisionstation:wnin_tv concept:subpartof concept:company:pbs 4 | concept:televisionstation:wgte_tv concept:subpartoforganization concept:televisionnetwork:pbs 5 | concept:televisionstation:wedn concept:subpartoforganization concept:televisionnetwork:pbs 6 | concept:televisionstation:wmed_tv concept:subpartoforganization concept:televisionnetwork:pbs 7 | concept:televisionstation:wfiq concept:subpartoforganization concept:televisionnetwork:pbs 8 | concept:televisionstation:wbra_tv concept:agentbelongstoorganization concept:company:pbs 9 | concept:televisionstation:wkpi concept:subpartoforganization concept:televisionnetwork:pbs 10 | concept:televisionstation:wmsy_tv concept:subpartof concept:company:pbs 11 | concept:televisionstation:krne_tv concept:agentcollaborateswithagent concept:company:pbs 12 | concept:televisionstation:kteh concept:agentbelongstoorganization concept:company:pbs 13 | concept:televisionstation:wkpc_tv concept:televisionstationaffiliatedwith concept:televisionnetwork:pbs 14 | concept:televisionstation:wyes_tv concept:subpartof concept:company:pbs 15 | concept:televisionstation:wnin_tv concept:agentbelongstoorganization concept:company:pbs 16 | concept:televisionstation:wkon concept:agentbelongstoorganization concept:company:pbs 17 | concept:company:pbs concept:agentcontrols concept:televisionstation:kmbh_tv 18 | concept:company:pbs concept:agentcontrols concept:televisionstation:wcfe_tv 19 | concept:televisionstation:wntv concept:televisionstationaffiliatedwith concept:televisionnetwork:pbs 20 | concept:televisionstation:wmec concept:agentbelongstoorganization concept:company:pbs 21 | concept:company:pbs concept:agentcontrols concept:televisionstation:kixe_tv 22 | concept:televisionstation:kusd_tv concept:agentbelongstoorganization concept:company:pbs 23 | concept:company:pbs concept:agentcontrols concept:televisionstation:wund_tv 24 | concept:televisionstation:wlvt_tv concept:agentbelongstoorganization concept:company:pbs 25 | concept:televisionstation:kcsd_tv concept:subpartof concept:company:pbs 26 | concept:televisionstation:wmea_tv concept:subpartof concept:company:pbs 27 | concept:televisionstation:witf_tv concept:agentcollaborateswithagent concept:company:pbs 28 | concept:televisionstation:kued concept:agentbelongstoorganization concept:company:pbs 29 | concept:televisionstation:wnpb_tv concept:subpartoforganization concept:televisionnetwork:pbs 30 | concept:televisionstation:wunf_tv concept:subpartof concept:company:pbs 31 | concept:televisionstation:kcdt_tv concept:subpartof concept:company:pbs 32 | concept:company:pbs concept:agentcontrols concept:televisionstation:krwg_tv 33 | concept:televisionstation:wlef_tv concept:agentcollaborateswithagent concept:company:pbs 34 | concept:televisionstation:kozk concept:subpartoforganization concept:televisionnetwork:pbs 35 | concept:televisionstation:wkyu_tv concept:subpartoforganization concept:televisionnetwork:pbs 36 | concept:televisionstation:ketc concept:televisionstationaffiliatedwith concept:televisionnetwork:pbs 37 | concept:televisionstation:wbiq_tv concept:subpartof concept:company:pbs 38 | concept:televisionstation:kcos_tv concept:subpartof concept:company:pbs 39 | concept:company:pbs concept:agentcontrols concept:televisionstation:kwet 40 | concept:company:pbs concept:agentcontrols concept:televisionstation:ktci_tv 41 | concept:company:pbs concept:agentcontrols concept:televisionstation:wntv 42 | concept:televisionstation:wkha concept:agentcollaborateswithagent concept:company:pbs 43 | concept:company:pbs concept:agentcontrols concept:televisionstation:wnsc_tv 44 | concept:televisionstation:wtiu_tv concept:agentbelongstoorganization concept:company:pbs 45 | concept:televisionstation:wgbx_tv concept:subpartoforganization concept:televisionnetwork:pbs 46 | concept:televisionstation:wkmu concept:subpartof concept:company:pbs 47 | concept:televisionstation:wyes_tv concept:agentbelongstoorganization concept:company:pbs 48 | concept:televisionstation:wkpi concept:subpartof concept:company:pbs 49 | concept:televisionstation:wunl_tv concept:televisionstationaffiliatedwith concept:televisionnetwork:pbs 50 | concept:company:pbs concept:agentcontrols concept:televisionstation:wpto 51 | concept:televisionstation:kcts_tv concept:televisionstationaffiliatedwith concept:televisionnetwork:pbs 52 | concept:televisionstation:wouc_tv concept:agentbelongstoorganization concept:company:pbs 53 | concept:televisionstation:ksps_tv concept:subpartof concept:company:pbs 54 | concept:televisionstation:kbyu_tv concept:televisionstationaffiliatedwith concept:televisionnetwork:pbs 55 | concept:televisionstation:wung_tv concept:agentbelongstoorganization concept:company:pbs 56 | concept:televisionstation:wnjs concept:agentcollaborateswithagent concept:company:pbs 57 | concept:televisionstation:wha__tv concept:subpartof concept:company:pbs 58 | concept:televisionstation:wunu concept:agentbelongstoorganization concept:company:pbs 59 | concept:televisionstation:wkso_tv concept:agentcollaborateswithagent concept:company:pbs 60 | concept:televisionstation:klrn_tv concept:televisionstationaffiliatedwith concept:televisionnetwork:pbs 61 | concept:televisionstation:krwg_tv concept:agentcollaborateswithagent concept:company:pbs 62 | concept:televisionstation:wqec concept:agentcollaborateswithagent concept:company:pbs 63 | concept:company:pbs concept:agentcontrols concept:televisionstation:kbyu_tv 64 | concept:televisionstation:wptd concept:agentbelongstoorganization concept:company:pbs 65 | concept:company:pbs concept:agentcontrols concept:televisionstation:wvpy 66 | concept:company:pbs concept:agentcontrols concept:televisionstation:wvut 67 | concept:televisionstation:wgby_tv concept:subpartoforganization concept:televisionnetwork:pbs 68 | concept:televisionstation:wwpb concept:subpartoforganization concept:televisionnetwork:pbs 69 | concept:televisionstation:wund_tv concept:agentcollaborateswithagent concept:company:pbs 70 | concept:company:pbs concept:agentcontrols concept:televisionstation:wmvs_tv 71 | concept:televisionstation:kepb_tv concept:agentcollaborateswithagent concept:company:pbs 72 | concept:company:pbs concept:agentcontrols concept:televisionstation:kaft 73 | concept:televisionstation:wnjt concept:agentbelongstoorganization concept:company:pbs 74 | concept:company:pbs concept:agentcontrols concept:televisionstation:wswp_tv 75 | concept:televisionstation:kqsd_tv concept:televisionstationaffiliatedwith concept:televisionnetwork:pbs 76 | concept:televisionstation:kmos_tv concept:agentbelongstoorganization concept:company:pbs 77 | concept:televisionstation:wpbo_tv concept:subpartof concept:company:pbs 78 | concept:televisionstation:kltm_tv concept:subpartoforganization concept:televisionnetwork:pbs 79 | concept:company:pbs concept:agentcontrols concept:televisionstation:kpne_tv 80 | concept:televisionstation:wmau_tv concept:subpartof concept:company:pbs 81 | concept:televisionstation:kued concept:subpartoforganization concept:televisionnetwork:pbs 82 | concept:company:pbs concept:agentcontrols concept:televisionstation:wkpi 83 | concept:company:pbs concept:agentcontrols concept:televisionstation:wune_tv 84 | concept:televisionstation:kltm_tv concept:agentbelongstoorganization concept:company:pbs 85 | concept:televisionstation:wmaw_tv concept:agentcollaborateswithagent concept:company:pbs 86 | concept:televisionstation:wunj_tv concept:agentbelongstoorganization concept:company:pbs 87 | concept:televisionstation:kuht concept:agentbelongstoorganization concept:company:pbs 88 | concept:televisionstation:wunf_tv concept:subpartoforganization concept:televisionnetwork:pbs 89 | concept:televisionstation:wvpy concept:subpartof concept:company:pbs 90 | concept:televisionstation:wmsy_tv concept:agentcollaborateswithagent concept:company:pbs 91 | concept:televisionstation:kwet concept:subpartoforganization concept:televisionnetwork:pbs 92 | concept:televisionstation:wmau_tv concept:subpartoforganization concept:televisionnetwork:pbs 93 | concept:televisionstation:wund_tv concept:subpartoforganization concept:televisionnetwork:pbs 94 | concept:televisionstation:wmaw_tv concept:televisionstationaffiliatedwith concept:televisionnetwork:pbs 95 | concept:televisionstation:wnjn concept:agentcollaborateswithagent concept:company:pbs 96 | concept:televisionstation:wmah_tv concept:subpartof concept:company:pbs 97 | concept:company:pbs concept:agentcontrols concept:televisionstation:wmpb_tv 98 | concept:televisionstation:ksmq_tv concept:subpartof concept:company:pbs 99 | concept:company:pbs concept:agentcontrols concept:televisionstation:weta_tv 100 | concept:televisionstation:wmpn_tv concept:agentcollaborateswithagent concept:company:pbs 101 | concept:televisionstation:wune_tv concept:agentcollaborateswithagent concept:company:pbs 102 | -------------------------------------------------------------------------------- /ensembling/blend.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | 9 | 10 | def read_scores(path): 11 | with open(path) as f: 12 | scores = [float(line.split()[-1]) for line in f.read().split('\n')[:-1]] 13 | return scores 14 | 15 | 16 | def get_triplets(path): 17 | with open(path) as f: 18 | triplets = [line.split()[:-1] for line in f.read().split('\n')[:-1]] 19 | return triplets 20 | 21 | 22 | def train(params): 23 | ''' 24 | Train and save a linear layer model. 25 | ''' 26 | ens_model_1_pos_scores_path = os.path.join('../data/{}/{}_valid_predictions.txt'.format(params.dataset, params.ensemble_model_1)) 27 | ens_model_1_neg_scores_path = os.path.join('../data/{}/{}_neg_valid_0_predictions.txt'.format(params.dataset, params.ensemble_model_1)) 28 | ens_model_2_pos_scores_path = os.path.join('../data/{}/{}_valid_predictions.txt'.format(params.dataset, params.ensemble_model_2)) 29 | ens_model_2_neg_scores_path = os.path.join('../data/{}/{}_neg_valid_0_predictions.txt'.format(params.dataset, params.ensemble_model_2)) 30 | 31 | assert get_triplets(ens_model_1_pos_scores_path) == get_triplets(ens_model_2_pos_scores_path) 32 | assert get_triplets(ens_model_1_neg_scores_path) == get_triplets(ens_model_2_neg_scores_path) 33 | 34 | pos_scores = torch.Tensor(list(zip(read_scores(ens_model_1_pos_scores_path), read_scores(ens_model_2_pos_scores_path)))) 35 | neg_scores = torch.Tensor(list(zip(read_scores(ens_model_1_neg_scores_path), read_scores(ens_model_2_neg_scores_path)))) 36 | 37 | # scores = pos_scores + neg_scores 38 | # targets = [1] * len(pos_scores) + [0] * len(neg_scores) 39 | 40 | model = nn.Linear(in_features=2, out_features=1) 41 | criterion = nn.MarginRankingLoss(10, reduction='sum') 42 | optimizer = optim.Adam(model.parameters(), lr=0.1, weight_decay=5e-4) 43 | 44 | for e in range(params.num_epochs): 45 | pos_out = model(pos_scores) 46 | neg_out = model(neg_scores) 47 | 48 | loss = criterion(pos_out, neg_out.view(len(pos_out), -1).mean(dim=1), torch.Tensor([1])) 49 | print('Loss at epoch {} : {}'.format(e, loss)) 50 | optimizer.zero_grad() 51 | loss.backward() 52 | optimizer.step() 53 | 54 | torch.save(model, os.path.join('../experiments', f'{params.ensemble_model_1}_{params.ensemble_model_2}_{params.dataset}_ensemble.pth')) 55 | 56 | 57 | def score_triplets(params): 58 | ''' 59 | Load the saved model and save scores of given set of triplets. 60 | ''' 61 | print('Loading model..') 62 | model = torch.load(os.path.join('../experiments', f'{params.ensemble_model_1}_{params.ensemble_model_2}_{params.dataset}_ensemble.pth')) 63 | print('Model loaded successfully!') 64 | 65 | ens_model_1_scores_path = os.path.join('../data/{}/{}_{}_predictions.txt'.format(params.dataset, params.ensemble_model_1, params.file_to_score)) 66 | ens_model_2_scores_path = os.path.join('../data/{}/{}_{}_predictions.txt'.format(params.dataset, params.ensemble_model_2, params.file_to_score)) 67 | 68 | scores = torch.Tensor(list(zip(read_scores(ens_model_1_scores_path), read_scores(ens_model_2_scores_path)))) 69 | ens_scores = model(scores) 70 | 71 | ens_model_1_triplets = get_triplets(ens_model_1_scores_path) 72 | ens_model_2_triplets = get_triplets(ens_model_2_scores_path) 73 | 74 | assert ens_model_1_triplets == ens_model_2_triplets 75 | 76 | file_path = os.path.join('../', 'data/{}/{}_with_{}_{}_predictions.txt'.format(params.dataset, params.ensemble_model_1, params.ensemble_model_2, params.file_to_score)) 77 | with open(file_path, "w") as f: 78 | for ([s, r, o], score) in zip(ens_model_1_triplets, ens_scores): 79 | f.write('\t'.join([s, r, o, str(score.item())]) + '\n') 80 | 81 | 82 | if __name__ == '__main__': 83 | parser = argparse.ArgumentParser(description='Model blender script') 84 | 85 | parser.add_argument('--dataset', '-d', type=str, default='Toy') 86 | parser.add_argument('--ensemble_model_1', '-em1', default='grail', type=str) 87 | parser.add_argument('--ensemble_model_2', '-em2', default='TransE', type=str) 88 | parser.add_argument('--do_train', action='store_true') 89 | parser.add_argument("--num_epochs", "-ne", type=int, default=500, 90 | help="Number of training iterations") 91 | parser.add_argument('--do_scoring', action='store_true') 92 | parser.add_argument('--file_to_score', '-f', default='valid', type=str) 93 | 94 | params = parser.parse_args() 95 | 96 | if params.do_train: 97 | train(params) 98 | elif params.do_scoring: 99 | score_triplets(params) 100 | -------------------------------------------------------------------------------- /ensembling/compute_auc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from sklearn import metrics 3 | 4 | if __name__ == '__main__': 5 | parser = argparse.ArgumentParser(description='Compute AUC from scored positive and negative triplets') 6 | 7 | parser.add_argument('--dataset', '-d', type=str, default='Toy') 8 | parser.add_argument('--model', '-m', default='ens', type=str) 9 | parser.add_argument('--test_file', '-t', default='test', type=str) 10 | 11 | params = parser.parse_args() 12 | 13 | # load pos and neg prediction scores of the test_file of the dataset for the given model. 14 | with open('../data/{}/{}_{}_predictions.txt'.format(params.dataset, params.model, params.test_file)) as f: 15 | pos_scores = [float(line.split()[-1]) for line in f.read().split('\n')[:-1]] 16 | with open('../data/{}/{}_neg_{}_0_predictions.txt'.format(params.dataset, params.model, params.test_file)) as f: 17 | neg_scores = [float(line.split()[-1]) for line in f.read().split('\n')[:-1]] 18 | 19 | # compute auc score 20 | scores = pos_scores + neg_scores 21 | labels = [1] * len(pos_scores) + [0] * len(neg_scores) 22 | 23 | auc = metrics.roc_auc_score(labels, scores) 24 | auc_pr = metrics.average_precision_score(labels, scores) 25 | 26 | with open('../data/{}/{}_{}_auc.txt'.format(params.dataset, params.model, params.test_file), "w") as f: 27 | f.write('AUC : {}, AUC_PR : {}\n'.format(auc, auc_pr)) 28 | -------------------------------------------------------------------------------- /ensembling/compute_rank_metrics.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | from scipy.stats import rankdata 6 | 7 | 8 | def get_ranks(scores): 9 | ''' 10 | Given scores of head/tail substituted triplets, return ranks of each triplet. 11 | Assumes a fixed number of negative samples (50) 12 | ''' 13 | ranks = [] 14 | for i in range(len(scores) // 50): 15 | # rank = np.argwhere(np.argsort(scores[50 * i: 50 * (i + 1)])[::-1] == 0) + 1 16 | rank = 50 - rankdata(scores[50 * i: 50 * (i + 1)], method='min')[0] + 1 17 | ranks.append(rank) 18 | return ranks 19 | 20 | 21 | if __name__ == '__main__': 22 | parser = argparse.ArgumentParser(description='Compute AUC from scored positive and negative triplets') 23 | 24 | parser.add_argument('--dataset', '-d', type=str, default='Toy') 25 | parser.add_argument('--model', '-m', default='ens', type=str) 26 | 27 | params = parser.parse_args() 28 | 29 | # load head and tail prediction scores of the test file of the dataset for the given model. 30 | with open('../data/{}/{}_ranking_head_predictions.txt'.format(params.dataset, params.model)) as f: 31 | head_scores = [float(line.split()[-1]) for line in f.read().split('\n')[:-1]] 32 | with open('../data/{}/{}_ranking_tail_predictions.txt'.format(params.dataset, params.model)) as f: 33 | tail_scores = [float(line.split()[-1]) for line in f.read().split('\n')[:-1]] 34 | 35 | # compute both ranks from the prediction scores 36 | head_ranks = get_ranks(head_scores) 37 | tail_ranks = get_ranks(tail_scores) 38 | 39 | ranks = head_ranks + tail_ranks 40 | 41 | isHit1List = [x for x in ranks if x <= 1] 42 | isHit5List = [x for x in ranks if x <= 5] 43 | isHit10List = [x for x in ranks if x <= 10] 44 | hits_1 = len(isHit1List) / len(ranks) 45 | hits_5 = len(isHit5List) / len(ranks) 46 | hits_10 = len(isHit10List) / len(ranks) 47 | 48 | mrr = np.mean(1 / np.array(ranks)) 49 | 50 | with open('../data/{}/{}_ranking_metrics.txt'.format(params.dataset, params.model), "w") as f: 51 | f.write(f'MRR | Hits@1 | Hits@5 | Hits@10 : {mrr} | {hits_1} | {hits_5} | {hits_10}\n') 52 | -------------------------------------------------------------------------------- /ensembling/get_ensemble_predictions.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # This script assumes GraIL predection scores on the validation and test set are already saved. 4 | # It also assumes that scored head/tail replaced triplets are also stored. 5 | # If any of those is not present, run the corresponding script from the following setup commands. 6 | ##################### SET UP ##################### 7 | # python test_auc.py -d WN18RR -e saved_grail_exp_name --hop 3 -t valid 8 | # python test_auc.py -d WN18RR -e saved_grail_exp_name --hop 3 -t test 9 | 10 | # python test_auc.py -d NELL-995 -e saved_grail_exp_name --hop 2 -t valid 11 | # python test.py -d NELL-995 -e saved_grail_exp_name --hop 2 -t test 12 | 13 | # python test_auc.py -d FB15K237 -e saved_grail_exp_name --hop 1 -t valid 14 | # python test_auc.py -d FB15K237 -e saved_grail_exp_name --hop 1 -t test 15 | 16 | # python test_ranking.py -d WN18RR -e saved_grail_exp_name --hop 3 17 | 18 | # python test_ranking.py -d NELL-995 -e saved_grail_exp_name --hop 2 19 | 20 | # python test_ranking.py -d FB15K237 -e saved_grail_exp_name --hop 1 21 | ################################################## 22 | 23 | 24 | # Arguments 25 | # Dataset 26 | DATASET=$1 27 | # KGE model to be used in ensemble 28 | KGE_MODEL=$2 29 | KGE_SAVED_MODEL_PATH="../experiments/kge_baselines/${KGE_MODEL}_${DATASET}" 30 | 31 | # score pos validation triplets with KGE model 32 | python score_triplets_kge.py -d $DATASET --model $KGE_MODEL -f valid -init $KGE_SAVED_MODEL_PATH 33 | # score neg validation triplets with KGE model 34 | python score_triplets_kge.py -d $DATASET --model $KGE_MODEL -f neg_valid_0 -init $KGE_SAVED_MODEL_PATH 35 | 36 | # train the ensemble model 37 | python blend.py -d $DATASET -em2 $KGE_MODEL --do_train -ne 500 38 | 39 | # Score the test pos and neg triplets with KGE model 40 | python score_triplets_kge.py -d $DATASET --model $KGE_MODEL -f test -init $KGE_SAVED_MODEL_PATH 41 | python score_triplets_kge.py -d $DATASET --model $KGE_MODEL -f neg_test_0 -init $KGE_SAVED_MODEL_PATH 42 | # Score the test pos and neg triplets with ensemble model 43 | python blend.py -d $DATASET -em2 $KGE_MODEL --do_scoring -f test 44 | python blend.py -d $DATASET -em2 $KGE_MODEL --do_scoring -f neg_test_0 45 | # Compute auc with the ensemble model scored pos and neg test files 46 | python compute_auc.py -d $DATASET -m grail_with_${KGE_MODEL} 47 | # Compute auc with the KGE model model scored pos and neg test files 48 | python compute_auc.py -d $DATASET -m $KGE_MODEL 49 | 50 | # Score head/tail replaced samples with KGE model 51 | python score_triplets_kge.py -d $DATASET --model $KGE_MODEL -f ranking_head -init $KGE_SAVED_MODEL_PATH 52 | python score_triplets_kge.py -d $DATASET --model $KGE_MODEL -f ranking_tail -init $KGE_SAVED_MODEL_PATH 53 | # Score head/tail replaced samples with ensemble model 54 | python blend.py -d $DATASET -em2 $KGE_MODEL --do_scoring -f ranking_head 55 | python blend.py -d $DATASET -em2 $KGE_MODEL --do_scoring -f ranking_tail 56 | # Compute ranking metrics for ensemble model with the scored head/tail replaced samples 57 | python compute_rank_metrics.py -d $DATASET -m grail_with_${KGE_MODEL} 58 | # Compute ranking metrics for KGE model with the scored head/tail replaced samples 59 | python compute_rank_metrics.py -d $DATASET -m $KGE_MODEL -------------------------------------------------------------------------------- /ensembling/get_kge_ensemble.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # This script assumes that head/tail replaced negative triplets are already stored while evaluating GraIL. 4 | # This assumptionn is made in order to make fair evaluations of all the methods on the same negative samples. 5 | # If any of those is not present, run the corresponding script from the following setup commands. These will 6 | # evaluate GraIL and savee thee negative samples along the way. 7 | ##################### SET UP ##################### 8 | # python test_auc.py -d WN18RR -e saved_grail_exp_name --hop 3 -t valid 9 | # python test_auc.py -d WN18RR -e saved_grail_exp_name --hop 3 -t test 10 | 11 | # python test_auc.py -d NELL-995 -e saved_grail_exp_name --hop 2 -t valid 12 | # python test_auc.py -d NELL-995 -e saved_grail_exp_name --hop 2 -t test 13 | 14 | # python test_auc.py -d FB15K237 -e saved_grail_exp_name --hop 1 -t valid 15 | # python test_auc.py -d FB15K237 -e saved_grail_exp_name --hop 1 -t test 16 | 17 | # python test_ranking.py -d WN18RR -e saved_grail_exp_name --hop 3 18 | 19 | # python test_ranking.py -d NELL-995 -e saved_grail_exp_name --hop 2 20 | 21 | # python test_ranking.py -d FB15K237 -e saved_grail_exp_name --hop 1 22 | ################################################## 23 | 24 | 25 | # Arguments 26 | # Dataset 27 | DATASET=$1 28 | # KGE model to be used in ensemble 29 | KGE_MODEL_1=$2 30 | KGE_SAVED_MODEL_PATH_1="../experiments/kge_baselines/${KGE_MODEL_1}_${DATASET}" 31 | 32 | KGE_MODEL_2=$3 33 | KGE_SAVED_MODEL_PATH_2="../experiments/kge_baselines/${KGE_MODEL_2}_${DATASET}" 34 | 35 | # score pos validation triplets with KGE model 36 | python score_triplets_kge.py -d $DATASET --model $KGE_MODEL_1 -f valid -init $KGE_SAVED_MODEL_PATH_1 37 | # score neg validation triplets with KGE model 38 | python score_triplets_kge.py -d $DATASET --model $KGE_MODEL_1 -f neg_valid_0 -init $KGE_SAVED_MODEL_PATH_1 39 | 40 | # score pos validation triplets with KGE model 41 | python score_triplets_kge.py -d $DATASET --model $KGE_MODEL_2 -f valid -init $KGE_SAVED_MODEL_PATH_2 42 | # score neg validation triplets with KGE model 43 | python score_triplets_kge.py -d $DATASET --model $KGE_MODEL_2 -f neg_valid_0 -init $KGE_SAVED_MODEL_PATH_2 44 | 45 | # train the ensemble model 46 | python blend.py -d $DATASET -em1 $KGE_MODEL_1 -em2 $KGE_MODEL_2 --do_train -ne 500 47 | 48 | # Score the test pos and neg triplets with KGE model 49 | python score_triplets_kge.py -d $DATASET --model $KGE_MODEL_1 -f test -init $KGE_SAVED_MODEL_PATH_1 50 | python score_triplets_kge.py -d $DATASET --model $KGE_MODEL_1 -f neg_test_0 -init $KGE_SAVED_MODEL_PATH_1 51 | # Score the test pos and neg triplets with KGE model 52 | python score_triplets_kge.py -d $DATASET --model $KGE_MODEL_2 -f test -init $KGE_SAVED_MODEL_PATH_2 53 | python score_triplets_kge.py -d $DATASET --model $KGE_MODEL_2 -f neg_test_0 -init $KGE_SAVED_MODEL_PATH_2 54 | 55 | 56 | # Score the test pos and neg triplets with ensemble model 57 | python blend.py -d $DATASET -em1 $KGE_MODEL_1 -em2 $KGE_MODEL_2 --do_scoring -f test 58 | python blend.py -d $DATASET -em1 $KGE_MODEL_1 -em2 $KGE_MODEL_2 --do_scoring -f neg_test_0 59 | # Compute auc with the ensemble model scored pos and neg test files 60 | python compute_auc.py -d $DATASET -m ${KGE_MODEL_1}_with_${KGE_MODEL_2} 61 | 62 | # Score head/tail replaced samples with KGE model 63 | python score_triplets_kge.py -d $DATASET --model $KGE_MODEL_1 -f ranking_head -init $KGE_SAVED_MODEL_PATH_1 64 | python score_triplets_kge.py -d $DATASET --model $KGE_MODEL_1 -f ranking_tail -init $KGE_SAVED_MODEL_PATH_1 65 | # Score head/tail replaced samples with KGE model 66 | python score_triplets_kge.py -d $DATASET --model $KGE_MODEL_2 -f ranking_head -init $KGE_SAVED_MODEL_PATH_2 67 | python score_triplets_kge.py -d $DATASET --model $KGE_MODEL_2 -f ranking_tail -init $KGE_SAVED_MODEL_PATH_2 68 | 69 | 70 | # Score head/tail replaced samples with ensemble model 71 | python blend.py -d $DATASET -em1 $KGE_MODEL_1 -em2 $KGE_MODEL_2 --do_scoring -f ranking_head 72 | python blend.py -d $DATASET -em1 $KGE_MODEL_1 -em2 $KGE_MODEL_2 --do_scoring -f ranking_tail 73 | # Compute ranking metrics for ensemble model with the scored head/tail replaced samples 74 | python compute_rank_metrics.py -d $DATASET -m ${KGE_MODEL_1}_with_${KGE_MODEL_2} -------------------------------------------------------------------------------- /ensembling/score_triplets_kge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import sys 8 | sys.path.insert(1, '../') 9 | 10 | import argparse 11 | import json 12 | import logging 13 | import os 14 | 15 | import torch 16 | 17 | from kge.model import KGEModel 18 | 19 | from utils.data_utils import process_files 20 | 21 | 22 | def parse_args(args=None): 23 | parser = argparse.ArgumentParser( 24 | description='Training and Testing Knowledge Graph Embedding Models', 25 | usage='train.py [] [-h | --help]' 26 | ) 27 | 28 | parser.add_argument('--cuda', action='store_true', help='use GPU') 29 | 30 | parser.add_argument('--dataset', '-d', type=str, default='Toy') 31 | parser.add_argument('--model', '-m', default='TransE', type=str) 32 | parser.add_argument('--file_to_score', '-f', default='test', type=str) 33 | parser.add_argument('--init_checkpoint', '-init', default=None, type=str) 34 | 35 | return parser.parse_args(args) 36 | 37 | 38 | def override_config(args): 39 | ''' 40 | Override model and data configuration 41 | ''' 42 | 43 | with open(os.path.join(args.init_checkpoint, 'config.json'), 'r') as fjson: 44 | argparse_dict = json.load(fjson) 45 | 46 | args.countries = argparse_dict['countries'] 47 | if args.dataset is None: 48 | args.dataset = argparse_dict['dataset'] 49 | args.model = argparse_dict['model'] 50 | args.double_entity_embedding = argparse_dict['double_entity_embedding'] 51 | args.double_relation_embedding = argparse_dict['double_relation_embedding'] 52 | args.hidden_dim = argparse_dict['hidden_dim'] 53 | args.test_batch_size = argparse_dict['test_batch_size'] 54 | args.gamma = argparse_dict['gamma'] 55 | 56 | 57 | def set_logger(args): 58 | ''' 59 | Write logs to checkpoint and console 60 | ''' 61 | log_file = os.path.join(args.init_checkpoint, 'score_{}.log'.format(args.file_to_score)) 62 | 63 | logging.basicConfig( 64 | format='%(asctime)s %(levelname)-8s %(message)s', 65 | level=logging.INFO, 66 | datefmt='%Y-%m-%d %H:%M:%S', 67 | filename=log_file, 68 | filemode='w' 69 | ) 70 | console = logging.StreamHandler() 71 | console.setLevel(logging.INFO) 72 | formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s') 73 | console.setFormatter(formatter) 74 | logging.getLogger('').addHandler(console) 75 | 76 | 77 | def read_triple(file_path, entity2id, relation2id): 78 | ''' 79 | Read triples and map them into ids. 80 | ''' 81 | triples = [] 82 | with open(file_path) as fin: 83 | for line in fin: 84 | h, r, t = line.strip().split('\t') 85 | triples.append((entity2id[h], relation2id[r], entity2id[t])) 86 | return triples 87 | 88 | 89 | def main(args): 90 | if args.init_checkpoint: 91 | override_config(args) 92 | elif args.dataset is None: 93 | raise ValueError('one of init_checkpoint/dataset must be choosed.') 94 | 95 | # Write logs to checkpoint and console 96 | set_logger(args) 97 | 98 | main_dir = os.path.join(os.path.relpath(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), '.') 99 | 100 | with open(os.path.join(main_dir, 'data/{}/entities.dict'.format(args.dataset))) as fin: 101 | entity2id = dict() 102 | for line in fin: 103 | eid, entity = line.strip().split('\t') 104 | entity2id[entity] = int(eid) 105 | 106 | with open(os.path.join(main_dir, 'data/{}/relations.dict'.format(args.dataset))) as fin: 107 | relation2id = dict() 108 | for line in fin: 109 | rid, relation = line.strip().split('\t') 110 | relation2id[relation] = int(rid) 111 | 112 | # test_triples = to_kge_format(triplets['to_score']) 113 | test_triples = read_triple(os.path.join(main_dir, 'data/{}/{}.txt'.format(args.dataset, args.file_to_score)), entity2id, relation2id) 114 | 115 | nentity = len(entity2id) 116 | nrelation = len(relation2id) 117 | args.nentity = nentity 118 | args.nrelation = nrelation 119 | 120 | logging.info('Model: %s' % args.model) 121 | logging.info('Data Path: %s' % args.dataset) 122 | logging.info('#entity: %d' % nentity) 123 | logging.info('#relation: %d' % nrelation) 124 | 125 | kge_model = KGEModel( 126 | model_name=args.model, 127 | nentity=nentity, 128 | nrelation=nrelation, 129 | hidden_dim=args.hidden_dim, 130 | gamma=args.gamma, 131 | double_entity_embedding=args.double_entity_embedding, 132 | double_relation_embedding=args.double_relation_embedding 133 | ) 134 | 135 | logging.info('Model Parameter Configuration:') 136 | for name, param in kge_model.named_parameters(): 137 | logging.info('Parameter %s: %s, require_grad = %s' % (name, str(param.size()), str(param.requires_grad))) 138 | 139 | if args.cuda: 140 | kge_model = kge_model.cuda() 141 | 142 | # Restore model from checkpoint directory 143 | logging.info('Loading checkpoint %s...' % args.init_checkpoint) 144 | checkpoint = torch.load(os.path.join(args.init_checkpoint, 'checkpoint')) 145 | kge_model.load_state_dict(checkpoint['model_state_dict']) 146 | logging.info('Scoring the triplets in {}.txt file'.format(args.file_to_score)) 147 | scores = kge_model.score_triplets(kge_model, test_triples, args) 148 | 149 | with open(os.path.join(main_dir, 'data/{}/{}.txt'.format(args.dataset, args.file_to_score))) as f: 150 | triplets = [line.split() for line in f.read().split('\n')[:-1]] 151 | file_path = os.path.join(main_dir, 'data/{}/{}_{}_predictions.txt'.format(args.dataset, args.model, args.file_to_score)) 152 | with open(file_path, "w") as f: 153 | for ([s, r, o], score) in zip(triplets, scores): 154 | f.write('\t'.join([s, r, o, str(score)]) + '\n') 155 | 156 | 157 | if __name__ == '__main__': 158 | main(parse_args()) 159 | -------------------------------------------------------------------------------- /managers/evaluator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from sklearn import metrics 4 | from torch.utils.data import DataLoader 5 | from relational_path.path_process import path_generate, path_emb_generate_batch_ori 6 | 7 | 8 | class Evaluator(): 9 | def __init__(self, params, graph_classifier, data): 10 | self.params = params 11 | self.graph_classifier = graph_classifier 12 | self.data = data # valid的data 13 | 14 | def eval(self, save=False): 15 | pos_scores = [] 16 | pos_labels = [] 17 | neg_scores = [] 18 | neg_labels = [] 19 | dataloader = DataLoader(self.data, batch_size=self.params.batch_size, shuffle=False, num_workers=self.params.num_workers, collate_fn=self.params.collate_fn) 20 | 21 | self.graph_classifier.eval() 22 | with torch.no_grad(): 23 | for b_idx, batch in enumerate(dataloader): 24 | 25 | 26 | data_pos, targets_pos, data_neg, targets_neg = self.params.move_batch_to_device(batch, self.params.device) 27 | 28 | rels_emb = self.graph_classifier.rel_emb 29 | pos_paths, neg_paths, target_rels = path_generate(data_pos, self.params.max_path_len,self.params.num_rels) 30 | pos_paths_emb_batch, s_p_pos = path_emb_generate_batch_ori(pos_paths, rels_emb.cpu(), target_rels.cpu()) 31 | neg_paths_emb_batch, s_p_neg = path_emb_generate_batch_ori(neg_paths, rels_emb.cpu(), target_rels.cpu()) 32 | self.graph_classifier.to(device=self.params.device) 33 | 34 | # print([self.data.id2relation[r.item()] for r in data_pos[1]]) 35 | # pdb.set_trace() 36 | score_pos = self.graph_classifier(data_pos, s_p_pos.to(device=self.params.device)) 37 | score_neg = self.graph_classifier(data_neg, s_p_neg.to(device=self.params.device)) 38 | 39 | # preds += torch.argmax(logits.detach().cpu(), dim=1).tolist() 40 | pos_scores += score_pos.squeeze(1).detach().cpu().tolist() 41 | neg_scores += score_neg.squeeze(1).detach().cpu().tolist() 42 | pos_labels += targets_pos.tolist() 43 | neg_labels += targets_neg.tolist() 44 | 45 | # acc = metrics.accuracy_score(labels, preds) 46 | auc = metrics.roc_auc_score(pos_labels + neg_labels, pos_scores + neg_scores) #计算roc值,二值分类器 47 | auc_pr = metrics.average_precision_score(pos_labels + neg_labels, pos_scores + neg_scores) 48 | 49 | if save: 50 | pos_test_triplets_path = os.path.join(self.params.main_dir, 'data/{}/{}.txt'.format(self.params.dataset, self.data.file_name)) 51 | with open(pos_test_triplets_path) as f: 52 | pos_triplets = [line.split() for line in f.read().split('\n')[:-1]] 53 | pos_file_path = os.path.join(self.params.main_dir, 'data/{}/grail_{}_predictions.txt'.format(self.params.dataset, self.data.file_name)) 54 | with open(pos_file_path, "w") as f: 55 | for ([s, r, o], score) in zip(pos_triplets, pos_scores): 56 | f.write('\t'.join([s, r, o, str(score)]) + '\n') 57 | 58 | neg_test_triplets_path = os.path.join(self.params.main_dir, 'data/{}/neg_{}_0.txt'.format(self.params.dataset, self.data.file_name)) 59 | with open(neg_test_triplets_path) as f: 60 | neg_triplets = [line.split() for line in f.read().split('\n')[:-1]] 61 | neg_file_path = os.path.join(self.params.main_dir, 'data/{}/grail_neg_{}_{}_predictions.txt'.format(self.params.dataset, self.data.file_name, self.params.constrained_neg_prob)) 62 | with open(neg_file_path, "w") as f: 63 | for ([s, r, o], score) in zip(neg_triplets, neg_scores): 64 | f.write('\t'.join([s, r, o, str(score)]) + '\n') 65 | 66 | return {'auc': auc, 'auc_pr': auc_pr} 67 | -------------------------------------------------------------------------------- /managers/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import time 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from torch.utils.data import DataLoader 9 | 10 | from relational_path.path_process import path_generate, path_emb_generate_batch, path_cross_loss, path_contrast_loss 11 | 12 | from sklearn import metrics 13 | 14 | 15 | class Trainer(): 16 | def __init__(self, params, graph_classifier, train, valid_evaluator=None): 17 | self.graph_classifier = graph_classifier 18 | self.valid_evaluator = valid_evaluator 19 | self.params = params 20 | self.train_data = train 21 | 22 | self.updates_counter = 0 23 | 24 | model_params = list(self.graph_classifier.parameters()) 25 | logging.info('Total number of parameters: %d' % sum(map(lambda x: x.numel(), model_params))) 26 | 27 | if params.optimizer == "SGD": 28 | self.optimizer = optim.SGD(model_params, lr=params.lr, momentum=params.momentum, weight_decay=self.params.l2) 29 | if params.optimizer == "Adam": 30 | self.optimizer = optim.Adam(model_params, lr=params.lr, weight_decay=self.params.l2) 31 | 32 | self.criterion = nn.MarginRankingLoss(self.params.margin, reduction='sum') # 返回标量 33 | 34 | self.reset_training_state() 35 | 36 | def reset_training_state(self): 37 | self.best_metric = 0 38 | self.last_metric = 0 39 | self.not_improved_count = 0 40 | 41 | def train_epoch(self, epoch): 42 | paths_epoch = 0 43 | paths_epoch_list = [] 44 | total_loss = 0 45 | all_preds = [] 46 | all_labels = [] 47 | all_scores = [] 48 | # num_rels = self.params.num_rels 49 | 50 | dataloader = DataLoader(self.train_data, batch_size=self.params.batch_size, shuffle=True, num_workers=self.params.num_workers, collate_fn=self.params.collate_fn) 51 | self.graph_classifier.train() # model.train() 52 | model_params = list(self.graph_classifier.parameters()) # model参数list 53 | for b_idx, batch in enumerate(dataloader): 54 | data_pos, targets_pos, data_neg, targets_neg = self.params.move_batch_to_device(batch, self.params.device) 55 | 56 | # index = torch.LongTensor([0, 1]).to(device=self.params.device) 57 | rels_emb = self.graph_classifier.rel_emb 58 | rels_emb_gpu = rels_emb.weight 59 | 60 | # 获得路径正例,负例,表示等 61 | pos_paths, neg_paths, target_rels = path_generate(data_pos, self.params.max_path_len, self.params.num_rels) 62 | # pos_paths, neg_paths, target_rels = path_generate_mul_neg(data_pos, self.params.max_path_len, self.params.num_rels) 63 | # pos_paths_emb_batch, s_p_pos = path_emb_generate_batch(pos_paths, rels_emb.cpu(), target_rels.cpu()) 64 | # neg_paths_emb_batch, s_p_neg = path_emb_generate_batch(neg_paths, rels_emb.cpu(), target_rels.cpu()) 65 | 66 | pos_paths_emb_batch, s_p_pos = path_emb_generate_batch(pos_paths, rels_emb.cpu(), target_rels.cpu(), epoch, self.params.num_epochs, self.params.exp_dir, 0) 67 | neg_paths_emb_batch, s_p_neg = path_emb_generate_batch(neg_paths, rels_emb.cpu(), target_rels.cpu(), epoch, self.params.num_epochs, self.params.exp_dir, 0) 68 | 69 | # paths_nums = get_paths_nums(pos_paths) 70 | 71 | # 两个loss 72 | cross_loss = path_cross_loss(s_p_pos.to(device=self.params.device), rels_emb_gpu.to(device=self.params.device), target_rels) 73 | contrast_loss = path_contrast_loss(s_p_pos.to(device=self.params.device), 74 | s_p_neg.to(device=self.params.device), 75 | rels_emb.to(device=self.params.device), target_rels) 76 | # contrast_loss_2 = path_contrast_loss_2(pos_paths_emb_batch, 77 | # neg_paths_emb_batch, 78 | # rels_emb.to(device=self.params.device), target_rels, self.params.device) 79 | 80 | self.optimizer.zero_grad() # 1) 清空过往梯度 81 | # score_pos = self.graph_classifier(data_pos) 82 | # score_neg = self.graph_classifier(data_neg) 83 | score_pos = self.graph_classifier(data_pos, s_p_pos.to(device=self.params.device)) # GraphClassifier.forward() 84 | score_neg = self.graph_classifier(data_neg, s_p_neg.to(device=self.params.device)) 85 | loss_triple = self.criterion(score_pos, score_neg.view(len(score_pos), -1).mean(dim=1), torch.Tensor([1]).to(device=self.params.device)) # 考虑到多个neg 86 | loss_1 = cross_loss.cpu().tolist() 87 | loss_2 = contrast_loss.cpu().tolist() 88 | loss = loss_triple + self.params.lambda_cross * loss_1 + self.params.lambda_contrast * loss_2 89 | # print(score_pos, score_neg, loss) 90 | loss.backward() # 2) 反向传播,计算当前梯度 91 | self.optimizer.step() # 3) 根据梯度更新网络参数 92 | self.updates_counter += 1 93 | 94 | with torch.no_grad(): # 这一部分不track梯度, 为了使下面的计算图不占用内存 95 | all_scores += score_pos.squeeze().detach().cpu().tolist() + score_neg.squeeze().detach().cpu().tolist() # 所有得分函数, list拼接 96 | all_labels += targets_pos.tolist() + targets_neg.tolist() # 所有labels, list拼接 97 | total_loss += loss # 一个epoch的总loss 98 | # paths_epoch_list.append(paths_nums) 99 | 100 | if self.valid_evaluator and self.params.eval_every_iter and self.updates_counter % self.params.eval_every_iter == 0: # 利用valid 验证, 多次训练验证一次 101 | tic = time.time() 102 | result = self.valid_evaluator.eval() 103 | logging.info('\nPerformance:' + str(result) + 'in ' + str(time.time() - tic)) 104 | 105 | if result['auc'] >= self.best_metric: # 最高的auc 106 | self.save_classifier() 107 | self.best_metric = result['auc'] 108 | self.not_improved_count = 0 109 | 110 | else: 111 | self.not_improved_count += 1 112 | if self.not_improved_count > self.params.early_stop: 113 | logging.info(f"Validation performance didn\'t improve for {self.params.early_stop} epochs. Training stops.") 114 | break 115 | self.last_metric = result['auc'] 116 | 117 | # paths_epoch = np.mean(paths_epoch_list) 118 | # logging.info(f"Average paths: {paths_epoch}.") 119 | 120 | auc = metrics.roc_auc_score(all_labels, all_scores) # 机器学习准确率 roc面积 121 | auc_pr = metrics.average_precision_score(all_labels, all_scores) 122 | 123 | weight_norm = sum(map(lambda x: torch.norm(x), model_params)) # 权重参数的范数 124 | 125 | return total_loss, auc, auc_pr, weight_norm 126 | 127 | def train(self): 128 | self.reset_training_state() 129 | 130 | for epoch in range(1, self.params.num_epochs + 1): 131 | time_start = time.time() 132 | loss, auc, auc_pr, weight_norm = self.train_epoch(epoch) 133 | time_elapsed = time.time() - time_start 134 | logging.info(f'Epoch {epoch} with loss: {loss}, training auc: {auc}, training auc_pr: {auc_pr}, best validation AUC: {self.best_metric}, weight_norm: {weight_norm} in {time_elapsed}') 135 | 136 | # if self.valid_evaluator and epoch % self.params.eval_every == 0: 137 | # result = self.valid_evaluator.eval() 138 | # logging.info('\nPerformance:' + str(result)) 139 | 140 | # if result['auc'] >= self.best_metric: 141 | # self.save_classifier() 142 | # self.best_metric = result['auc'] 143 | # self.not_improved_count = 0 144 | 145 | # else: 146 | # self.not_improved_count += 1 147 | # if self.not_improved_count > self.params.early_stop: 148 | # logging.info(f"Validation performance didn\'t improve for {self.params.early_stop} epochs. Training stops.") 149 | # break 150 | # self.last_metric = result['auc'] 151 | 152 | if epoch % self.params.save_every == 0: 153 | torch.save(self.graph_classifier, os.path.join(self.params.exp_dir, 'graph_classifier_chk.pth')) 154 | 155 | def save_classifier(self): 156 | save_dir = os.path.join(self.params.exp_dir, f"{self.updates_counter}") 157 | if not os.path.exists(save_dir): 158 | os.makedirs(save_dir) 159 | torch.save(self.graph_classifier, os.path.join(save_dir, 160 | 'best_graph_classifier.pth')) # Does it overwrite or fuck with the existing file? 161 | logging.info('Better models found w.r.t accuracy. Saved it!') 162 | -------------------------------------------------------------------------------- /model/dgl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyd418/LogCo/7b6b5795cac46a5ff7c781f93d03c741f48ced71/model/dgl/__init__.py -------------------------------------------------------------------------------- /model/dgl/aggregators.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import torch.nn as nn 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | class Aggregator(nn.Module): 8 | def __init__(self, emb_dim): 9 | super(Aggregator, self).__init__() 10 | 11 | def forward(self, node): 12 | curr_emb = node.mailbox['curr_emb'][:, 0, :] # (B, F) 13 | nei_msg = torch.bmm(node.mailbox['alpha'].transpose(1, 2), node.mailbox['msg']).squeeze(1) # (B, F) 14 | # nei_msg, _ = torch.max(node.mailbox['msg'], 1) # (B, F) 15 | 16 | new_emb = self.update_embedding(curr_emb, nei_msg) 17 | 18 | return {'h': new_emb} 19 | 20 | @abc.abstractmethod 21 | def update_embedding(curr_emb, nei_msg): 22 | raise NotImplementedError 23 | 24 | 25 | class SumAggregator(Aggregator): 26 | def __init__(self, emb_dim): 27 | super(SumAggregator, self).__init__(emb_dim) 28 | 29 | def update_embedding(self, curr_emb, nei_msg): 30 | new_emb = nei_msg + curr_emb 31 | 32 | return new_emb 33 | 34 | 35 | class MLPAggregator(Aggregator): 36 | def __init__(self, emb_dim): 37 | super(MLPAggregator, self).__init__(emb_dim) 38 | self.linear = nn.Linear(2 * emb_dim, emb_dim) 39 | 40 | def update_embedding(self, curr_emb, nei_msg): 41 | inp = torch.cat((nei_msg, curr_emb), 1) 42 | new_emb = F.relu(self.linear(inp)) 43 | 44 | return new_emb 45 | 46 | 47 | class GRUAggregator(Aggregator): 48 | def __init__(self, emb_dim): 49 | super(GRUAggregator, self).__init__(emb_dim) 50 | self.gru = nn.GRUCell(emb_dim, emb_dim) 51 | 52 | def update_embedding(self, curr_emb, nei_msg): 53 | new_emb = self.gru(nei_msg, curr_emb) 54 | 55 | return new_emb 56 | -------------------------------------------------------------------------------- /model/dgl/graph_classifier.py: -------------------------------------------------------------------------------- 1 | from .rgcn_model import RGCN 2 | from dgl import mean_nodes 3 | import torch.nn as nn 4 | import torch 5 | """ 6 | File based off of dgl tutorial on RGCN 7 | Source: https://github.com/dmlc/dgl/tree/master/examples/pytorch/rgcn 8 | """ 9 | 10 | 11 | class GraphClassifier(nn.Module): # Module子类 12 | def __init__(self, params, relation2id): # in_dim, h_dim, rel_emb_dim, out_dim, num_rels, num_bases): 13 | super().__init__() 14 | 15 | self.params = params 16 | self.relation2id = relation2id 17 | 18 | self.gnn = RGCN(params) # in_dim, h_dim, h_dim, num_rels, num_bases) 19 | self.rel_emb = nn.Embedding(self.params.num_rels, self.params.rel_emb_dim, sparse=False) 20 | 21 | if self.params.add_ht_emb: 22 | if self.params.add_pt_emb: 23 | self.fc_layer = nn.Linear(3 * self.params.num_gcn_layers * self.params.emb_dim + 2 * self.params.rel_emb_dim, 1) 24 | else: 25 | self.fc_layer = nn.Linear(3 * self.params.num_gcn_layers * self.params.emb_dim + self.params.rel_emb_dim, 1) 26 | else: 27 | self.fc_layer = nn.Linear(self.params.num_gcn_layers * self.params.emb_dim + self.params.rel_emb_dim, 1) 28 | 29 | def forward(self, data, s_p): 30 | g, rel_labels = data 31 | g.ndata['h'] = self.gnn(g) 32 | 33 | g_out = mean_nodes(g, 'repr') 34 | 35 | head_ids = (g.ndata['id'] == 1).nonzero().squeeze(1) 36 | head_embs = g.ndata['repr'][head_ids] 37 | 38 | tail_ids = (g.ndata['id'] == 2).nonzero().squeeze(1) 39 | tail_embs = g.ndata['repr'][tail_ids] 40 | 41 | if self.params.add_ht_emb: 42 | if self.params.add_pt_emb: 43 | g_rep = torch.cat([g_out.view(-1, self.params.num_gcn_layers * self.params.emb_dim), 44 | head_embs.view(-1, self.params.num_gcn_layers * self.params.emb_dim), 45 | tail_embs.view(-1, self.params.num_gcn_layers * self.params.emb_dim), 46 | self.rel_emb(rel_labels), s_p], dim=1) 47 | else: 48 | g_rep = torch.cat([g_out.view(-1, self.params.num_gcn_layers * self.params.emb_dim), 49 | head_embs.view(-1, self.params.num_gcn_layers * self.params.emb_dim), 50 | tail_embs.view(-1, self.params.num_gcn_layers * self.params.emb_dim), 51 | self.rel_emb(rel_labels)], dim=1) 52 | else: 53 | g_rep = torch.cat([g_out.view(-1, self.params.num_gcn_layers * self.params.emb_dim), self.rel_emb(rel_labels)], dim=1) 54 | 55 | output = self.fc_layer(g_rep) 56 | return output 57 | -------------------------------------------------------------------------------- /model/dgl/layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | File baseed off of dgl tutorial on RGCN 3 | Source: https://github.com/dmlc/dgl/tree/master/examples/pytorch/rgcn 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class Identity(nn.Module): 11 | """A placeholder identity operator that is argument-insensitive. 12 | (Identity has already been supported by PyTorch 1.2, we will directly 13 | import torch.nn.Identity in the future) 14 | """ 15 | 16 | def __init__(self): 17 | super(Identity, self).__init__() 18 | 19 | def forward(self, x): 20 | """Return input""" 21 | return x 22 | 23 | 24 | class RGCNLayer(nn.Module): 25 | def __init__(self, inp_dim, out_dim, aggregator, bias=None, activation=None, dropout=0.0, edge_dropout=0.0, is_input_layer=False): 26 | super(RGCNLayer, self).__init__() 27 | self.bias = bias 28 | self.activation = activation 29 | 30 | if self.bias: 31 | self.bias = nn.Parameter(torch.Tensor(out_dim)) 32 | nn.init.xavier_uniform_(self.bias, 33 | gain=nn.init.calculate_gain('relu')) 34 | 35 | self.aggregator = aggregator 36 | 37 | if dropout: 38 | self.dropout = nn.Dropout(dropout) 39 | else: 40 | self.dropout = None 41 | 42 | if edge_dropout: 43 | self.edge_dropout = nn.Dropout(edge_dropout) 44 | else: 45 | self.edge_dropout = Identity() 46 | 47 | # define how propagation is done in subclass 48 | def propagate(self, g): 49 | raise NotImplementedError 50 | 51 | def forward(self, g, attn_rel_emb=None): 52 | 53 | self.propagate(g, attn_rel_emb) 54 | 55 | # apply bias and activation 56 | node_repr = g.ndata['h'] 57 | if self.bias: 58 | node_repr = node_repr + self.bias 59 | if self.activation: 60 | node_repr = self.activation(node_repr) 61 | if self.dropout: 62 | node_repr = self.dropout(node_repr) 63 | 64 | g.ndata['h'] = node_repr 65 | 66 | if self.is_input_layer: 67 | g.ndata['repr'] = g.ndata['h'].unsqueeze(1) 68 | else: 69 | g.ndata['repr'] = torch.cat([g.ndata['repr'], g.ndata['h'].unsqueeze(1)], dim=1) 70 | 71 | 72 | class RGCNBasisLayer(RGCNLayer): 73 | def __init__(self, inp_dim, out_dim, aggregator, attn_rel_emb_dim, num_rels, num_bases=-1, bias=None, 74 | activation=None, dropout=0.0, edge_dropout=0.0, is_input_layer=False, has_attn=False): 75 | super( 76 | RGCNBasisLayer, 77 | self).__init__( 78 | inp_dim, 79 | out_dim, 80 | aggregator, 81 | bias, 82 | activation, 83 | dropout=dropout, 84 | edge_dropout=edge_dropout, 85 | is_input_layer=is_input_layer) 86 | self.inp_dim = inp_dim 87 | self.out_dim = out_dim 88 | self.attn_rel_emb_dim = attn_rel_emb_dim 89 | self.num_rels = num_rels 90 | self.num_bases = num_bases 91 | self.is_input_layer = is_input_layer 92 | self.has_attn = has_attn 93 | 94 | if self.num_bases <= 0 or self.num_bases > self.num_rels: 95 | self.num_bases = self.num_rels 96 | 97 | # add basis weights 98 | # self.weight = basis_weights 99 | self.weight = nn.Parameter(torch.Tensor(self.num_bases, self.inp_dim, self.out_dim)) 100 | self.w_comp = nn.Parameter(torch.Tensor(self.num_rels, self.num_bases)) 101 | 102 | if self.has_attn: 103 | self.A = nn.Linear(2 * self.inp_dim + 2 * self.attn_rel_emb_dim, inp_dim) 104 | self.B = nn.Linear(inp_dim, 1) 105 | 106 | self.self_loop_weight = nn.Parameter(torch.Tensor(self.inp_dim, self.out_dim)) 107 | 108 | nn.init.xavier_uniform_(self.self_loop_weight, gain=nn.init.calculate_gain('relu')) 109 | nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu')) 110 | nn.init.xavier_uniform_(self.w_comp, gain=nn.init.calculate_gain('relu')) 111 | 112 | def propagate(self, g, attn_rel_emb=None): 113 | # generate all weights from bases 114 | weight = self.weight.view(self.num_bases, 115 | self.inp_dim * self.out_dim) 116 | weight = torch.matmul(self.w_comp, weight).view( 117 | self.num_rels, self.inp_dim, self.out_dim) 118 | 119 | g.edata['w'] = self.edge_dropout(torch.ones(g.number_of_edges(), 1).to(weight.device)) 120 | 121 | input_ = 'feat' if self.is_input_layer else 'h' 122 | 123 | def msg_func(edges): 124 | w = weight.index_select(0, edges.data['type']) 125 | msg = edges.data['w'] * torch.bmm(edges.src[input_].unsqueeze(1), w).squeeze(1) 126 | curr_emb = torch.mm(edges.dst[input_], self.self_loop_weight) # (B, F) 127 | 128 | if self.has_attn: 129 | e = torch.cat([edges.src[input_], edges.dst[input_], attn_rel_emb(edges.data['type']), attn_rel_emb(edges.data['label'])], dim=1) 130 | a = torch.sigmoid(self.B(F.relu(self.A(e)))) 131 | else: 132 | a = torch.ones((len(edges), 1)).to(device=w.device) 133 | 134 | return {'curr_emb': curr_emb, 'msg': msg, 'alpha': a} 135 | 136 | g.update_all(msg_func, self.aggregator, None) 137 | -------------------------------------------------------------------------------- /model/dgl/layers_new.py: -------------------------------------------------------------------------------- 1 | """ 2 | File baseed off of dgl tutorial on RGCN 3 | Source: https://github.com/dmlc/dgl/tree/master/examples/pytorch/rgcn 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class Identity(nn.Module): 11 | """A placeholder identity operator that is argument-insensitive. 12 | (Identity has already been supported by PyTorch 1.2, we will directly 13 | import torch.nn.Identity in the future) 14 | """ 15 | 16 | def __init__(self): 17 | super(Identity, self).__init__() 18 | 19 | def forward(self, x): 20 | """Return input""" 21 | return x 22 | 23 | 24 | class RGCNLayer(nn.Module): 25 | def __init__(self, inp_dim, out_dim, aggregator, bias=None, activation=None, dropout=0.0, edge_dropout=0.0, is_input_layer=False): 26 | super(RGCNLayer, self).__init__() 27 | self.bias = bias 28 | self.activation = activation 29 | 30 | if self.bias: 31 | self.bias = nn.Parameter(torch.Tensor(out_dim)) 32 | nn.init.xavier_uniform_(self.bias, 33 | gain=nn.init.calculate_gain('relu')) 34 | 35 | self.aggregator = aggregator 36 | 37 | if dropout: 38 | self.dropout = nn.Dropout(dropout) 39 | else: 40 | self.dropout = None 41 | 42 | if edge_dropout: 43 | self.edge_dropout = nn.Dropout(edge_dropout) 44 | else: 45 | self.edge_dropout = Identity() 46 | 47 | # define how propagation is done in subclass 48 | def propagate(self, g): 49 | raise NotImplementedError 50 | 51 | def forward(self, g, attn_rel_emb=None): 52 | 53 | self.propagate(g, attn_rel_emb) 54 | 55 | # apply bias and activation 56 | node_repr = g.ndata['h'] 57 | if self.bias: 58 | node_repr = node_repr + self.bias 59 | if self.activation: 60 | node_repr = self.activation(node_repr) 61 | if self.dropout: 62 | node_repr = self.dropout(node_repr) 63 | 64 | g.ndata['h'] = node_repr 65 | 66 | if self.is_input_layer: 67 | g.ndata['repr'] = g.ndata['h'].unsqueeze(1) 68 | else: 69 | g.ndata['repr'] = torch.cat([g.ndata['repr'], g.ndata['h'].unsqueeze(1)], dim=1) 70 | 71 | 72 | class RGCNBasisLayer(RGCNLayer): 73 | def __init__(self, inp_dim, out_dim, aggregator, attn_rel_emb_dim, num_rels, num_bases=-1, bias=None, 74 | activation=None, dropout=0.0, edge_dropout=0.0, is_input_layer=False, has_attn=False, self_attn=False): 75 | super( 76 | RGCNBasisLayer, 77 | self).__init__( 78 | inp_dim, 79 | out_dim, 80 | aggregator, 81 | bias, 82 | activation, 83 | dropout=dropout, 84 | edge_dropout=edge_dropout, 85 | is_input_layer=is_input_layer) 86 | self.inp_dim = inp_dim 87 | self.out_dim = out_dim 88 | self.attn_rel_emb_dim = attn_rel_emb_dim 89 | self.num_rels = num_rels 90 | self.num_bases = num_bases 91 | self.is_input_layer = is_input_layer 92 | self.has_attn = has_attn 93 | self.self_attn = self_attn 94 | 95 | if self.num_bases <= 0 or self.num_bases > self.num_rels: 96 | self.num_bases = self.num_rels 97 | 98 | # add basis weights 99 | # self.weight = basis_weights 100 | self.weight = nn.Parameter(torch.Tensor(self.num_bases, self.inp_dim, self.out_dim)) 101 | self.w_comp = nn.Parameter(torch.Tensor(self.num_rels, self.num_bases)) 102 | 103 | if self.self_attn: 104 | self.A = nn.Linear(self.inp_dim + self.attn_rel_emb_dim, inp_dim) 105 | self.B = nn.Linear(2 * self.inp_dim + self.attn_rel_emb_dim, 1) 106 | else: 107 | if self.has_attn: 108 | self.A = nn.Linear(2 * self.inp_dim + 2 * self.attn_rel_emb_dim, inp_dim) 109 | self.B = nn.Linear(inp_dim, 1) 110 | 111 | 112 | self.self_loop_weight = nn.Parameter(torch.Tensor(self.inp_dim, self.out_dim)) 113 | 114 | nn.init.xavier_uniform_(self.self_loop_weight, gain=nn.init.calculate_gain('relu')) 115 | nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu')) 116 | nn.init.xavier_uniform_(self.w_comp, gain=nn.init.calculate_gain('relu')) 117 | 118 | def propagate(self, g, attn_rel_emb=None): 119 | # generate all weights from bases 120 | weight = self.weight.view(self.num_bases, 121 | self.inp_dim * self.out_dim) 122 | weight = torch.matmul(self.w_comp, weight).view( 123 | self.num_rels, self.inp_dim, self.out_dim) 124 | 125 | g.edata['w'] = self.edge_dropout(torch.ones(g.number_of_edges(), 1).to(weight.device)) 126 | 127 | input_ = 'feat' if self.is_input_layer else 'h' 128 | 129 | def msg_func(edges): 130 | w = weight.index_select(0, edges.data['type']) 131 | msg = edges.data['w'] * torch.bmm(edges.src[input_].unsqueeze(1), w).squeeze(1) 132 | curr_emb = torch.mm(edges.dst[input_], self.self_loop_weight) # (B, F) 133 | 134 | if self.self_attn: 135 | e = torch.cat([edges.src[input_], attn_rel_emb(edges.data['type'])], dim=1) 136 | e_linear = F.relu(self.A(e)) 137 | rel_emb = attn_rel_emb(edges.data['label']) 138 | s = torch.cat([e_linear, rel_emb, edges.dst[input_]], dim=1) 139 | a = torch.sigmoid(self.B(s)) 140 | else: 141 | if self.has_attn: 142 | e = torch.cat([edges.src[input_], edges.dst[input_], attn_rel_emb(edges.data['type']), 143 | attn_rel_emb(edges.data['label'])], dim=1) 144 | a = torch.sigmoid(self.B(F.relu(self.A(e)))) 145 | else: 146 | a = torch.ones((len(edges), 1)).to(device=w.device) 147 | 148 | return {'curr_emb': curr_emb, 'msg': msg, 'alpha': a} 149 | 150 | g.update_all(msg_func, self.aggregator, None) 151 | -------------------------------------------------------------------------------- /model/dgl/layers_ori.py: -------------------------------------------------------------------------------- 1 | """ 2 | File baseed off of dgl tutorial on RGCN 3 | Source: https://github.com/dmlc/dgl/tree/master/examples/pytorch/rgcn 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class Identity(nn.Module): 11 | """A placeholder identity operator that is argument-insensitive. 12 | (Identity has already been supported by PyTorch 1.2, we will directly 13 | import torch.nn.Identity in the future) 14 | """ 15 | 16 | def __init__(self): 17 | super(Identity, self).__init__() 18 | 19 | def forward(self, x): 20 | """Return input""" 21 | return x 22 | 23 | 24 | class RGCNLayer(nn.Module): 25 | def __init__(self, inp_dim, out_dim, aggregator, bias=None, activation=None, dropout=0.0, edge_dropout=0.0, is_input_layer=False): 26 | super(RGCNLayer, self).__init__() 27 | self.bias = bias 28 | self.activation = activation 29 | 30 | if self.bias: 31 | self.bias = nn.Parameter(torch.Tensor(out_dim)) 32 | nn.init.xavier_uniform_(self.bias, 33 | gain=nn.init.calculate_gain('relu')) 34 | 35 | self.aggregator = aggregator 36 | 37 | if dropout: 38 | self.dropout = nn.Dropout(dropout) 39 | else: 40 | self.dropout = None 41 | 42 | if edge_dropout: 43 | self.edge_dropout = nn.Dropout(edge_dropout) 44 | else: 45 | self.edge_dropout = Identity() 46 | 47 | # define how propagation is done in subclass 48 | def propagate(self, g): 49 | raise NotImplementedError 50 | 51 | def forward(self, g, attn_rel_emb=None): 52 | 53 | self.propagate(g, attn_rel_emb) 54 | 55 | # apply bias and activation 56 | node_repr = g.ndata['h'] 57 | if self.bias: 58 | node_repr = node_repr + self.bias 59 | if self.activation: 60 | node_repr = self.activation(node_repr) 61 | if self.dropout: 62 | node_repr = self.dropout(node_repr) 63 | 64 | g.ndata['h'] = node_repr 65 | 66 | if self.is_input_layer: 67 | g.ndata['repr'] = g.ndata['h'].unsqueeze(1) 68 | else: 69 | g.ndata['repr'] = torch.cat([g.ndata['repr'], g.ndata['h'].unsqueeze(1)], dim=1) 70 | 71 | 72 | class RGCNBasisLayer(RGCNLayer): 73 | def __init__(self, inp_dim, out_dim, aggregator, attn_rel_emb_dim, num_rels, num_bases=-1, bias=None, 74 | activation=None, dropout=0.0, edge_dropout=0.0, is_input_layer=False, has_attn=False): 75 | super( 76 | RGCNBasisLayer, 77 | self).__init__( 78 | inp_dim, 79 | out_dim, 80 | aggregator, 81 | bias, 82 | activation, 83 | dropout=dropout, 84 | edge_dropout=edge_dropout, 85 | is_input_layer=is_input_layer) 86 | self.inp_dim = inp_dim 87 | self.out_dim = out_dim 88 | self.attn_rel_emb_dim = attn_rel_emb_dim 89 | self.num_rels = num_rels 90 | self.num_bases = num_bases 91 | self.is_input_layer = is_input_layer 92 | self.has_attn = has_attn 93 | 94 | if self.num_bases <= 0 or self.num_bases > self.num_rels: 95 | self.num_bases = self.num_rels 96 | 97 | # add basis weights 98 | # self.weight = basis_weights 99 | self.weight = nn.Parameter(torch.Tensor(self.num_bases, self.inp_dim, self.out_dim)) 100 | self.w_comp = nn.Parameter(torch.Tensor(self.num_rels, self.num_bases)) 101 | 102 | if self.has_attn: 103 | self.A = nn.Linear(2 * self.inp_dim + 2 * self.attn_rel_emb_dim, inp_dim) 104 | self.B = nn.Linear(inp_dim, 1) 105 | 106 | self.self_loop_weight = nn.Parameter(torch.Tensor(self.inp_dim, self.out_dim)) 107 | 108 | nn.init.xavier_uniform_(self.self_loop_weight, gain=nn.init.calculate_gain('relu')) 109 | nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu')) 110 | nn.init.xavier_uniform_(self.w_comp, gain=nn.init.calculate_gain('relu')) 111 | 112 | def propagate(self, g, attn_rel_emb=None): 113 | # generate all weights from bases 114 | weight = self.weight.view(self.num_bases, 115 | self.inp_dim * self.out_dim) 116 | weight = torch.matmul(self.w_comp, weight).view( 117 | self.num_rels, self.inp_dim, self.out_dim) 118 | 119 | g.edata['w'] = self.edge_dropout(torch.ones(g.number_of_edges(), 1).to(weight.device)) 120 | 121 | input_ = 'feat' if self.is_input_layer else 'h' 122 | 123 | def msg_func(edges): 124 | w = weight.index_select(0, edges.data['type']) 125 | msg = edges.data['w'] * torch.bmm(edges.src[input_].unsqueeze(1), w).squeeze(1) 126 | curr_emb = torch.mm(edges.dst[input_], self.self_loop_weight) # (B, F) 127 | 128 | if self.has_attn: 129 | e = torch.cat([edges.src[input_], edges.dst[input_], attn_rel_emb(edges.data['type']), attn_rel_emb(edges.data['label'])], dim=1) 130 | a = torch.sigmoid(self.B(F.relu(self.A(e)))) 131 | else: 132 | a = torch.ones((len(edges), 1)).to(device=w.device) 133 | 134 | return {'curr_emb': curr_emb, 'msg': msg, 'alpha': a} 135 | 136 | g.update_all(msg_func, self.aggregator, None) 137 | -------------------------------------------------------------------------------- /model/dgl/rgcn_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | File based off of dgl tutorial on RGCN 3 | Source: https://github.com/dmlc/dgl/tree/master/examples/pytorch/rgcn 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | # from .layers import RGCNBasisLayer as RGCNLayer 10 | from .layers_new import RGCNBasisLayer as RGCNLayer 11 | 12 | from .aggregators import SumAggregator, MLPAggregator, GRUAggregator 13 | 14 | 15 | class RGCN(nn.Module): 16 | def __init__(self, params): 17 | super(RGCN, self).__init__() 18 | 19 | self.max_label_value = params.max_label_value 20 | self.inp_dim = params.inp_dim 21 | self.emb_dim = params.emb_dim 22 | self.attn_rel_emb_dim = params.attn_rel_emb_dim 23 | self.num_rels = params.num_rels 24 | self.aug_num_rels = params.aug_num_rels 25 | self.num_bases = params.num_bases 26 | self.num_hidden_layers = params.num_gcn_layers 27 | self.dropout = params.dropout 28 | self.edge_dropout = params.edge_dropout 29 | # self.aggregator_type = params.gnn_agg_type 30 | self.has_attn = params.has_attn 31 | self.self_attn = params.self_attn 32 | 33 | self.device = params.device 34 | 35 | if self.has_attn: 36 | self.attn_rel_emb = nn.Embedding(self.num_rels, self.attn_rel_emb_dim, sparse=False) 37 | else: 38 | self.attn_rel_emb = None 39 | 40 | # initialize aggregators for input and hidden layers 41 | if params.gnn_agg_type == "sum": 42 | self.aggregator = SumAggregator(self.emb_dim) 43 | elif params.gnn_agg_type == "mlp": 44 | self.aggregator = MLPAggregator(self.emb_dim) 45 | elif params.gnn_agg_type == "gru": 46 | self.aggregator = GRUAggregator(self.emb_dim) 47 | 48 | # initialize basis weights for input and hidden layers 49 | # self.input_basis_weights = nn.Parameter(torch.Tensor(self.num_bases, self.inp_dim, self.emb_dim)) 50 | # self.basis_weights = nn.Parameter(torch.Tensor(self.num_bases, self.emb_dim, self.emb_dim)) 51 | 52 | # create rgcn layers 53 | self.build_model() 54 | 55 | # create initial features 56 | self.features = self.create_features() 57 | 58 | def create_features(self): 59 | features = torch.arange(self.inp_dim).to(device=self.device) 60 | return features 61 | 62 | def build_model(self): 63 | self.layers = nn.ModuleList() 64 | # i2h 65 | i2h = self.build_input_layer() 66 | if i2h is not None: 67 | self.layers.append(i2h) 68 | # h2h 69 | for idx in range(self.num_hidden_layers - 1): 70 | h2h = self.build_hidden_layer(idx) 71 | self.layers.append(h2h) 72 | 73 | def build_input_layer(self): 74 | return RGCNLayer(self.inp_dim, 75 | self.emb_dim, 76 | # self.input_basis_weights, 77 | self.aggregator, 78 | self.attn_rel_emb_dim, 79 | self.aug_num_rels, 80 | self.num_bases, 81 | activation=F.relu, 82 | dropout=self.dropout, 83 | edge_dropout=self.edge_dropout, 84 | is_input_layer=True, 85 | has_attn=self.has_attn, 86 | self_attn=self.self_attn 87 | ) 88 | 89 | def build_hidden_layer(self, idx): 90 | return RGCNLayer(self.emb_dim, 91 | self.emb_dim, 92 | # self.basis_weights, 93 | self.aggregator, 94 | self.attn_rel_emb_dim, 95 | self.aug_num_rels, 96 | self.num_bases, 97 | activation=F.relu, 98 | dropout=self.dropout, 99 | edge_dropout=self.edge_dropout, 100 | has_attn=self.has_attn, 101 | self_attn=self.self_attn 102 | ) 103 | 104 | def forward(self, g): 105 | for layer in self.layers: 106 | layer(g, self.attn_rel_emb) 107 | return g.ndata.pop('h') 108 | -------------------------------------------------------------------------------- /relational_path/path_process.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import itertools 3 | import os 4 | import dgl 5 | import torch 6 | import torch.nn as nn 7 | from relational_path.path_sampler import obtain_dglpaths, obtain_batch_dgl_rel_paths, unfold_rel_paths, create_neg_paths 8 | 9 | 10 | def path_generate(data, max_path_len, rels_num): 11 | g, rel_labels = data 12 | 13 | graphs_list = list(dgl.unbatch(g)) 14 | # rel_temp = graphs_list[10].edata['type'][graphs_list[10].edge_id(0, 1)] 15 | root_list = np.zeros(len(graphs_list), dtype=int) 16 | end_list = np.ones(len(graphs_list), dtype=int) 17 | hop_list = np.ones(len(graphs_list), dtype=int) * max_path_len 18 | 19 | # entity paths 20 | param_ent = zip(graphs_list, root_list, end_list, hop_list) 21 | entity_paths = list(itertools.starmap(obtain_dglpaths, list(param_ent))) 22 | 23 | # unfold relation paths 24 | param_rel = zip(entity_paths, graphs_list) 25 | rel_paths = list(itertools.starmap(obtain_batch_dgl_rel_paths, list(param_rel))) 26 | rel_path_batch = list(map(unfold_rel_paths, rel_paths)) 27 | 28 | # generate negative relation paths 29 | rel_list = np.ones(len(graphs_list), dtype=int) * rels_num 30 | param_neg = zip(rel_path_batch, rel_labels, rel_list) 31 | neg_paths_batch = list(itertools.starmap(create_neg_paths, list(param_neg))) 32 | 33 | return rel_path_batch, neg_paths_batch, rel_labels 34 | 35 | def path_generate_mul_neg(data, max_path_len, rels_num): 36 | g, rel_labels = data 37 | 38 | graphs_list = list(dgl.unbatch(g)) 39 | root_list = np.zeros(len(graphs_list), dtype=int) 40 | end_list = np.ones(len(graphs_list), dtype=int) 41 | hop_list = np.ones(len(graphs_list), dtype=int) * max_path_len 42 | 43 | # entity paths 44 | param_ent = zip(graphs_list, root_list, end_list, hop_list) 45 | entity_paths = list(itertools.starmap(obtain_dglpaths, list(param_ent))) 46 | 47 | # unfold relation paths 48 | param_rel = zip(entity_paths, graphs_list) 49 | rel_paths = list(itertools.starmap(obtain_batch_dgl_rel_paths, list(param_rel))) 50 | rel_path_batch = list(map(unfold_rel_paths, rel_paths)) 51 | 52 | # generate negative relation paths 53 | rel_list = np.ones(len(graphs_list), dtype=int) * rels_num 54 | param_neg = zip(rel_path_batch, rel_labels, rel_list) 55 | neg_paths_batch = list(itertools.starmap(create_neg_paths, list(param_neg))) 56 | param_neg_2 = zip(rel_path_batch, rel_labels, rel_list) 57 | 58 | neg_paths_batch_2 = list(itertools.starmap(create_neg_paths, list(param_neg_2))) 59 | for i in range(0, len(neg_paths_batch)): 60 | neg_paths_batch[i].extend(neg_paths_batch_2[i]) 61 | 62 | return rel_path_batch, neg_paths_batch, rel_labels 63 | 64 | def get_paths_nums(rel_path_batch): 65 | num_list = list(map(len, rel_path_batch)) 66 | paths_nums = np.mean(num_list) 67 | return paths_nums 68 | 69 | 70 | def path_emb_generate_batch(rel_path_batch, rel_emb, target_labels, epoch, max_epoch, filename, pos): 71 | s_p_batch = [] 72 | paths_emb_batch = [] 73 | # alpha_p = [] 74 | label_id = 0 75 | # batch_size = len(rel_paths_batch) 76 | for paths_in_graph in rel_path_batch: 77 | paths_emb, score_paths = path_emb_generate_rnn(paths_in_graph, rel_emb, target_labels[label_id]) 78 | s_p = torch.matmul(score_paths.squeeze(0), torch.Tensor(paths_emb)) # Tensor([0.1, 0.2, ...]) 79 | s_p_batch.append(s_p.tolist()) # [[0.1, 0.2, ...],[0.1, 0.2, ...],[0.1, 0.2, ...]...] 80 | paths_emb_batch.append(paths_emb) 81 | if epoch == max_epoch: 82 | if pos == 1: 83 | print_attetion(paths_in_graph, score_paths.squeeze(0), target_labels[label_id], filename) 84 | label_id += 1 85 | s_p_batch = torch.Tensor(s_p_batch) # Tensor([[0.1, 0.2, ...],[0.1, 0.2, ...],[0.1, 0.2, ...]...]) 86 | return paths_emb_batch, s_p_batch 87 | 88 | 89 | def print_attetion(paths_in_graph, score_paths, target_rel, file_name): 90 | data_path = os.path.join(file_name, f"target_{target_rel}/alpha.txt") 91 | dir_path = os.path.join(file_name, f"target_{target_rel}") 92 | if not os.path.exists(dir_path): 93 | os.makedirs(dir_path) 94 | with open(data_path, 'a') as f: 95 | for i in range(len(paths_in_graph)): 96 | f.write(str(paths_in_graph[i]) + str(score_paths[i]) + '\n') 97 | f.writelines('\n') 98 | # return label_id 99 | 100 | 101 | def path_emb_generate_batch_ori(rel_path_batch, rel_emb, target_labels): 102 | s_p_batch = [] 103 | paths_emb_batch = [] 104 | # alpha_p = [] 105 | label_id = 0 106 | # batch_size = len(rel_paths_batch) 107 | for paths_in_graph in rel_path_batch: 108 | paths_emb, score_paths = path_emb_generate(paths_in_graph, rel_emb, target_labels[label_id]) 109 | s_p = torch.matmul(score_paths.squeeze(0), torch.Tensor(paths_emb)) # Tensor([0.1, 0.2, ...]) 110 | s_p_batch.append(s_p.tolist()) # [[0.1, 0.2, ...],[0.1, 0.2, ...],[0.1, 0.2, ...]...] 111 | paths_emb_batch.append(paths_emb) 112 | label_id += 1 113 | s_p_batch = torch.Tensor(s_p_batch) # Tensor([[0.1, 0.2, ...],[0.1, 0.2, ...],[0.1, 0.2, ...]...]) 114 | return paths_emb_batch, s_p_batch 115 | 116 | 117 | def path_emb_generate_rnn(paths, rel_emb, target_id): # list tensor int 118 | paths_emb = [] # list 119 | score_paths = torch.Tensor([]) # list 120 | 121 | conv1 = nn.Conv1d(in_channels=32, out_channels=32, kernel_size=2) 122 | 123 | for path in paths: # path: list 124 | # index = np.asarray(path) 125 | # index = np.array(path, dtype=int) 126 | # index = torch.LongTensor([0, 1]) 127 | # rnn = nn.LSTM(32, 32, num_layers=2) 128 | # input_rnn = rel_emb(torch.LongTensor(path)) 129 | # output_rnn, (h, c) = rnn(input_rnn.unsqueeze(dim=0)) 130 | if len(path) < 2: 131 | path_emb = rel_emb(torch.LongTensor(path)) 132 | else: 133 | input_rnn = rel_emb(torch.LongTensor(path)) 134 | input_rnn = input_rnn.unsqueeze(dim=0).permute(0, 2, 1) 135 | path_emb = conv1(input_rnn) 136 | path_emb = path_emb.permute(0, 1, 2).squeeze(0) 137 | path_emb = torch.sum(path_emb, dim=1) 138 | # print(path_emb.shape) 139 | path_emb = path_emb.squeeze(0).tolist() 140 | # print(output_rnn.squeeze(0).shape) 141 | # path_emb = torch.sum(output_rnn.squeeze(0), dim=0).tolist() # tensor ->tolist 142 | # path_emb = torch.mean(rel_emb[index], dim=1) 143 | # print(torch.Tensor(path_emb).shape) 144 | paths_emb.append(path_emb) # list: [[0.1, 0.2, ...],[0.1, 0.2, ...],[0.1, 0.2, ...]...] 145 | # print(path_emb.shape) 146 | 147 | score = torch.dot(torch.Tensor(path_emb), rel_emb(target_id)) 148 | # score_paths.append(score) 149 | score_paths = torch.cat((score_paths, score.view(1, 1)), dim=0) # Tensor([]) 150 | softmax = nn.Softmax(dim=1) 151 | alpha_paths = softmax(score_paths.view(1, len(score_paths))) # Tensor([[0.1, 0.2, ...]]) 152 | return paths_emb, alpha_paths 153 | 154 | 155 | def path_emb_generate(paths, rel_emb, target_id): # list tensor int 156 | paths_emb = [] # list 157 | score_paths = torch.Tensor([]) # list 158 | for path in paths: # path: list 159 | # index = np.asarray(path) 160 | # index = np.array(path, dtype=int) 161 | # index = torch.LongTensor([0, 1]) 162 | # print(rel_emb(torch.LongTensor(path)).shape) 163 | path_emb = torch.sum(rel_emb(torch.LongTensor(path)), dim=0).tolist() # tensor ->tolist 164 | # print(path_emb.shape) 165 | # path_emb = torch.mean(rel_emb(torch.LongTensor(path)), dim=0).tolist() # tensor ->tolist 166 | # path_emb = torch.mean(rel_emb[index], dim=1) 167 | paths_emb.append(path_emb) # list: [[0.1, 0.2, ...],[0.1, 0.2, ...],[0.1, 0.2, ...]...] 168 | 169 | score = torch.dot(torch.Tensor(path_emb), rel_emb(target_id)) 170 | # score_paths.append(score) 171 | score_paths = torch.cat((score_paths, score.view(1, 1)), dim=0) # Tensor([]) 172 | softmax = nn.Softmax(dim=1) 173 | alpha_paths = softmax(score_paths.view(1, len(score_paths))) # Tensor([[0.1, 0.2, ...]]) 174 | return paths_emb, alpha_paths 175 | 176 | 177 | def path_cross_loss(s_p_batch, rel_emb, target_labels): 178 | # index = np.array(target_labels, dtype=int) 179 | # target_embs = torch.Tensor(rel_emb[index]) 180 | # score = torch.sum(torch.mul(torch.Tensor(s_p_batch), target_embs), dim=1) # [1, 2, 3, ...] 181 | output = torch.matmul(s_p_batch, rel_emb.t()) 182 | criterion = nn.CrossEntropyLoss(reduction='sum') 183 | loss = criterion(output, target_labels) 184 | return loss 185 | 186 | 187 | def path_contrast_loss(s_p_pos, s_p_neg, rel_emb, target_labels): 188 | batch_size = len(target_labels) 189 | pos = torch.mul(s_p_pos, rel_emb(target_labels)) 190 | neg = torch.mul(s_p_neg, rel_emb(target_labels)) 191 | softmax = nn.Softmax(dim=0) 192 | pos_sim = torch.sum(pos, dim=1) # tensor 193 | neg_sim = torch.sum(neg, dim=1) # tensor 194 | 195 | sim = torch.cat((pos_sim.view(1, len(pos_sim)), neg_sim.view(1, len(neg_sim))), dim=0) 196 | output = softmax(sim) 197 | 198 | zero_id = (output[0] == 0).nonzero(as_tuple=False).flatten() 199 | output[0][zero_id] = output[0][zero_id] + 1e-10 200 | temp = output[0] 201 | # numerator = torch.exp(pos_sim) 202 | # denominator = torch.exp(pos_sim) + torch.exp(neg_sim) 203 | res = - torch.log(output[0]) # tensor, triple? 204 | # loss = torch.sum(res) / (2 * batch_size) 205 | loss = torch.sum(res) 206 | return loss 207 | 208 | 209 | def path_contrast_loss_2(pos_paths_emb_batch, neg_paths_emb_batch, rel_emb, target_labels, device): 210 | batch_size = len(target_labels) 211 | label_id = 0 212 | loss_batch = 0 213 | softmax = nn.Softmax(dim=0) 214 | for paths_emb in zip(pos_paths_emb_batch, neg_paths_emb_batch): 215 | pos_sim = torch.matmul(torch.Tensor(paths_emb[0]).to(device=device), rel_emb(target_labels[label_id])) 216 | neg_sim = torch.matmul(torch.Tensor(paths_emb[1]).to(device=device), rel_emb(target_labels[label_id])) 217 | 218 | sim = torch.cat((pos_sim.view(1, len(pos_sim)), neg_sim.view(1, len(neg_sim))), dim=0) 219 | output = softmax(sim) 220 | 221 | zero_id = (output[0] == 0).nonzero(as_tuple=False).flatten() 222 | output[0][zero_id] = output[0][zero_id] + 1e-10 223 | res = - torch.log(output[0]) # tensor, triple? 224 | # loss = torch.sum(res) / (2 * batch_size) 225 | loss = torch.sum(res) 226 | label_id += 1 227 | loss_batch = loss_batch + loss 228 | 229 | return loss_batch 230 | -------------------------------------------------------------------------------- /relational_path/path_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.io as sio 3 | import scipy.sparse as ssp 4 | from scipy.sparse import coo_matrix 5 | import sys 6 | import torch 7 | from scipy.special import softmax 8 | from utils.dgl_utils import _bfs_relational 9 | from utils.graph_utils import incidence_matrix, remove_nodes, ssp_to_torch, serialize, deserialize, get_edge_count, diameter, radius 10 | import networkx as nx 11 | import itertools 12 | import copy 13 | 14 | 15 | def get_neighbor_nodes(roots, adj, h=5, max_nodes_per_hop=None): 16 | bfs_generator = _bfs_relational(adj, roots, max_nodes_per_hop) 17 | lvls = list() 18 | for _ in range(h): 19 | try: 20 | lvls.append(next(bfs_generator)) 21 | except StopIteration: 22 | pass 23 | return set().union(*lvls) 24 | 25 | 26 | def find_entity_paths(roots, end, adj, h, max_nodes_per_hop=None, path=[]): 27 | roots = set([roots]) 28 | path = path + list(roots) 29 | if roots == set([end]): 30 | return [path] 31 | bfs_generator = _bfs_relational(adj, roots, max_nodes_per_hop) 32 | nodeList = next(bfs_generator) 33 | paths = list() 34 | for node in nodeList: 35 | try: 36 | # nodeList=next(bfs_generator) 37 | if node not in path: 38 | if h == 0: 39 | break 40 | else: 41 | newpaths = find_entity_paths(node, end, adj, h - 1, max_nodes_per_hop, path) 42 | 43 | for newpath in newpaths: 44 | paths.append(newpath) 45 | except StopIteration: 46 | pass 47 | return paths 48 | 49 | 50 | def find_relations(matrix, head, tail): 51 | if matrix[head, tail] > 0: 52 | print("Existed") 53 | 54 | 55 | def find_relation_paths(paths, root, end, rel, A): 56 | rel_paths = list() 57 | for path in paths: 58 | rel_path = list() 59 | for i in range(len(path) - 1): 60 | path_rel_label = [] 61 | for adj in range(len(A)): 62 | # rel_ = adj 63 | exist_rel = A[adj][path[i], path[i + 1]] 64 | if exist_rel > 0: 65 | path_rel_label.append(adj) 66 | pre_path = rel_path 67 | rel_path = pre_path + [list(path_rel_label)] 68 | rel_paths.append(rel_path) 69 | return rel_paths 70 | 71 | 72 | def obtain_rel_paths_old(rel_path): 73 | all_list = list(itertools.product(*rel_path)) 74 | return all_list 75 | 76 | 77 | def find_paths(roots, end, adj, h, max_nodes_per_hop=None, path=[]): 78 | roots=set([roots]) 79 | path = path + list(roots) 80 | if roots == set([end]): 81 | return [path] 82 | bfs_generator = _bfs_relational(adj, roots, max_nodes_per_hop) 83 | nodeList = next(bfs_generator) 84 | paths = list() 85 | for node in nodeList: 86 | try: 87 | # nodeList=next(bfs_generator) 88 | if node not in path: 89 | if h == 0: 90 | break 91 | else: 92 | newpaths = find_paths(node, end, adj, h - 1, max_nodes_per_hop, path) 93 | 94 | for newpath in newpaths: 95 | paths.append(newpath) 96 | except StopIteration: 97 | pass 98 | return paths 99 | 100 | 101 | ########################################################################## 102 | #dgl graph 103 | ########################################################################## 104 | 105 | def obtain_dglpaths(graph, root, end, h, path=[]): 106 | roots = set([root]) 107 | path = path + list(roots) 108 | if roots == set([end]): 109 | return [path] 110 | nodeList = list(graph.successors(root)) 111 | paths = list() 112 | for node in nodeList: 113 | try: 114 | # nodeList=next(bfs_generator) 115 | node = int(node) 116 | if node not in path: 117 | if h == 0: 118 | break 119 | else: 120 | newpaths = obtain_dglpaths(graph, node, end, h - 1, path) 121 | 122 | for newpath in newpaths: 123 | paths.append(newpath) 124 | except StopIteration: 125 | pass 126 | return paths 127 | 128 | 129 | def dgl_relation_exist(dgl_graph, head, tail): 130 | if dgl_graph.edge_id(head, tail) is not None: 131 | print("Existed") 132 | 133 | 134 | def obtain_dgl_relation_paths(paths, dgl_graph): 135 | rel_paths = list() 136 | pre_path = [] 137 | for path in paths: 138 | rel_path = list() 139 | for i in range(len(path) - 1): 140 | path_rel_label = [] 141 | if dgl_graph.edge_id(path[i], path[i + 1]) is not None: 142 | rel_label = dgl_graph.edata['type'][dgl_graph.edge_id(path[i], path[i + 1])].tolist() 143 | path_rel_label.append(rel_label) 144 | pre_path = rel_path 145 | rel_path = pre_path + path_rel_label 146 | rel_paths.append(rel_path) 147 | return rel_paths 148 | 149 | 150 | def obtain_batch_dgl_rel_paths(paths, dgl_graph): 151 | rel_paths = list() 152 | pre_path = [] 153 | for path in paths: 154 | rel_path = list() 155 | for i in range(len(path) - 1): 156 | path_rel_label = [] 157 | if dgl_graph.edge_id(path[i], path[i + 1]) is not None: 158 | rel_label = dgl_graph.edata['type'][dgl_graph.edge_ids(path[i], path[i + 1])].tolist() 159 | path_rel_label.append(rel_label) 160 | pre_path = rel_path 161 | rel_path = pre_path + path_rel_label 162 | rel_paths.append(rel_path) 163 | return rel_paths 164 | 165 | 166 | def obtain_rel_path(rel_path): 167 | all_list = list(itertools.product(*rel_path)) 168 | tuple_list=[] 169 | for tuples in all_list: 170 | tuple_list.append(list(tuples)) 171 | return tuple_list 172 | 173 | 174 | def unfold_rel_paths(folded_paths): 175 | unfolded_paths = [] 176 | for rel_paths in folded_paths: 177 | rel_path = obtain_rel_path(rel_paths) 178 | unfolded_paths = unfolded_paths + rel_path 179 | return unfolded_paths 180 | 181 | 182 | ########################################################################## 183 | # neg_rel 184 | ########################################################################## 185 | def create_neg_paths(paths, target_rel, rels): 186 | target_list = [target_rel] 187 | neg_path = [] 188 | old_path = copy.deepcopy(paths) 189 | # set_old_path = set([old_path]) 190 | for item in paths: 191 | loop = 0 192 | while True and loop < rels: 193 | rel_list = list(range(rels)) 194 | loop += 1 195 | if item != target_list: 196 | neg_position = np.random.choice(len(item)) 197 | rel_list.remove(item[neg_position]) 198 | item_temp = copy.deepcopy(item) 199 | item_temp[neg_position] = np.random.choice(rel_list) 200 | if item_temp not in old_path: 201 | neg_path.append(item_temp) 202 | break 203 | # item[neg_position] = np.random.choice(rel_list) 204 | # if item not in old_path: 205 | # neg_path.append(item) 206 | # break 207 | else: 208 | neg_list = list(range(rels)) 209 | neg_list.remove(target_rel) 210 | item = [np.random.choice(neg_list)] 211 | neg_path.append(item) 212 | return neg_path 213 | 214 | def create_mul_neg_paths(paths, target_rel, rels): 215 | target_list = [target_rel] 216 | neg_path = [] 217 | old_path = copy.deepcopy(paths) 218 | # set_old_path = set([old_path]) 219 | for item in paths: 220 | loop = 0 221 | while True and loop < rels: 222 | rel_list = list(range(rels)) 223 | loop += 1 224 | if item != target_list: 225 | neg_position = np.random.choice(len(item)) 226 | rel_list.remove(item[neg_position]) 227 | item_temp = copy.deepcopy(item) 228 | item_temp[neg_position] = np.random.choice(rel_list) 229 | if item_temp not in old_path: 230 | neg_path.append(item_temp) 231 | break 232 | # item[neg_position] = np.random.choice(rel_list) 233 | # if item not in old_path: 234 | # neg_path.append(item) 235 | # break 236 | else: 237 | neg_list = list(range(rels)) 238 | neg_list.remove(target_rel) 239 | item = [np.random.choice(neg_list)] 240 | neg_path.append(item) 241 | return neg_path 242 | -------------------------------------------------------------------------------- /relational_path/readme.txt: -------------------------------------------------------------------------------- 1 | The code of relational path extraction and contrastive representation. 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dgl==0.4.2 2 | lmdb==0.98 3 | networkx==2.4 4 | scikit-learn==0.22.1 5 | torch==1.4.0 6 | tqdm==4.43.0 -------------------------------------------------------------------------------- /subgraph_extraction/datasets.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import os 3 | import logging 4 | import lmdb 5 | import numpy as np 6 | import json 7 | import dgl 8 | from utils.graph_utils import ssp_multigraph_to_dgl 9 | from utils.data_utils import process_files, save_to_file 10 | from relational_path.graph_sampler import * 11 | 12 | 13 | def generate_subgraph_datasets(params, splits=['train', 'valid'], saved_relation2id=None, max_label_value=None): 14 | 15 | testing = 'test' in splits 16 | adj_list, triplets, entity2id, relation2id, id2entity, id2relation = process_files(params.file_paths, saved_relation2id) 17 | 18 | # plot_rel_dist(adj_list, os.path.join(params.main_dir, f'data/{params.dataset}/rel_dist.png')) 19 | 20 | data_path = os.path.join(params.main_dir, f'data/{params.dataset}/relation2id.json') 21 | if not os.path.isdir(data_path) and not testing: 22 | with open(data_path, 'w') as f: 23 | json.dump(relation2id, f) 24 | 25 | graphs = {} 26 | 27 | for split_name in splits: 28 | graphs[split_name] = {'triplets': triplets[split_name], 'max_size': params.max_links} 29 | 30 | # Sample train and valid/test links 31 | for split_name, split in graphs.items(): 32 | logging.info(f"Sampling negative links for {split_name}") 33 | split['pos'], split['neg'] = sample_neg(adj_list, split['triplets'], params.num_neg_samples_per_link, max_size=split['max_size'], constrained_neg_prob=params.constrained_neg_prob) 34 | 35 | if testing: 36 | directory = os.path.join(params.main_dir, 'data/{}/'.format(params.dataset)) 37 | save_to_file(directory, f'neg_{params.test_file}_{params.constrained_neg_prob}.txt', graphs['test']['neg'], id2entity, id2relation) 38 | 39 | links2subgraphs(adj_list, graphs, params, max_label_value) 40 | 41 | 42 | def get_kge_embeddings(dataset, kge_model): 43 | 44 | path = './experiments/kge_baselines/{}_{}'.format(kge_model, dataset) 45 | node_features = np.load(os.path.join(path, 'entity_embedding.npy')) 46 | with open(os.path.join(path, 'id2entity.json')) as json_file: 47 | kge_id2entity = json.load(json_file) 48 | kge_entity2id = {v: int(k) for k, v in kge_id2entity.items()} 49 | 50 | return node_features, kge_entity2id 51 | 52 | 53 | class SubgraphDataset(Dataset): 54 | """Extracted, labeled, subgraph dataset -- DGL Only""" 55 | 56 | def __init__(self, db_path, db_name_pos, db_name_neg, raw_data_paths, included_relations=None, add_traspose_rels=False, num_neg_samples_per_link=1, use_kge_embeddings=False, dataset='', kge_model='', file_name=''): 57 | # create db file 58 | self.main_env = lmdb.open(db_path, readonly=True, max_dbs=3, lock=False, map_size=int(1e9)) 59 | # create corresponding db files 60 | self.db_pos = self.main_env.open_db(db_name_pos.encode()) 61 | self.db_neg = self.main_env.open_db(db_name_neg.encode()) 62 | self.node_features, self.kge_entity2id = get_kge_embeddings(dataset, kge_model) if use_kge_embeddings else (None, None) 63 | self.num_neg_samples_per_link = num_neg_samples_per_link 64 | self.file_name = file_name 65 | 66 | ssp_graph, __, __, __, id2entity, id2relation = process_files(raw_data_paths, included_relations) 67 | self.num_rels = len(ssp_graph) 68 | 69 | # Add transpose matrices to handle both directions of relations. 有向图-> 无向图 70 | if add_traspose_rels: 71 | ssp_graph_t = [adj.T for adj in ssp_graph] 72 | ssp_graph += ssp_graph_t 73 | 74 | # the effective number of relations after adding symmetric adjacency matrices and/or self connections 75 | self.aug_num_rels = len(ssp_graph) 76 | self.graph = ssp_multigraph_to_dgl(ssp_graph) 77 | self.ssp_graph = ssp_graph 78 | self.id2entity = id2entity 79 | self.id2relation = id2relation 80 | 81 | self.max_n_label = np.array([0, 0]) 82 | with self.main_env.begin() as txn: 83 | self.max_n_label[0] = int.from_bytes(txn.get('max_n_label_sub'.encode()), byteorder='little') 84 | self.max_n_label[1] = int.from_bytes(txn.get('max_n_label_obj'.encode()), byteorder='little') 85 | 86 | self.avg_subgraph_size = struct.unpack('f', txn.get('avg_subgraph_size'.encode())) 87 | self.min_subgraph_size = struct.unpack('f', txn.get('min_subgraph_size'.encode())) 88 | self.max_subgraph_size = struct.unpack('f', txn.get('max_subgraph_size'.encode())) 89 | self.std_subgraph_size = struct.unpack('f', txn.get('std_subgraph_size'.encode())) 90 | 91 | self.avg_enc_ratio = struct.unpack('f', txn.get('avg_enc_ratio'.encode())) 92 | self.min_enc_ratio = struct.unpack('f', txn.get('min_enc_ratio'.encode())) 93 | self.max_enc_ratio = struct.unpack('f', txn.get('max_enc_ratio'.encode())) 94 | self.std_enc_ratio = struct.unpack('f', txn.get('std_enc_ratio'.encode())) 95 | 96 | self.avg_num_pruned_nodes = struct.unpack('f', txn.get('avg_num_pruned_nodes'.encode())) 97 | self.min_num_pruned_nodes = struct.unpack('f', txn.get('min_num_pruned_nodes'.encode())) 98 | self.max_num_pruned_nodes = struct.unpack('f', txn.get('max_num_pruned_nodes'.encode())) 99 | self.std_num_pruned_nodes = struct.unpack('f', txn.get('std_num_pruned_nodes'.encode())) 100 | 101 | logging.info(f"Max distance from sub : {self.max_n_label[0]}, Max distance from obj : {self.max_n_label[1]}") 102 | 103 | # logging.info('=====================') 104 | # logging.info(f"Subgraph size stats: \n Avg size {self.avg_subgraph_size}, \n Min size {self.min_subgraph_size}, \n Max size {self.max_subgraph_size}, \n Std {self.std_subgraph_size}") 105 | 106 | # logging.info('=====================') 107 | # logging.info(f"Enclosed nodes ratio stats: \n Avg size {self.avg_enc_ratio}, \n Min size {self.min_enc_ratio}, \n Max size {self.max_enc_ratio}, \n Std {self.std_enc_ratio}") 108 | 109 | # logging.info('=====================') 110 | # logging.info(f"# of pruned nodes stats: \n Avg size {self.avg_num_pruned_nodes}, \n Min size {self.min_num_pruned_nodes}, \n Max size {self.max_num_pruned_nodes}, \n Std {self.std_num_pruned_nodes}") 111 | 112 | with self.main_env.begin(db=self.db_pos) as txn: 113 | self.num_graphs_pos = int.from_bytes(txn.get('num_graphs'.encode()), byteorder='little') 114 | with self.main_env.begin(db=self.db_neg) as txn: 115 | self.num_graphs_neg = int.from_bytes(txn.get('num_graphs'.encode()), byteorder='little') 116 | 117 | self.__getitem__(0) 118 | 119 | def __getitem__(self, index): # 当访问不存在的属性时调用,index = 0,batch 时会更新 120 | with self.main_env.begin(db=self.db_pos) as txn: 121 | str_id = '{:08}'.format(index).encode('ascii') 122 | nodes_pos, r_label_pos, g_label_pos, n_labels_pos = deserialize(txn.get(str_id)).values() 123 | subgraph_pos = self._prepare_subgraphs(nodes_pos, r_label_pos, n_labels_pos) 124 | subgraphs_neg = [] 125 | r_labels_neg = [] 126 | g_labels_neg = [] 127 | with self.main_env.begin(db=self.db_neg) as txn: 128 | for i in range(self.num_neg_samples_per_link): 129 | str_id = '{:08}'.format(index + i * (self.num_graphs_pos)).encode('ascii') 130 | nodes_neg, r_label_neg, g_label_neg, n_labels_neg = deserialize(txn.get(str_id)).values() 131 | subgraphs_neg.append(self._prepare_subgraphs(nodes_neg, r_label_neg, n_labels_neg)) 132 | r_labels_neg.append(r_label_neg) 133 | g_labels_neg.append(g_label_neg) 134 | 135 | # 考虑在这里加入paths 136 | # print("data: "+str(index)) 137 | # path = obtain_dglpaths(subgraph_pos, 0, 1, 4) 138 | # all_rel_paths = obtain_dgl_relation_paths(path, subgraph_pos) 139 | 140 | return subgraph_pos, g_label_pos, r_label_pos, subgraphs_neg, g_labels_neg, r_labels_neg 141 | 142 | def __len__(self): 143 | return self.num_graphs_pos 144 | 145 | def _prepare_subgraphs(self, nodes, r_label, n_labels): 146 | subgraph = dgl.DGLGraph(self.graph.subgraph(nodes)) # 将nodes构成的子图转换为dgl图 147 | subgraph.edata['type'] = self.graph.edata['type'][self.graph.subgraph(nodes).parent_eid] # 子图中对应边是什么 148 | subgraph.edata['label'] = torch.tensor(r_label * np.ones(subgraph.edata['type'].shape), dtype=torch.long) # r_label 149 | 150 | edges_btw_roots = subgraph.edge_id(0, 1) # roots间的边 151 | rel_link = np.nonzero(subgraph.edata['type'][edges_btw_roots] == r_label) # 0,1之间是否为r_label 152 | if rel_link.squeeze().nelement() == 0: 153 | subgraph.add_edge(0, 1) 154 | subgraph.edata['type'][-1] = torch.tensor(r_label).type(torch.LongTensor) 155 | subgraph.edata['label'][-1] = torch.tensor(r_label).type(torch.LongTensor) 156 | 157 | # map the id read by GraIL to the entity IDs as registered by the KGE embeddings 158 | kge_nodes = [self.kge_entity2id[self.id2entity[n]] for n in nodes] if self.kge_entity2id else None # 默认None 159 | n_feats = self.node_features[kge_nodes] if self.node_features is not None else None # 默认None 160 | subgraph = self._prepare_features_new(subgraph, n_labels, n_feats) 161 | 162 | return subgraph 163 | 164 | def _prepare_features(self, subgraph, n_labels, n_feats=None): 165 | # One hot encode the node label feature and concat to n_feature 166 | n_nodes = subgraph.number_of_nodes() 167 | label_feats = np.zeros((n_nodes, self.max_n_label[0] + 1)) 168 | label_feats[np.arange(n_nodes), n_labels] = 1 169 | label_feats[np.arange(n_nodes), self.max_n_label[0] + 1 + n_labels[:, 1]] = 1 170 | n_feats = np.concatenate((label_feats, n_feats), axis=1) if n_feats else label_feats 171 | subgraph.ndata['feat'] = torch.FloatTensor(n_feats) 172 | self.n_feat_dim = n_feats.shape[1] # Find cleaner way to do this -- i.e. set the n_feat_dim 173 | return subgraph 174 | 175 | def _prepare_features_new(self, subgraph, n_labels, n_feats=None): 176 | # One hot encode the node label feature and concat to n_feature 177 | n_nodes = subgraph.number_of_nodes() 178 | label_feats = np.zeros((n_nodes, self.max_n_label[0] + 1 + self.max_n_label[1] + 1)) 179 | label_feats[np.arange(n_nodes), n_labels[:, 0]] = 1 # d(i,u) 180 | label_feats[np.arange(n_nodes), self.max_n_label[0] + 1 + n_labels[:, 1]] = 1 # d(i,v) 181 | # label_feats = np.zeros((n_nodes, self.max_n_label[0] + 1 + self.max_n_label[1] + 1)) 182 | # label_feats[np.arange(n_nodes), 0] = 1 183 | # label_feats[np.arange(n_nodes), self.max_n_label[0] + 1] = 1 184 | n_feats = np.concatenate((label_feats, n_feats), axis=1) if n_feats is not None else label_feats 185 | subgraph.ndata['feat'] = torch.FloatTensor(n_feats) 186 | 187 | head_id = np.argwhere([label[0] == 0 and label[1] == 1 for label in n_labels]) 188 | tail_id = np.argwhere([label[0] == 1 and label[1] == 0 for label in n_labels]) 189 | n_ids = np.zeros(n_nodes) 190 | n_ids[head_id] = 1 # head 191 | n_ids[tail_id] = 2 # tail 192 | subgraph.ndata['id'] = torch.FloatTensor(n_ids) 193 | 194 | self.n_feat_dim = n_feats.shape[1] # Find cleaner way to do this -- i.e. set the n_feat_dim 195 | return subgraph 196 | -------------------------------------------------------------------------------- /subgraph_extraction/datasets_path.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import os 3 | import logging 4 | import lmdb 5 | import numpy as np 6 | import json 7 | import dgl 8 | from utils.graph_utils import ssp_multigraph_to_dgl 9 | from utils.data_utils import process_files, save_to_file 10 | from .graph_sampler import * 11 | # from .path_sampler import * 12 | from .path_sampler import obtain_dglpaths 13 | 14 | 15 | def generate_subgraph_datasets(params, splits=['train', 'valid'], saved_relation2id=None, max_label_value=None): 16 | 17 | testing = 'test' in splits 18 | adj_list, triplets, entity2id, relation2id, id2entity, id2relation = process_files(params.file_paths, saved_relation2id) 19 | 20 | # plot_rel_dist(adj_list, os.path.join(params.main_dir, f'data/{params.dataset}/rel_dist.png')) 21 | 22 | data_path = os.path.join(params.main_dir, f'data/{params.dataset}/relation2id.json') 23 | if not os.path.isdir(data_path) and not testing: 24 | with open(data_path, 'w') as f: 25 | json.dump(relation2id, f) 26 | 27 | graphs = {} 28 | 29 | for split_name in splits: 30 | graphs[split_name] = {'triplets': triplets[split_name], 'max_size': params.max_links} 31 | 32 | # Sample train and valid/test links 33 | for split_name, split in graphs.items(): 34 | logging.info(f"Sampling negative links for {split_name}") 35 | split['pos'], split['neg'] = sample_neg(adj_list, split['triplets'], params.num_neg_samples_per_link, max_size=split['max_size'], constrained_neg_prob=params.constrained_neg_prob) 36 | 37 | if testing: 38 | directory = os.path.join(params.main_dir, 'data/{}/'.format(params.dataset)) 39 | save_to_file(directory, f'neg_{params.test_file}_{params.constrained_neg_prob}.txt', graphs['test']['neg'], id2entity, id2relation) 40 | 41 | links2subgraphs(adj_list, graphs, params, max_label_value) 42 | 43 | 44 | def get_kge_embeddings(dataset, kge_model): 45 | 46 | path = './experiments/kge_baselines/{}_{}'.format(kge_model, dataset) 47 | node_features = np.load(os.path.join(path, 'entity_embedding.npy')) 48 | with open(os.path.join(path, 'id2entity.json')) as json_file: 49 | kge_id2entity = json.load(json_file) 50 | kge_entity2id = {v: int(k) for k, v in kge_id2entity.items()} 51 | 52 | return node_features, kge_entity2id 53 | 54 | 55 | class SubgraphDataset(Dataset): 56 | """Extracted, labeled, subgraph dataset -- DGL Only""" 57 | 58 | def __init__(self, db_path, db_name_pos, db_name_neg, raw_data_paths, included_relations=None, add_traspose_rels=False, num_neg_samples_per_link=1, use_kge_embeddings=False, dataset='', kge_model='', file_name=''): 59 | # create db file 60 | self.main_env = lmdb.open(db_path, readonly=True, max_dbs=3, lock=False) 61 | # create corresponding db files 62 | self.db_pos = self.main_env.open_db(db_name_pos.encode()) 63 | self.db_neg = self.main_env.open_db(db_name_neg.encode()) 64 | self.node_features, self.kge_entity2id = get_kge_embeddings(dataset, kge_model) if use_kge_embeddings else (None, None) 65 | self.num_neg_samples_per_link = num_neg_samples_per_link 66 | self.file_name = file_name 67 | 68 | ssp_graph, __, __, __, id2entity, id2relation = process_files(raw_data_paths, included_relations) 69 | self.num_rels = len(ssp_graph) 70 | 71 | # Add transpose matrices to handle both directions of relations. 有向图-> 无向图 72 | if add_traspose_rels: 73 | ssp_graph_t = [adj.T for adj in ssp_graph] 74 | ssp_graph += ssp_graph_t 75 | 76 | # the effective number of relations after adding symmetric adjacency matrices and/or self connections 77 | self.aug_num_rels = len(ssp_graph) 78 | self.graph = ssp_multigraph_to_dgl(ssp_graph) 79 | self.ssp_graph = ssp_graph 80 | self.id2entity = id2entity 81 | self.id2relation = id2relation 82 | 83 | self.max_n_label = np.array([0, 0]) 84 | with self.main_env.begin() as txn: 85 | self.max_n_label[0] = int.from_bytes(txn.get('max_n_label_sub'.encode()), byteorder='little') 86 | self.max_n_label[1] = int.from_bytes(txn.get('max_n_label_obj'.encode()), byteorder='little') 87 | 88 | self.avg_subgraph_size = struct.unpack('f', txn.get('avg_subgraph_size'.encode())) 89 | self.min_subgraph_size = struct.unpack('f', txn.get('min_subgraph_size'.encode())) 90 | self.max_subgraph_size = struct.unpack('f', txn.get('max_subgraph_size'.encode())) 91 | self.std_subgraph_size = struct.unpack('f', txn.get('std_subgraph_size'.encode())) 92 | 93 | self.avg_enc_ratio = struct.unpack('f', txn.get('avg_enc_ratio'.encode())) 94 | self.min_enc_ratio = struct.unpack('f', txn.get('min_enc_ratio'.encode())) 95 | self.max_enc_ratio = struct.unpack('f', txn.get('max_enc_ratio'.encode())) 96 | self.std_enc_ratio = struct.unpack('f', txn.get('std_enc_ratio'.encode())) 97 | 98 | self.avg_num_pruned_nodes = struct.unpack('f', txn.get('avg_num_pruned_nodes'.encode())) 99 | self.min_num_pruned_nodes = struct.unpack('f', txn.get('min_num_pruned_nodes'.encode())) 100 | self.max_num_pruned_nodes = struct.unpack('f', txn.get('max_num_pruned_nodes'.encode())) 101 | self.std_num_pruned_nodes = struct.unpack('f', txn.get('std_num_pruned_nodes'.encode())) 102 | 103 | logging.info(f"Max distance from sub : {self.max_n_label[0]}, Max distance from obj : {self.max_n_label[1]}") 104 | 105 | # logging.info('=====================') 106 | # logging.info(f"Subgraph size stats: \n Avg size {self.avg_subgraph_size}, \n Min size {self.min_subgraph_size}, \n Max size {self.max_subgraph_size}, \n Std {self.std_subgraph_size}") 107 | 108 | # logging.info('=====================') 109 | # logging.info(f"Enclosed nodes ratio stats: \n Avg size {self.avg_enc_ratio}, \n Min size {self.min_enc_ratio}, \n Max size {self.max_enc_ratio}, \n Std {self.std_enc_ratio}") 110 | 111 | # logging.info('=====================') 112 | # logging.info(f"# of pruned nodes stats: \n Avg size {self.avg_num_pruned_nodes}, \n Min size {self.min_num_pruned_nodes}, \n Max size {self.max_num_pruned_nodes}, \n Std {self.std_num_pruned_nodes}") 113 | 114 | with self.main_env.begin(db=self.db_pos) as txn: 115 | self.num_graphs_pos = int.from_bytes(txn.get('num_graphs'.encode()), byteorder='little') 116 | with self.main_env.begin(db=self.db_neg) as txn: 117 | self.num_graphs_neg = int.from_bytes(txn.get('num_graphs'.encode()), byteorder='little') 118 | 119 | self.__getitem__(0) 120 | 121 | def __getitem__(self, index): # 当访问不存在的属性时调用,index = 0,batch 时会更新 122 | with self.main_env.begin(db=self.db_pos) as txn: 123 | str_id = '{:08}'.format(index).encode('ascii') 124 | nodes_pos, r_label_pos, g_label_pos, n_labels_pos = deserialize(txn.get(str_id)).values() 125 | subgraph_pos = self._prepare_subgraphs(nodes_pos, r_label_pos, n_labels_pos) 126 | subgraphs_neg = [] 127 | r_labels_neg = [] 128 | g_labels_neg = [] 129 | with self.main_env.begin(db=self.db_neg) as txn: 130 | for i in range(self.num_neg_samples_per_link): 131 | str_id = '{:08}'.format(index + i * (self.num_graphs_pos)).encode('ascii') 132 | nodes_neg, r_label_neg, g_label_neg, n_labels_neg = deserialize(txn.get(str_id)).values() 133 | subgraphs_neg.append(self._prepare_subgraphs(nodes_neg, r_label_neg, n_labels_neg)) 134 | r_labels_neg.append(r_label_neg) 135 | g_labels_neg.append(g_label_neg) 136 | 137 | # 考虑在这里加入paths 138 | path = obtain_dglpaths(subgraph_pos, 0, 1, 3) 139 | all_rel_paths = obtain_dgl_relation_paths(path, subgraph_pos) 140 | 141 | print("data: "+str(index)) 142 | return subgraph_pos, g_label_pos, r_label_pos, subgraphs_neg, g_labels_neg, r_labels_neg 143 | 144 | def __len__(self): 145 | return self.num_graphs_pos 146 | 147 | def _prepare_subgraphs(self, nodes, r_label, n_labels): 148 | subgraph = dgl.DGLGraph(self.graph.subgraph(nodes)) # 将nodes构成的子图转换为dgl图 149 | subgraph.edata['type'] = self.graph.edata['type'][self.graph.subgraph(nodes).parent_eid] # 子图中对应边是什么 150 | subgraph.edata['label'] = torch.tensor(r_label * np.ones(subgraph.edata['type'].shape), dtype=torch.long) # r_label 151 | 152 | edges_btw_roots = subgraph.edge_id(0, 1) # roots间的边 153 | rel_link = np.nonzero(subgraph.edata['type'][edges_btw_roots] == r_label) # 0,1之间是否为r_label 154 | if rel_link.squeeze().nelement() == 0: 155 | subgraph.add_edge(0, 1) 156 | subgraph.edata['type'][-1] = torch.tensor(r_label).type(torch.LongTensor) 157 | subgraph.edata['label'][-1] = torch.tensor(r_label).type(torch.LongTensor) 158 | 159 | # map the id read by GraIL to the entity IDs as registered by the KGE embeddings 160 | kge_nodes = [self.kge_entity2id[self.id2entity[n]] for n in nodes] if self.kge_entity2id else None # 默认None 161 | n_feats = self.node_features[kge_nodes] if self.node_features is not None else None # 默认None 162 | subgraph = self._prepare_features_new(subgraph, n_labels, n_feats) 163 | 164 | return subgraph 165 | 166 | def _prepare_features(self, subgraph, n_labels, n_feats=None): 167 | # One hot encode the node label feature and concat to n_feature 168 | n_nodes = subgraph.number_of_nodes() 169 | label_feats = np.zeros((n_nodes, self.max_n_label[0] + 1)) 170 | label_feats[np.arange(n_nodes), n_labels] = 1 171 | label_feats[np.arange(n_nodes), self.max_n_label[0] + 1 + n_labels[:, 1]] = 1 172 | n_feats = np.concatenate((label_feats, n_feats), axis=1) if n_feats else label_feats 173 | subgraph.ndata['feat'] = torch.FloatTensor(n_feats) 174 | self.n_feat_dim = n_feats.shape[1] # Find cleaner way to do this -- i.e. set the n_feat_dim 175 | return subgraph 176 | 177 | def _prepare_features_new(self, subgraph, n_labels, n_feats=None): 178 | # One hot encode the node label feature and concat to n_feature 179 | n_nodes = subgraph.number_of_nodes() 180 | label_feats = np.zeros((n_nodes, self.max_n_label[0] + 1 + self.max_n_label[1] + 1)) 181 | label_feats[np.arange(n_nodes), n_labels[:, 0]] = 1 # d(i,u) 182 | label_feats[np.arange(n_nodes), self.max_n_label[0] + 1 + n_labels[:, 1]] = 1 # d(i,v) 183 | # label_feats = np.zeros((n_nodes, self.max_n_label[0] + 1 + self.max_n_label[1] + 1)) 184 | # label_feats[np.arange(n_nodes), 0] = 1 185 | # label_feats[np.arange(n_nodes), self.max_n_label[0] + 1] = 1 186 | n_feats = np.concatenate((label_feats, n_feats), axis=1) if n_feats is not None else label_feats 187 | subgraph.ndata['feat'] = torch.FloatTensor(n_feats) 188 | 189 | head_id = np.argwhere([label[0] == 0 and label[1] == 1 for label in n_labels]) 190 | tail_id = np.argwhere([label[0] == 1 and label[1] == 0 for label in n_labels]) 191 | n_ids = np.zeros(n_nodes) 192 | n_ids[head_id] = 1 # head 193 | n_ids[tail_id] = 2 # tail 194 | subgraph.ndata['id'] = torch.FloatTensor(n_ids) 195 | 196 | self.n_feat_dim = n_feats.shape[1] # Find cleaner way to do this -- i.e. set the n_feat_dim 197 | return subgraph 198 | -------------------------------------------------------------------------------- /test_auc.py: -------------------------------------------------------------------------------- 1 | # from comet_ml import Experiment 2 | import pdb 3 | import os 4 | import argparse 5 | import logging 6 | import torch 7 | from scipy.sparse import SparseEfficiencyWarning 8 | import numpy as np 9 | 10 | from subgraph_extraction.datasets import SubgraphDataset, generate_subgraph_datasets 11 | from utils.initialization_utils import initialize_experiment, initialize_model 12 | from utils.graph_utils import collate_dgl, move_batch_to_device_dgl 13 | from managers.evaluator import Evaluator 14 | 15 | from warnings import simplefilter 16 | 17 | 18 | def main(params): 19 | simplefilter(action='ignore', category=UserWarning) 20 | simplefilter(action='ignore', category=SparseEfficiencyWarning) 21 | 22 | graph_classifier = initialize_model(params, None, load_model=True) 23 | 24 | logging.info(f"Device: {params.device}") 25 | 26 | all_auc = [] 27 | auc_mean = 0 28 | 29 | all_auc_pr = [] 30 | auc_pr_mean = 0 31 | max_result_aucpr = 0 32 | for r in range(1, params.runs + 1): 33 | 34 | params.db_path = os.path.join(params.main_dir, f'data/{params.dataset}/test_subgraphs_{params.experiment_name}_{params.constrained_neg_prob}_en_{params.enclosing_sub_graph}') 35 | 36 | generate_subgraph_datasets(params, splits=['test'], 37 | saved_relation2id=graph_classifier.relation2id, 38 | max_label_value=graph_classifier.gnn.max_label_value) 39 | 40 | test = SubgraphDataset(params.db_path, 'test_pos', 'test_neg', params.file_paths, graph_classifier.relation2id, 41 | add_traspose_rels=params.add_traspose_rels, 42 | num_neg_samples_per_link=params.num_neg_samples_per_link, 43 | use_kge_embeddings=params.use_kge_embeddings, dataset=params.dataset, 44 | kge_model=params.kge_model, file_name=params.test_file) 45 | 46 | params.num_rels = test.num_rels 47 | 48 | test_evaluator = Evaluator(params, graph_classifier, test) 49 | 50 | result = test_evaluator.eval(save=True) 51 | logging.info('\nTest Set Performance:' + str(result)) 52 | 53 | if(result['auc_pr'] > max_result_aucpr): 54 | max_result_aucpr = result['auc_pr'] 55 | 56 | logging.info('\nMax Auc_Pr:' + str(max_result_aucpr)) 57 | 58 | all_auc.append(result['auc']) 59 | auc_mean = auc_mean + (result['auc'] - auc_mean) / r 60 | 61 | all_auc_pr.append(result['auc_pr']) 62 | auc_pr_mean = auc_pr_mean + (result['auc_pr'] - auc_pr_mean) / r 63 | 64 | auc_std = np.std(all_auc) 65 | auc_pr_std = np.std(all_auc_pr) 66 | 67 | logging.info('\nAvg test Set Performance -- mean auc :' + str(np.mean(all_auc)) + ' std auc: ' + str(np.std(all_auc))) 68 | logging.info('\nAvg test Set Performance -- mean auc_pr :' + str(np.mean(all_auc_pr)) + ' std auc_pr: ' + str(np.std(all_auc_pr))) 69 | 70 | 71 | if __name__ == '__main__': 72 | 73 | logging.basicConfig(level=logging.INFO) 74 | 75 | parser = argparse.ArgumentParser(description='TransE model') 76 | 77 | # Experiment setup params 78 | parser.add_argument("--experiment_name", "-e", type=str, default="wn_v4_test", 79 | help="A folder with this name would be created to dump saved models and log files") 80 | parser.add_argument("--dataset", "-d", type=str, default="WN18RR_v4_ind", 81 | help="Dataset string") 82 | parser.add_argument("--train_file", "-tf", type=str, default="train", 83 | help="Name of file containing training triplets") 84 | parser.add_argument("--test_file", "-t", type=str, default="test", 85 | help="Name of file containing test triplets") 86 | parser.add_argument("--runs", type=int, default=50, 87 | help="How many runs to perform for mean and std?") 88 | parser.add_argument("--gpu", type=int, default=2, 89 | help="Which GPU to use?") 90 | parser.add_argument('--disable_cuda', action='store_true', 91 | help='Disable CUDA') 92 | 93 | # Data processing pipeline params 94 | parser.add_argument("--max_links", type=int, default=100000, 95 | help="Set maximum number of links (to fit into memory)") 96 | parser.add_argument("--max_path_len", type=int, default=3, 97 | help="Max length of the path") 98 | parser.add_argument("--hop", type=int, default=3, 99 | help="Enclosing subgraph hop number") 100 | parser.add_argument("--max_nodes_per_hop", "-max_h", type=int, default=None, 101 | help="if > 0, upper bound the # nodes per hop by subsampling") 102 | parser.add_argument("--use_kge_embeddings", "-kge", type=bool, default=False, 103 | help='whether to use pretrained KGE embeddings') 104 | parser.add_argument("--kge_model", type=str, default="TransE", 105 | help="Which KGE model to load entity embeddings from") 106 | parser.add_argument('--model_type', '-m', type=str, choices=['dgl'], default='dgl', 107 | help='what format to store subgraphs in for model') 108 | parser.add_argument('--constrained_neg_prob', '-cn', type=float, default=0, 109 | help='with what probability to sample constrained heads/tails while neg sampling') 110 | parser.add_argument("--num_neg_samples_per_link", '-neg', type=int, default=1, 111 | help="Number of negative examples to sample per positive link") 112 | parser.add_argument("--batch_size", type=int, default=16, 113 | help="Batch size") 114 | parser.add_argument("--num_workers", type=int, default=8, 115 | help="Number of dataloading processes") 116 | parser.add_argument('--add_traspose_rels', '-tr', type=bool, default=False, 117 | help='whether to append adj matrix list with symmetric relations') 118 | parser.add_argument('--enclosing_sub_graph', '-en', type=bool, default=True, 119 | help='whether to only consider enclosing subgraph') 120 | 121 | 122 | params = parser.parse_args() 123 | initialize_experiment(params, __file__) 124 | 125 | params.file_paths = { 126 | 'train': os.path.join(params.main_dir, 'data/{}/{}.txt'.format(params.dataset, params.train_file)), 127 | 'test': os.path.join(params.main_dir, 'data/{}/{}.txt'.format(params.dataset, params.test_file)) 128 | } 129 | 130 | if not params.disable_cuda and torch.cuda.is_available(): 131 | params.device = torch.device('cuda:%d' % params.gpu) 132 | else: 133 | params.device = torch.device('cpu') 134 | 135 | params.collate_fn = collate_dgl 136 | params.move_batch_to_device = move_batch_to_device_dgl 137 | 138 | main(params) 139 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import logging 4 | import torch 5 | from scipy.sparse import SparseEfficiencyWarning 6 | 7 | from subgraph_extraction.datasets import SubgraphDataset, generate_subgraph_datasets 8 | from utils.initialization_utils import initialize_experiment, initialize_model 9 | from utils.graph_utils import collate_dgl, move_batch_to_device_dgl 10 | 11 | from model.dgl.graph_classifier import GraphClassifier as dgl_model 12 | 13 | from managers.evaluator import Evaluator 14 | from managers.trainer import Trainer 15 | 16 | from warnings import simplefilter 17 | 18 | 19 | def main(params): 20 | simplefilter(action='ignore', category=UserWarning) 21 | simplefilter(action='ignore', category=SparseEfficiencyWarning) 22 | 23 | params.db_path = os.path.join(params.main_dir, f'data/{params.dataset}/subgraphs_en_{params.enclosing_sub_graph}_neg_{params.num_neg_samples_per_link}_hop_{params.hop}') 24 | 25 | # 没有子图数据,重新构建子图datasets 26 | if not os.path.isdir(params.db_path): 27 | generate_subgraph_datasets(params) 28 | 29 | train = SubgraphDataset(params.db_path, 'train_pos', 'train_neg', params.file_paths, 30 | add_traspose_rels=params.add_traspose_rels, 31 | num_neg_samples_per_link=params.num_neg_samples_per_link, 32 | use_kge_embeddings=params.use_kge_embeddings, dataset=params.dataset, 33 | kge_model=params.kge_model, file_name=params.train_file) 34 | valid = SubgraphDataset(params.db_path, 'valid_pos', 'valid_neg', params.file_paths, 35 | add_traspose_rels=params.add_traspose_rels, 36 | num_neg_samples_per_link=params.num_neg_samples_per_link, 37 | use_kge_embeddings=params.use_kge_embeddings, dataset=params.dataset, 38 | kge_model=params.kge_model, file_name=params.valid_file) 39 | 40 | params.num_rels = train.num_rels 41 | params.aug_num_rels = train.aug_num_rels 42 | params.inp_dim = train.n_feat_dim 43 | 44 | # Log the max label value to save it in the model. This will be used to cap the labels generated on test set. 45 | params.max_label_value = train.max_n_label 46 | 47 | # rgcn模块初始化 48 | graph_classifier = initialize_model(params, dgl_model, params.load_model) 49 | 50 | logging.info(f"Device: {params.device}") 51 | logging.info(f"Input dim : {params.inp_dim}, # Relations : {params.num_rels}, # Augmented relations : {params.aug_num_rels}") 52 | 53 | # 初始化验证Evaluator 54 | valid_evaluator = Evaluator(params, graph_classifier, valid) 55 | 56 | trainer = Trainer(params, graph_classifier, train, valid_evaluator) 57 | 58 | logging.info('Starting training with full batch...') 59 | 60 | trainer.train() 61 | 62 | 63 | if __name__ == '__main__': 64 | 65 | logging.basicConfig(level=logging.INFO) 66 | 67 | parser = argparse.ArgumentParser(description='TransE model') 68 | 69 | # Experiment setup params 70 | parser.add_argument("--experiment_name", "-e", type=str, default="nell_v1", 71 | help="A folder with this name would be created to dump saved models and log files") 72 | parser.add_argument("--dataset", "-d", type=str, default="nell_v1", 73 | help="Dataset string") 74 | parser.add_argument("--gpu", type=int, default=3, 75 | help="Which GPU to use?") 76 | parser.add_argument('--disable_cuda', action='store_true', 77 | help='Disable CUDA') 78 | parser.add_argument('--load_model', action='store_true', 79 | help='Load existing model?') 80 | parser.add_argument("--train_file", "-tf", type=str, default="train", 81 | help="Name of file containing training triplets") 82 | parser.add_argument("--valid_file", "-vf", type=str, default="valid", 83 | help="Name of file containing validation triplets") 84 | 85 | # Training regime params 86 | parser.add_argument("--lambda_cross", type=float, default=1, 87 | help="Weight of the cross entropy loss") 88 | parser.add_argument("--lambda_contrast", type=float, default=1, 89 | help="Weight of the contrast loss") 90 | parser.add_argument("--num_epochs", "-ne", type=int, default=50, 91 | help="Learning rate of the optimizer") 92 | parser.add_argument("--eval_every", type=int, default=3, 93 | help="Interval of epochs to evaluate the model?") 94 | parser.add_argument("--eval_every_iter", type=int, default=500, 95 | help="Interval of iterations to evaluate the model?") 96 | parser.add_argument("--save_every", type=int, default=10, 97 | help="Interval of epochs to save a checkpoint of the model?") 98 | parser.add_argument("--early_stop", type=int, default=100, 99 | help="Early stopping patience") 100 | parser.add_argument("--optimizer", type=str, default="Adam", 101 | help="Which optimizer to use?") 102 | parser.add_argument("--lr", type=float, default=0.001, 103 | help="Learning rate of the optimizer") 104 | parser.add_argument("--clip", type=int, default=1000, 105 | help="Maximum gradient norm allowed") 106 | parser.add_argument("--l2", type=float, default=5e-4, 107 | help="Regularization constant for GNN weights") 108 | parser.add_argument("--margin", type=float, default=10, 109 | help="The margin between positive and negative samples in the max-margin loss") 110 | 111 | # Data processing pipeline params 112 | parser.add_argument("--max_links", type=int, default=1000000, 113 | help="Set maximum number of train links (to fit into memory)") 114 | parser.add_argument("--hop", type=int, default=3, 115 | help="Enclosing subgraph hop number") 116 | parser.add_argument("--max_path_len", type=int, default=4, 117 | help="Max length of the path") 118 | parser.add_argument("--max_nodes_per_hop", "-max_h", type=int, default=None, 119 | help="if > 0, upper bound the # nodes per hop by subsampling") 120 | parser.add_argument("--use_kge_embeddings", "-kge", type=bool, default=False, 121 | help='whether to use pretrained KGE embeddings') 122 | parser.add_argument("--kge_model", type=str, default="TransE", 123 | help="Which KGE model to load entity embeddings from") 124 | parser.add_argument('--model_type', '-m', type=str, choices=['ssp', 'dgl'], default='dgl', 125 | help='what format to store subgraphs in for model') 126 | parser.add_argument('--constrained_neg_prob', '-cn', type=float, default=0.0, 127 | help='with what probability to sample constrained heads/tails while neg sampling') 128 | parser.add_argument("--batch_size", type=int, default=16, 129 | help="Batch size") 130 | parser.add_argument("--num_neg_samples_per_link", '-neg', type=int, default=1, 131 | help="Number of negative examples to sample per positive link") 132 | parser.add_argument("--num_workers", type=int, default=8, 133 | help="Number of dataloading processes") 134 | parser.add_argument('--add_traspose_rels', '-tr', type=bool, default=False, 135 | help='whether to append adj matrix list with symmetric relations') 136 | parser.add_argument('--enclosing_sub_graph', '-en', type=bool, default=True, 137 | help='whether to only consider enclosing subgraph') 138 | 139 | # Model params 140 | parser.add_argument("--rel_emb_dim", "-r_dim", type=int, default=32, 141 | help="Relation embedding size") 142 | parser.add_argument("--attn_rel_emb_dim", "-ar_dim", type=int, default=32, 143 | help="Relation embedding size for attention") 144 | parser.add_argument("--emb_dim", "-dim", type=int, default=32, 145 | help="Entity embedding size") 146 | parser.add_argument("--num_gcn_layers", "-l", type=int, default=3, 147 | help="Number of GCN layers") 148 | parser.add_argument("--num_bases", "-b", type=int, default=4, 149 | help="Number of basis functions to use for GCN weights") 150 | parser.add_argument("--dropout", type=float, default=0, 151 | help="Dropout rate in GNN layers") 152 | parser.add_argument("--edge_dropout", type=float, default=0.5, 153 | help="Dropout rate in edges of the subgraphs") 154 | parser.add_argument('--gnn_agg_type', '-a', type=str, choices=['sum', 'mlp', 'gru'], default='sum', 155 | help='what type of aggregation to do in gnn msg passing') 156 | parser.add_argument('--add_ht_emb', '-ht', type=bool, default=True, 157 | help='whether to concatenate head/tail embedding with pooled graph representation') 158 | parser.add_argument('--add_pt_emb', '-pt', type=bool, default=True, 159 | help='whether to concatenate path embedding with pooled graph representation') 160 | parser.add_argument('--has_attn', '-attn', type=bool, default=True, 161 | help='whether to have attn in model or not') 162 | parser.add_argument('--self_attn', '-st', type=bool, default=True, 163 | help='whether to have self attn in model or not') 164 | 165 | 166 | params = parser.parse_args() 167 | initialize_experiment(params, __file__) 168 | 169 | params.file_paths = { 170 | 'train': os.path.join(params.main_dir, 'data/{}/{}.txt'.format(params.dataset, params.train_file)), 171 | 'valid': os.path.join(params.main_dir, 'data/{}/{}.txt'.format(params.dataset, params.valid_file)) 172 | } 173 | 174 | if not params.disable_cuda and torch.cuda.is_available(): 175 | params.device = torch.device('cuda:%d' % params.gpu) 176 | else: 177 | params.device = torch.device('cpu') 178 | 179 | params.collate_fn = collate_dgl 180 | params.move_batch_to_device = move_batch_to_device_dgl 181 | 182 | main(params) 183 | -------------------------------------------------------------------------------- /utils/clean_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | 5 | 6 | def write_to_file(file_name, data): 7 | with open(file_name, "w") as f: 8 | for s, r, o in data: 9 | f.write('\t'.join([s, r, o]) + '\n') 10 | 11 | 12 | def main(params): 13 | with open(os.path.join(params.main_dir, 'data', params.dataset, 'train.txt')) as f: 14 | train_data = [line.split() for line in f.read().split('\n')[:-1]] 15 | with open(os.path.join(params.main_dir, 'data', params.dataset, 'valid.txt')) as f: 16 | valid_data = [line.split() for line in f.read().split('\n')[:-1]] 17 | with open(os.path.join(params.main_dir, 'data', params.dataset, 'test.txt')) as f: 18 | test_data = [line.split() for line in f.read().split('\n')[:-1]] 19 | 20 | train_tails = set([d[2] for d in train_data]) 21 | train_heads = set([d[0] for d in train_data]) 22 | train_ent = train_tails.union(train_heads) 23 | train_rels = set([d[1] for d in train_data]) 24 | 25 | filtered_valid_data = [] 26 | for d in valid_data: 27 | if d[0] in train_ent and d[1] in train_rels and d[2] in train_ent: 28 | filtered_valid_data.append(d) 29 | else: 30 | train_data.append(d) 31 | train_ent = train_ent.union(set([d[0], d[2]])) 32 | train_rels = train_rels.union(set([d[1]])) 33 | 34 | filtered_test_data = [] 35 | for d in test_data: 36 | if d[0] in train_ent and d[1] in train_rels and d[2] in train_ent: 37 | filtered_test_data.append(d) 38 | else: 39 | train_data.append(d) 40 | train_ent = train_ent.union(set([d[0], d[2]])) 41 | train_rels = train_rels.union(set([d[1]])) 42 | 43 | data_dir = os.path.join(params.main_dir, 'data/{}'.format(params.dataset)) 44 | write_to_file(os.path.join(data_dir, 'train.txt'), train_data) 45 | write_to_file(os.path.join(data_dir, 'valid.txt'), filtered_valid_data) 46 | write_to_file(os.path.join(data_dir, 'test.txt'), filtered_test_data) 47 | 48 | with open(os.path.join(params.main_dir, 'data', params.dataset + '_meta', 'train.txt')) as f: 49 | meta_train_data = [line.split() for line in f.read().split('\n')[:-1]] 50 | with open(os.path.join(params.main_dir, 'data', params.dataset + '_meta', 'valid.txt')) as f: 51 | meta_valid_data = [line.split() for line in f.read().split('\n')[:-1]] 52 | with open(os.path.join(params.main_dir, 'data', params.dataset + '_meta', 'test.txt')) as f: 53 | meta_test_data = [line.split() for line in f.read().split('\n')[:-1]] 54 | 55 | meta_train_tails = set([d[2] for d in meta_train_data]) 56 | meta_train_heads = set([d[0] for d in meta_train_data]) 57 | meta_train_ent = meta_train_tails.union(meta_train_heads) 58 | meta_train_rels = set([d[1] for d in meta_train_data]) 59 | 60 | filtered_meta_valid_data = [] 61 | for d in meta_valid_data: 62 | if d[0] in meta_train_ent and d[1] in meta_train_rels and d[2] in meta_train_ent: 63 | filtered_meta_valid_data.append(d) 64 | else: 65 | meta_train_data.append(d) 66 | meta_train_ent = meta_train_ent.union(set([d[0], d[2]])) 67 | meta_train_rels = meta_train_rels.union(set([d[1]])) 68 | 69 | filtered_meta_test_data = [] 70 | for d in meta_test_data: 71 | if d[0] in meta_train_ent and d[1] in meta_train_rels and d[2] in meta_train_ent: 72 | filtered_meta_test_data.append(d) 73 | else: 74 | meta_train_data.append(d) 75 | meta_train_ent = meta_train_ent.union(set([d[0], d[2]])) 76 | meta_train_rels = meta_train_rels.union(set([d[1]])) 77 | 78 | meta_data_dir = os.path.join(params.main_dir, 'data/{}_meta'.format(params.dataset)) 79 | write_to_file(os.path.join(meta_data_dir, 'train.txt'), meta_train_data) 80 | write_to_file(os.path.join(meta_data_dir, 'valid.txt'), filtered_meta_valid_data) 81 | write_to_file(os.path.join(meta_data_dir, 'test.txt'), filtered_meta_test_data) 82 | 83 | 84 | if __name__ == '__main__': 85 | parser = argparse.ArgumentParser(description='Move new entities from test/valid to train') 86 | 87 | parser.add_argument("--dataset", "-d", type=str, default="fb237_v1_copy", 88 | help="Dataset string") 89 | params = parser.parse_args() 90 | 91 | params.main_dir = os.path.join(os.path.relpath(os.path.dirname(os.path.abspath(__file__))), '..') 92 | 93 | main(params) 94 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | import numpy as np 4 | from scipy.sparse import csc_matrix 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | def plot_rel_dist(adj_list, filename): 9 | rel_count = [] 10 | for adj in adj_list: 11 | rel_count.append(adj.count_nonzero()) 12 | 13 | fig = plt.figure(figsize=(12, 8)) 14 | plt.plot(rel_count) 15 | fig.savefig(filename, dpi=fig.dpi) 16 | 17 | 18 | def process_files(files, saved_relation2id=None): 19 | ''' 20 | files: Dictionary map of file paths to read the triplets from. 21 | saved_relation2id: Saved relation2id (mostly passed from a trained model) which can be used to map relations to pre-defined indices and filter out the unknown ones. 22 | ''' 23 | entity2id = {} 24 | relation2id = {} if saved_relation2id is None else saved_relation2id 25 | 26 | triplets = {} 27 | 28 | ent = 0 29 | rel = 0 30 | 31 | for file_type, file_path in files.items(): 32 | 33 | data = [] 34 | with open(file_path) as f: 35 | file_data = [line.split() for line in f.read().split('\n')[:-1]] 36 | 37 | for triplet in file_data: 38 | if triplet[0] not in entity2id: 39 | entity2id[triplet[0]] = ent 40 | ent += 1 41 | if triplet[2] not in entity2id: 42 | entity2id[triplet[2]] = ent 43 | ent += 1 44 | if not saved_relation2id and triplet[1] not in relation2id: 45 | relation2id[triplet[1]] = rel 46 | rel += 1 47 | 48 | # Save the triplets corresponding to only the known relations 49 | if triplet[1] in relation2id: 50 | data.append([entity2id[triplet[0]], entity2id[triplet[2]], relation2id[triplet[1]]]) 51 | 52 | triplets[file_type] = np.array(data) 53 | 54 | id2entity = {v: k for k, v in entity2id.items()} 55 | id2relation = {v: k for k, v in relation2id.items()} 56 | 57 | # Construct the list of adjacency matrix each corresponding to eeach relation. Note that this is constructed only from the train data. 58 | adj_list = [] 59 | for i in range(len(relation2id)): 60 | idx = np.argwhere(triplets['train'][:, 2] == i) 61 | adj_list.append(csc_matrix((np.ones(len(idx), dtype=np.uint8), (triplets['train'][:, 0][idx].squeeze(1), triplets['train'][:, 1][idx].squeeze(1))), shape=(len(entity2id), len(entity2id)))) 62 | 63 | return adj_list, triplets, entity2id, relation2id, id2entity, id2relation 64 | 65 | 66 | def save_to_file(directory, file_name, triplets, id2entity, id2relation): 67 | file_path = os.path.join(directory, file_name) 68 | with open(file_path, "w") as f: 69 | for s, o, r in triplets: 70 | f.write('\t'.join([id2entity[s], id2relation[r], id2entity[o]]) + '\n') 71 | -------------------------------------------------------------------------------- /utils/dgl_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as ssp 3 | import random 4 | 5 | """All functions in this file are from dgl.contrib.data.knowledge_graph""" 6 | 7 | 8 | def _bfs_relational(adj, roots, max_nodes_per_hop=None): 9 | """ 10 | BFS for graphs. 11 | Modified from dgl.contrib.data.knowledge_graph to accomodate node sampling 12 | """ 13 | visited = set() 14 | current_lvl = set(roots) 15 | 16 | next_lvl = set() 17 | 18 | while current_lvl: 19 | 20 | for v in current_lvl: 21 | visited.add(v) 22 | 23 | next_lvl = _get_neighbors(adj, current_lvl) 24 | next_lvl -= visited # set difference 25 | 26 | if max_nodes_per_hop and max_nodes_per_hop < len(next_lvl): 27 | next_lvl = set(random.sample(next_lvl, max_nodes_per_hop)) 28 | 29 | yield next_lvl 30 | 31 | current_lvl = set.union(next_lvl) 32 | 33 | 34 | def _get_neighbors(adj, nodes): 35 | """Takes a set of nodes and a graph adjacency matrix and returns a set of neighbors. 36 | Directly copied from dgl.contrib.data.knowledge_graph""" 37 | sp_nodes = _sp_row_vec_from_idx_list(list(nodes), adj.shape[1]) 38 | sp_neighbors = sp_nodes.dot(adj) 39 | neighbors = set(ssp.find(sp_neighbors)[1]) # convert to set of indices 40 | return neighbors 41 | 42 | 43 | def _sp_row_vec_from_idx_list(idx_list, dim): 44 | """Create sparse vector of dimensionality dim from a list of indices.""" 45 | shape = (1, dim) 46 | data = np.ones(len(idx_list)) 47 | row_ind = np.zeros(len(idx_list)) 48 | col_ind = list(idx_list) 49 | return ssp.csr_matrix((data, (row_ind, col_ind)), shape=shape) 50 | -------------------------------------------------------------------------------- /utils/graph_utils.py: -------------------------------------------------------------------------------- 1 | import statistics 2 | import numpy as np 3 | import scipy.sparse as ssp 4 | import torch 5 | import networkx as nx 6 | import dgl 7 | import pickle 8 | 9 | 10 | def serialize(data): 11 | data_tuple = tuple(data.values()) 12 | return pickle.dumps(data_tuple) 13 | 14 | 15 | def deserialize(data): 16 | data_tuple = pickle.loads(data) 17 | keys = ('nodes', 'r_label', 'g_label', 'n_label') 18 | return dict(zip(keys, data_tuple)) 19 | 20 | 21 | def get_edge_count(adj_list): 22 | count = [] 23 | for adj in adj_list: 24 | count.append(len(adj.tocoo().row.tolist())) 25 | return np.array(count) 26 | 27 | 28 | def incidence_matrix(adj_list): 29 | ''' 30 | adj_list: List of sparse adjacency matrices 31 | ''' 32 | 33 | rows, cols, dats = [], [], [] 34 | dim = adj_list[0].shape 35 | for adj in adj_list: 36 | adjcoo = adj.tocoo() 37 | rows += adjcoo.row.tolist() 38 | cols += adjcoo.col.tolist() 39 | dats += adjcoo.data.tolist() 40 | row = np.array(rows) 41 | col = np.array(cols) 42 | data = np.array(dats) 43 | return ssp.csc_matrix((data, (row, col)), shape=dim) 44 | 45 | 46 | def remove_nodes(A_incidence, nodes): 47 | idxs_wo_nodes = list(set(range(A_incidence.shape[1])) - set(nodes)) 48 | return A_incidence[idxs_wo_nodes, :][:, idxs_wo_nodes] 49 | 50 | 51 | def ssp_to_torch(A, device, dense=False): 52 | ''' 53 | A : Sparse adjacency matrix 54 | ''' 55 | idx = torch.LongTensor([A.tocoo().row, A.tocoo().col]) 56 | dat = torch.FloatTensor(A.tocoo().data) 57 | A = torch.sparse.FloatTensor(idx, dat, torch.Size([A.shape[0], A.shape[1]])).to(device=device) 58 | return A 59 | 60 | 61 | def ssp_multigraph_to_dgl(graph, n_feats=None): 62 | """ 63 | Converting ssp multigraph (i.e. list of adjs) to dgl multigraph. 64 | """ 65 | 66 | g_nx = nx.MultiDiGraph() 67 | g_nx.add_nodes_from(list(range(graph[0].shape[0]))) 68 | # Add edges 69 | for rel, adj in enumerate(graph): 70 | # Convert adjacency matrix to tuples for nx0 71 | nx_triplets = [] 72 | for src, dst in list(zip(adj.tocoo().row, adj.tocoo().col)): 73 | nx_triplets.append((src, dst, {'type': rel})) 74 | g_nx.add_edges_from(nx_triplets) 75 | 76 | # make dgl graph 77 | g_dgl = dgl.DGLGraph(multigraph=True) 78 | g_dgl.from_networkx(g_nx, edge_attrs=['type']) 79 | # add node features 80 | if n_feats is not None: 81 | g_dgl.ndata['feat'] = torch.tensor(n_feats) 82 | 83 | return g_dgl 84 | 85 | 86 | def collate_dgl(samples): 87 | # The input `samples` is a list of pairs 88 | graphs_pos, g_labels_pos, r_labels_pos, graphs_negs, g_labels_negs, r_labels_negs = map(list, zip(*samples)) 89 | batched_graph_pos = dgl.batch(graphs_pos) 90 | 91 | graphs_neg = [item for sublist in graphs_negs for item in sublist] 92 | g_labels_neg = [item for sublist in g_labels_negs for item in sublist] 93 | r_labels_neg = [item for sublist in r_labels_negs for item in sublist] 94 | 95 | batched_graph_neg = dgl.batch(graphs_neg) 96 | return (batched_graph_pos, r_labels_pos), g_labels_pos, (batched_graph_neg, r_labels_neg), g_labels_neg 97 | 98 | 99 | def move_batch_to_device_dgl(batch, device): 100 | ((g_dgl_pos, r_labels_pos), targets_pos, (g_dgl_neg, r_labels_neg), targets_neg) = batch 101 | 102 | targets_pos = torch.LongTensor(targets_pos).to(device=device) 103 | r_labels_pos = torch.LongTensor(r_labels_pos).to(device=device) 104 | 105 | targets_neg = torch.LongTensor(targets_neg).to(device=device) 106 | r_labels_neg = torch.LongTensor(r_labels_neg).to(device=device) 107 | 108 | g_dgl_pos = send_graph_to_device(g_dgl_pos, device) 109 | g_dgl_neg = send_graph_to_device(g_dgl_neg, device) 110 | 111 | return ((g_dgl_pos, r_labels_pos), targets_pos, (g_dgl_neg, r_labels_neg), targets_neg) 112 | 113 | 114 | def send_graph_to_device(g, device): 115 | # nodes 116 | labels = g.node_attr_schemes() 117 | for l in labels.keys(): 118 | g.ndata[l] = g.ndata.pop(l).to(device) 119 | 120 | # edges 121 | labels = g.edge_attr_schemes() 122 | for l in labels.keys(): 123 | g.edata[l] = g.edata.pop(l).to(device) 124 | return g 125 | 126 | # The following three functions are modified from networks source codes to 127 | # accomodate diameter and radius for dirercted graphs 128 | 129 | 130 | def eccentricity(G): 131 | e = {} 132 | for n in G.nbunch_iter(): 133 | length = nx.single_source_shortest_path_length(G, n) 134 | e[n] = max(length.values()) 135 | return e 136 | 137 | 138 | def radius(G): 139 | e = eccentricity(G) 140 | e = np.where(np.array(list(e.values())) > 0, list(e.values()), np.inf) 141 | return min(e) 142 | 143 | 144 | def diameter(G): 145 | e = eccentricity(G) 146 | return max(e.values()) 147 | -------------------------------------------------------------------------------- /utils/initialization_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import json 4 | import torch 5 | 6 | 7 | def initialize_experiment(params, file_name): 8 | ''' 9 | Makes the experiment directory, sets standard paths and initializes the logger 10 | ''' 11 | params.main_dir = os.path.join(os.path.relpath(os.path.dirname(os.path.abspath(__file__))), '..') 12 | exps_dir = os.path.join(params.main_dir, 'experiments') 13 | if not os.path.exists(exps_dir): 14 | os.makedirs(exps_dir) 15 | 16 | params.exp_dir = os.path.join(exps_dir, params.experiment_name) 17 | 18 | if not os.path.exists(params.exp_dir): 19 | os.makedirs(params.exp_dir) 20 | 21 | if file_name == 'test_auc.py': 22 | params.test_exp_dir = os.path.join(params.exp_dir, f"test_{params.dataset}_{params.constrained_neg_prob}") 23 | if not os.path.exists(params.test_exp_dir): 24 | os.makedirs(params.test_exp_dir) 25 | file_handler = logging.FileHandler(os.path.join(params.test_exp_dir, f"log_test.txt")) 26 | else: 27 | file_handler = logging.FileHandler(os.path.join(params.exp_dir, "log_train.txt")) 28 | logger = logging.getLogger() 29 | logger.addHandler(file_handler) 30 | 31 | logger.info('============ Initialized logger ============') 32 | logger.info('\n'.join('%s: %s' % (k, str(v)) for k, v 33 | in sorted(dict(vars(params)).items()))) 34 | logger.info('============================================') 35 | 36 | with open(os.path.join(params.exp_dir, "params.json"), 'w') as fout: 37 | json.dump(vars(params), fout) 38 | 39 | 40 | def initialize_model(params, model, load_model=False): 41 | ''' 42 | relation2id: the relation to id mapping, this is stored in the model and used when testing 43 | model: the type of model to initialize/load 44 | load_model: flag which decide to initialize the model or load a saved model 45 | ''' 46 | 47 | if load_model and os.path.exists(os.path.join(params.exp_dir, 'best_graph_classifier.pth')): 48 | logging.info('Loading existing model from %s' % os.path.join(params.exp_dir, 'best_graph_classifier.pth')) 49 | graph_classifier = torch.load(os.path.join(params.exp_dir, 'best_graph_classifier.pth')).to(device=params.device) 50 | else: 51 | relation2id_path = os.path.join(params.main_dir, f'data/{params.dataset}/relation2id.json') 52 | with open(relation2id_path) as f: 53 | relation2id = json.load(f) 54 | 55 | logging.info('No existing model found. Initializing new model..') 56 | graph_classifier = model(params, relation2id).to(device=params.device) 57 | 58 | return graph_classifier 59 | -------------------------------------------------------------------------------- /utils/prepare_meta_data.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import os 3 | import math 4 | import random 5 | import argparse 6 | import numpy as np 7 | 8 | from .graph_utils import incidence_matrix, get_edge_count 9 | from .dgl_utils import _bfs_relational 10 | from .data_utils import process_files, save_to_file 11 | 12 | 13 | def get_active_relations(adj_list): 14 | act_rels = [] 15 | for r, adj in enumerate(adj_list): 16 | if len(adj.tocoo().row.tolist()) > 0: 17 | act_rels.append(r) 18 | return act_rels 19 | 20 | 21 | def get_avg_degree(adj_list): 22 | adj_mat = incidence_matrix(adj_list) 23 | degree = [] 24 | for node in range(adj_list[0].shape[0]): 25 | degree.append(np.sum(adj_mat[node, :])) 26 | return np.mean(degree) 27 | 28 | 29 | def get_splits(adj_list, nodes, valid_rels=None, valid_ratio=0.1, test_ratio=0.1): 30 | ''' 31 | Get train/valid/test splits of the sub-graph defined by the given set of nodes. The relations in this subbgraph are limited to be among the given valid_rels. 32 | ''' 33 | 34 | # Extract the subgraph 35 | subgraph = [adj[nodes, :][:, nodes] for adj in adj_list] 36 | 37 | # Get the relations that are allowed to be sampled 38 | active_rels = get_active_relations(subgraph) 39 | common_rels = list(set(active_rels).intersection(set(valid_rels))) 40 | 41 | print('Average degree : ', get_avg_degree(subgraph)) 42 | print('Nodes: ', len(nodes)) 43 | print('Links: ', np.sum(get_edge_count(subgraph))) 44 | print('Active relations: ', len(common_rels)) 45 | 46 | # get all the triplets satisfying the given constraints 47 | all_triplets = [] 48 | for r in common_rels: 49 | # print(r, len(subgraph[r].tocoo().row)) 50 | for (i, j) in zip(subgraph[r].tocoo().row, subgraph[r].tocoo().col): 51 | all_triplets.append([nodes[i], nodes[j], r]) 52 | all_triplets = np.array(all_triplets) 53 | 54 | # delete the triplets which correspond to self connections 55 | ind = np.argwhere(all_triplets[:, 0] == all_triplets[:, 1]) 56 | all_triplets = np.delete(all_triplets, ind, axis=0) 57 | print('Links after deleting self connections : %d' % len(all_triplets)) 58 | 59 | # get the splits according to the given ratio 60 | np.random.shuffle(all_triplets) 61 | train_split = int(math.ceil(len(all_triplets) * (1 - valid_ratio - test_ratio))) 62 | valid_split = int(math.ceil(len(all_triplets) * (1 - test_ratio))) 63 | 64 | train_triplets = all_triplets[:train_split] 65 | valid_triplets = all_triplets[train_split: valid_split] 66 | test_triplets = all_triplets[valid_split:] 67 | 68 | return train_triplets, valid_triplets, test_triplets, common_rels 69 | 70 | 71 | def get_subgraph(adj_list, hops, max_nodes_per_hop): 72 | ''' 73 | Samples a subgraph around randomly chosen root nodes upto hops with a limit on the nodes selected per hop given by max_nodes_per_hop 74 | ''' 75 | 76 | # collapse the list of adj mattricees to a single matrix 77 | A_incidence = incidence_matrix(adj_list) 78 | 79 | # chose a set of random root nodes 80 | idx = np.random.choice(range(len(A_incidence.tocoo().row)), size=params.n_roots, replace=False) 81 | roots = set([A_incidence.tocoo().row[id] for id in idx] + [A_incidence.tocoo().col[id] for id in idx]) 82 | 83 | # get the neighbor nodes within a limit of hops 84 | bfs_generator = _bfs_relational(A_incidence, roots, max_nodes_per_hop) 85 | lvls = list() 86 | for _ in range(hops): 87 | lvls.append(next(bfs_generator)) 88 | 89 | nodes = list(roots) + list(set().union(*lvls)) 90 | 91 | return nodes 92 | 93 | 94 | def mask_nodes(adj_list, nodes): 95 | ''' 96 | mask a set of nodes from a given graph 97 | ''' 98 | 99 | masked_adj_list = [adj.copy() for adj in adj_list] 100 | for node in nodes: 101 | for adj in masked_adj_list: 102 | adj.data[adj.indptr[node]:adj.indptr[node + 1]] = 0 103 | adj = adj.tocsr() 104 | adj.data[adj.indptr[node]:adj.indptr[node + 1]] = 0 105 | adj = adj.tocsc() 106 | for adj in masked_adj_list: 107 | adj.eliminate_zeros() 108 | return masked_adj_list 109 | 110 | 111 | def main(params): 112 | 113 | adj_list, triplets, entity2id, relation2id, id2entity, id2relation = process_files(files) 114 | 115 | meta_train_nodes = get_subgraph(adj_list, params.hops, params.max_nodes_per_hop) # list(range(750, 8500)) # 116 | 117 | masked_adj_list = mask_nodes(adj_list, meta_train_nodes) 118 | 119 | meta_test_nodes = get_subgraph(masked_adj_list, params.hops_test + 1, params.max_nodes_per_hop_test) # list(range(0, 750)) # 120 | 121 | print('Common nodes among the two disjoint datasets (should ideally be zero): ', set(meta_train_nodes).intersection(set(meta_test_nodes))) 122 | tmp = [adj[meta_train_nodes, :][:, meta_train_nodes] for adj in masked_adj_list] 123 | print('Residual edges (should be zero) : ', np.sum(get_edge_count(tmp))) 124 | 125 | print("================") 126 | print("Train graph stats") 127 | print("================") 128 | train_triplets, valid_triplets, test_triplets, train_active_rels = get_splits(adj_list, meta_train_nodes, range(len(adj_list))) 129 | print("================") 130 | print("Meta-test graph stats") 131 | print("================") 132 | meta_train_triplets, meta_valid_triplets, meta_test_triplets, meta_active_rels = get_splits(adj_list, meta_test_nodes, train_active_rels) 133 | 134 | print("================") 135 | print('Extra rels (should be empty): ', set(meta_active_rels) - set(train_active_rels)) 136 | 137 | # TODO: ABSTRACT THIS INTO A METHOD 138 | data_dir = os.path.join(params.main_dir, 'data/{}'.format(params.new_dataset)) 139 | if not os.path.exists(data_dir): 140 | os.makedirs(data_dir) 141 | 142 | save_to_file(data_dir, 'train.txt', train_triplets, id2entity, id2relation) 143 | save_to_file(data_dir, 'valid.txt', valid_triplets, id2entity, id2relation) 144 | save_to_file(data_dir, 'test.txt', test_triplets, id2entity, id2relation) 145 | 146 | meta_data_dir = os.path.join(params.main_dir, 'data/{}'.format(params.new_dataset + '_meta')) 147 | if not os.path.exists(meta_data_dir): 148 | os.makedirs(meta_data_dir) 149 | 150 | save_to_file(meta_data_dir, 'train.txt', meta_train_triplets, id2entity, id2relation) 151 | save_to_file(meta_data_dir, 'valid.txt', meta_valid_triplets, id2entity, id2relation) 152 | save_to_file(meta_data_dir, 'test.txt', meta_test_triplets, id2entity, id2relation) 153 | 154 | 155 | if __name__ == '__main__': 156 | 157 | parser = argparse.ArgumentParser(description='Save adjacency matrtices and triplets') 158 | 159 | parser.add_argument("--dataset", "-d", type=str, default="FB15K237", 160 | help="Dataset string") 161 | parser.add_argument("--new_dataset", "-nd", type=str, default="fb_v3", 162 | help="Dataset string") 163 | parser.add_argument("--n_roots", "-n", type=int, default="1", 164 | help="Number of roots to sample the neighborhood from") 165 | parser.add_argument("--hops", "-H", type=int, default="3", 166 | help="Number of hops to sample the neighborhood") 167 | parser.add_argument("--max_nodes_per_hop", "-m", type=int, default="2500", 168 | help="Number of nodes in the neighborhood") 169 | parser.add_argument("--hops_test", "-HT", type=int, default="3", 170 | help="Number of hops to sample the neighborhood") 171 | parser.add_argument("--max_nodes_per_hop_test", "-mt", type=int, default="2500", 172 | help="Number of nodes in the neighborhood") 173 | parser.add_argument("--seed", "-s", type=int, default="28", 174 | help="Numpy random seed") 175 | 176 | params = parser.parse_args() 177 | 178 | np.random.seed(params.seed) 179 | random.seed(params.seed) 180 | 181 | params.main_dir = os.path.join(os.path.relpath(os.path.dirname(os.path.abspath(__file__))), '..') 182 | 183 | files = { 184 | 'train': os.path.join(params.main_dir, 'data/{}/train.txt'.format(params.dataset)), 185 | 'valid': os.path.join(params.main_dir, 'data/{}/valid.txt'.format(params.dataset)), 186 | 'test': os.path.join(params.main_dir, 'data/{}/test.txt'.format(params.dataset)) 187 | } 188 | 189 | main(params) 190 | --------------------------------------------------------------------------------