├── .gitignore ├── README.md ├── args.py ├── benchmarks ├── DB15K │ ├── 1-1.txt │ ├── 1-n.txt │ ├── entity2id.txt │ ├── n-1.txt │ ├── n-n.py │ ├── n-n.txt │ ├── relation2id.txt │ ├── test2id.txt │ ├── test2id_all.txt │ ├── train2id.txt │ ├── type_constrain.txt │ └── valid2id.txt ├── MKG-W │ ├── 1-1.txt │ ├── 1-n.txt │ ├── entity2id.txt │ ├── n-1.txt │ ├── n-n.py │ ├── n-n.txt │ ├── relation2id.txt │ ├── test2id.txt │ ├── test2id_all.txt │ ├── train2id.txt │ ├── type_constrain.txt │ └── valid2id.txt └── MKG-Y │ ├── 1-1.txt │ ├── 1-n.txt │ ├── entity2id.txt │ ├── n-1.txt │ ├── n-n.py │ ├── n-n.txt │ ├── relation2id.txt │ ├── test2id.txt │ ├── test2id_all.txt │ ├── train2id.txt │ ├── type_constrain.txt │ └── valid2id.txt ├── embeddings ├── .DS_Store └── .keep ├── figure └── model.png ├── mmkgc ├── .DS_Store ├── __init__.py ├── adv │ └── modules.py ├── base │ ├── Base.cpp │ ├── Corrupt.h │ ├── Random.h │ ├── Reader.h │ ├── Setting.h │ ├── Test.h │ └── Triple.h ├── config │ ├── AdvConMixTrainer.py │ ├── AdvConTrainer.py │ ├── AdvMixTrainer.py │ ├── AdvTrainer.py │ ├── MMKRLTrainer.py │ ├── MultiAdvMixTrainer.py │ ├── RSMEAdvTrainer.py │ ├── Tester.py │ ├── Trainer.py │ └── __init__.py ├── data │ ├── PyTorchTrainDataLoader.py │ ├── TestDataLoader.py │ ├── TrainDataLoader.py │ └── __init__.py ├── make.sh ├── module │ ├── BaseModule.py │ ├── __init__.py │ ├── loss │ │ ├── Loss.py │ │ ├── MarginLoss.py │ │ ├── SigmoidLoss.py │ │ ├── SoftplusLoss.py │ │ └── __init__.py │ ├── model │ │ ├── AdvMixRotatE.py │ │ ├── EnsembleComplEx.py │ │ ├── EnsembleMMKGE.py │ │ ├── IKRL.py │ │ ├── MMKRL.py │ │ ├── MMRotatE.py │ │ ├── Model.py │ │ ├── RSME.py │ │ ├── RotatE.py │ │ ├── TBKGC.py │ │ ├── TransAE.py │ │ ├── TransE.py │ │ ├── VBRotatE.py │ │ ├── VBTransE.py │ │ └── __init__.py │ └── strategy │ │ ├── MMKRLNegativeSampling.py │ │ ├── NegativeSampling.py │ │ ├── Strategy.py │ │ ├── TransAENegativeSampling.py │ │ └── __init__.py └── release │ └── Base.so ├── requirements.txt ├── run_adamf_mat.py └── scripts ├── run_db15k.sh ├── run_mkgw.sh └── run_mkgy.sh /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unleashing the Power of Imbalanced Modality Information for Multi-modal Knowledge Graph Completion 2 | 3 | ![](https://img.shields.io/badge/version-1.0.1-blue) 4 | [![license](https://img.shields.io/github/license/mashape/apistatus.svg?maxAge=2592000)](https://github.com/zjukg/AdaMF-MAT/main/LICENSE) 5 | [![Preprint](https://img.shields.io/badge/Preprint'24-brightgreen)](https://arxiv.org/abs/2402.15444) 6 | [![Pytorch](https://img.shields.io/badge/PyTorch-%23EE4C2C.svg?e&logo=PyTorch&logoColor=white)](https://pytorch.org/) 7 | [![COLING2024](https://img.shields.io/badge/COLING-2024-%23bd9f65?labelColor=%2377BBDD&color=3388bb)](https://lrec-coling-2024.org/) 8 | - [Unleashing the Power of Imbalanced Modality Information for Multi-modal Knowledge Graph Completion](https://arxiv.org/abs/2402.15444) 9 | 10 | > Multi-modal knowledge graph completion (MMKGC) aims to predict the missing triples in the multi-modal knowledge graphs by incorporating structural, visual, and textual information of entities into the discriminant models. The information from different modalities will work together to measure the triple plausibility. Existing MMKGC methods overlook the imbalance problem of modality information among entities, resulting in inadequate modal fusion and inefficient utilization of the raw modality information. To address the mentioned problems, we propose Adaptive Multi-modal Fusion and Modality Adversarial Training (AdaMF-MAT) to unleash the power of imbalanced modality information for MMKGC. AdaMF-MAT achieves multi-modal fusion with adaptive modality weights and further generates adversarial samples by modality-adversarial training to enhance the imbalanced modality information. Our approach is a co-design of the MMKGC model and training strategy which can outperform 19 recent MMKGC methods and achieve new state-of-the-art results on three public MMKGC benchmarks. 11 | 12 | ## 🌈 Model Architecture 13 | ![Model_architecture](figure/model.png) 14 | 15 | ## 🔔 News 16 | - `2024-04` We preprint a new paper [MyGO: Discrete Modality Information as Fine-Grained Tokens for Multi-modal Knowledge Graph Completion](https://arxiv.org/abs/2404.09468). 17 | - `2024-03` We release the [Repo](https://github.com/zjukg/NATIVE) for our paper: [NativE: Multi-modal Knowledge Graph Completion in the Wild](https://www.techrxiv.org/doi/full/10.36227/techrxiv.171259566.60211714), **SIGIR 2024**. 18 | - `2024-02` Our paper has been accepted by **LREC-COLING 2024**. 19 | 20 | 21 | ## 💻 Data preparation 22 | We use the MMKG datasets proposed in [MMRNS](https://github.com/quqxui/MMRNS). You can refer to this repo to download the multi-modal embeddings of the MMKGs and put them in `embeddings/`. We prepare a processed version of the multi-modal embeddings and you can download from [Google Drive](https://drive.google.com/drive/folders/1UJSfnb8DEx2s-k8zaQx1fWUw5f45GBpI?usp=sharing) 23 | 24 | ## 🚀 Training and Inference 25 | 26 | You can use the shell scripts in the `scripts/` to conduct the experiments. For example, the following scripts can run an experiments on DB15K 27 | 28 | ```shell 29 | DATA=DB15K 30 | EMB_DIM=250 31 | NUM_BATCH=1024 32 | MARGIN=12 33 | LR=1e-4 34 | LRG=1e-4 35 | NEG_NUM=128 36 | EPOCH=1000 37 | 38 | CUDA_VISIBLE_DEVICES=0 nohup python run_adamf_mat.py -dataset=$DATA \ 39 | -batch_size=$NUM_BATCH \ 40 | -margin=$MARGIN \ 41 | -epoch=$EPOCH \ 42 | -dim=$EMB_DIM \ 43 | -lrg=$LRG \ 44 | -mu=0 \ 45 | -save=./checkpoint/$DATA-$NUM_BATCH-$EMB_DIM-$NEG_NUM-$MARGIN-$LR-$EPOCH \ 46 | -neg_num=$NEG_NUM \ 47 | -learning_rate=$LR > $DATA-$EMB_DIM-$NUM_BATCH-$NEG_NUM-$MARGIN-$EPOCH.txt & 48 | 49 | ``` 50 | 51 | 52 | ## 🤝 Cite: 53 | Please consider citing this paper if you use the code from our work. 54 | Thanks a lot :) 55 | 56 | ```bigquery 57 | 58 | @misc{zhang2024unleashing, 59 | title={Unleashing the Power of Imbalanced Modality Information for Multi-modal Knowledge Graph Completion}, 60 | author={Yichi Zhang and Zhuo Chen and Lei Liang and Huajun Chen and Wen Zhang}, 61 | year={2024}, 62 | eprint={2402.15444}, 63 | archivePrefix={arXiv}, 64 | primaryClass={cs.AI} 65 | } 66 | 67 | ``` 68 | 69 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_args(): 5 | arg = argparse.ArgumentParser() 6 | arg.add_argument('-dataset', type=str, default='FB15K') 7 | arg.add_argument('-batch_size', type=int, default=1024) 8 | arg.add_argument('-margin', type=float, default=6.0) 9 | arg.add_argument('-dim', type=int, default=128) 10 | arg.add_argument('-epoch', type=int, default=1000) 11 | arg.add_argument('-save', type=str) 12 | arg.add_argument('-img_dim', type=int, default=4096) 13 | arg.add_argument('-neg_num', type=int, default=1) 14 | arg.add_argument('-learning_rate', type=float, default=0.001) 15 | arg.add_argument('-lrg', type=float, default=0.0001) 16 | arg.add_argument('-adv_temp', type=float, default=2.0) 17 | arg.add_argument('-visual', type=str, default='random') 18 | arg.add_argument('-seed', type=int, default=42) 19 | arg.add_argument('-missing_rate', type=float, default=0.8) 20 | arg.add_argument('-postfix', type=str, default='') 21 | arg.add_argument('-con_temp', type=float, default=0) 22 | arg.add_argument('-lamda', type=float, default=0) 23 | arg.add_argument('-mu', type=float, required=True) 24 | arg.add_argument('-adv_num', type=int, default=1) 25 | return arg.parse_args() 26 | 27 | 28 | if __name__ == "__main__": 29 | args = get_args() 30 | print(args) 31 | -------------------------------------------------------------------------------- /benchmarks/DB15K/1-1.txt: -------------------------------------------------------------------------------- 1 | 385 2 | 2038 4442 162 3 | 6746 6616 62 4 | 10282 5516 152 5 | 10224 10224 48 6 | 788 714 4 7 | 6286 7897 61 8 | 6458 6459 62 9 | 933 2349 3 10 | 2164 1604 3 11 | 11220 228 133 12 | 10312 8013 152 13 | 5371 193 38 14 | 8725 520 25 15 | 3926 2394 4 16 | 8504 7835 61 17 | 503 1025 135 18 | 256 258 72 19 | 11307 1551 234 20 | 7222 3177 216 21 | 11540 6095 72 22 | 11387 3710 152 23 | 2975 2969 72 24 | 10111 6271 135 25 | 12568 174 272 26 | 209 211 64 27 | 467 468 4 28 | 1524 2592 3 29 | 5723 5725 156 30 | 1259 1048 64 31 | 5517 5516 3 32 | 5798 5465 72 33 | 8 4 8 34 | 2764 2766 10 35 | 12 209 64 36 | 10529 10530 25 37 | 8725 2785 25 38 | 3585 74 64 39 | 7491 2561 94 40 | 596 595 3 41 | 31 330 3 42 | 2084 2084 64 43 | 11185 868 121 44 | 12017 6281 121 45 | 9802 8153 234 46 | 4258 4184 72 47 | 5633 6258 48 48 | 5314 513 198 49 | 8389 1359 64 50 | 9739 1981 121 51 | 9298 9299 64 52 | 6315 409 64 53 | 10545 7155 162 54 | 6970 6979 121 55 | 7563 6804 61 56 | 86 2916 3 57 | 5157 3850 112 58 | 8367 8577 200 59 | 918 916 156 60 | 2887 6283 133 61 | 6434 176 4 62 | 7484 7473 72 63 | 4055 4062 121 64 | 8428 185 3 65 | 9798 9799 159 66 | 12156 9788 120 67 | 2788 2790 8 68 | 2485 4245 3 69 | 1030 176 3 70 | 9718 9717 61 71 | 10142 195 152 72 | 12105 1507 234 73 | 4924 3681 48 74 | 7819 3991 48 75 | 5888 5889 4 76 | 10207 3041 48 77 | 247 248 64 78 | 248 247 64 79 | 1055 1056 3 80 | 4610 4609 3 81 | 3577 484 64 82 | 6027 8793 25 83 | 816 817 3 84 | 6983 7200 72 85 | 4435 4436 4 86 | 3786 861 25 87 | 11140 4960 64 88 | 9417 6927 25 89 | 3990 4376 25 90 | 6401 10411 135 91 | 9837 8712 135 92 | 2803 2257 162 93 | 2121 5047 231 94 | 2673 2674 64 95 | 4 128 4 96 | 1100 1095 8 97 | 1769 2660 64 98 | 1549 178 25 99 | 7010 7025 48 100 | 8263 8264 152 101 | 2664 2665 3 102 | 3720 5180 62 103 | 806 48 45 104 | 3256 2185 156 105 | 918 916 4 106 | 8261 8262 152 107 | 587 3330 3 108 | 12496 1958 112 109 | 262 6544 25 110 | 10857 97 152 111 | 290 466 3 112 | 555 1910 61 113 | 5223 5222 186 114 | 572 566 8 115 | 10139 121 152 116 | 11900 2518 4 117 | 9031 4573 234 118 | 1110 1109 156 119 | 4209 4210 156 120 | 10375 8351 121 121 | 8026 616 8 122 | 4321 3602 61 123 | 3926 2394 156 124 | 2006 1717 64 125 | 5680 939 4 126 | 10391 18 152 127 | 3299 5381 135 128 | 5322 7 4 129 | 12425 12273 64 130 | 948 7645 245 131 | 334 125 64 132 | 5740 1921 4 133 | 4290 6149 61 134 | 5240 715 64 135 | 7432 2530 152 136 | 1885 1882 162 137 | 2346 2347 4 138 | 9104 5989 121 139 | 653 268 64 140 | 9946 3781 234 141 | 4729 4730 3 142 | 411 410 61 143 | 9668 1990 254 144 | 3896 5703 262 145 | 4288 6790 61 146 | 3339 3338 186 147 | 3561 1618 185 148 | 3719 4841 62 149 | 2044 2036 61 150 | 3338 3340 186 151 | 3256 2185 4 152 | 3884 3887 62 153 | 4764 1783 64 154 | 8133 8134 61 155 | 477 478 3 156 | 5706 693 156 157 | 10408 937 152 158 | 29 28 3 159 | 8697 1914 152 160 | 1692 1693 64 161 | 3647 3049 123 162 | 10131 1940 241 163 | 12140 5344 72 164 | 1100 1413 10 165 | 10939 632 25 166 | 707 709 72 167 | 7250 7251 72 168 | 4891 289 64 169 | 5036 3749 61 170 | 5811 5812 156 171 | 9383 9384 62 172 | 9747 2178 152 173 | 5189 1787 25 174 | 4937 565 3 175 | 2846 4204 4 176 | 2073 2074 61 177 | 4515 4514 3 178 | 12757 830 274 179 | 3302 2526 61 180 | 654 652 61 181 | 2579 1580 3 182 | 955 789 3 183 | 5381 6055 61 184 | 58 3897 216 185 | 1673 145 10 186 | 2806 1099 10 187 | 7116 7115 61 188 | 1247 1250 62 189 | 2972 2967 72 190 | 1037 233 60 191 | 5739 569 156 192 | 10180 2761 3 193 | 3941 5123 156 194 | 1272 378 3 195 | 5732 569 64 196 | 159 160 200 197 | 6017 734 25 198 | 1 1606 4 199 | 3009 3010 4 200 | 8909 8910 72 201 | 3058 3933 3 202 | 5202 2634 197 203 | 7730 3082 72 204 | 1352 2261 3 205 | 6930 6929 61 206 | 3197 2835 72 207 | 1784 6954 61 208 | 1910 415 61 209 | 5386 7617 61 210 | 1670 1672 8 211 | 969 2809 61 212 | 5796 881 64 213 | 106 959 55 214 | 7096 1059 162 215 | 3358 3356 64 216 | 707 710 121 217 | 6171 6172 64 218 | 4970 2433 162 219 | 292 292 48 220 | 6447 215 64 221 | 12391 11171 112 222 | 7810 7811 61 223 | 11869 10895 143 224 | 9180 4047 162 225 | 338 2681 3 226 | 4281 4282 156 227 | 10240 7374 152 228 | 6835 1655 234 229 | 6938 7247 25 230 | 4293 3613 133 231 | 4393 322 156 232 | 3982 1519 64 233 | 5840 5841 156 234 | 5401 5402 61 235 | 6179 3962 61 236 | 4571 2598 64 237 | 6564 6565 72 238 | 5840 2483 156 239 | 9812 682 121 240 | 3431 4642 3 241 | 2180 2621 124 242 | 6377 7083 135 243 | 5864 5865 4 244 | 2006 919 64 245 | 3064 576 64 246 | 1245 1244 64 247 | 4217 314 62 248 | 5802 925 4 249 | 9197 9194 48 250 | 3695 1453 61 251 | 5879 2392 64 252 | 10500 4059 72 253 | 5866 5867 156 254 | 11900 2517 4 255 | 606 6916 135 256 | 12128 5170 112 257 | 3750 3751 61 258 | 9194 9195 61 259 | 1961 1962 62 260 | 601 602 61 261 | 12308 5095 152 262 | 3637 3638 61 263 | 681 3299 135 264 | 5230 5232 156 265 | 3332 1409 64 266 | 136 56 4 267 | 6920 5984 143 268 | 9714 817 152 269 | 4634 7749 61 270 | 9846 1604 152 271 | 10637 265 8 272 | 6055 5381 61 273 | 4990 4991 156 274 | 1453 2451 61 275 | 6772 6771 48 276 | 216 214 57 277 | 4509 289 4 278 | 2014 3337 186 279 | 3134 4501 48 280 | 2002 2001 3 281 | 10605 982 152 282 | 4525 3881 156 283 | 12022 3323 120 284 | 11065 6433 240 285 | 7573 7321 61 286 | 7828 413 94 287 | 571 1035 8 288 | 3638 3637 61 289 | 5671 574 4 290 | 12030 9050 112 291 | 13 10 3 292 | 9990 7157 162 293 | 795 1889 3 294 | 4070 4071 156 295 | 9806 221 152 296 | 2857 2868 38 297 | 811 812 3 298 | 7941 5050 160 299 | 5356 4084 4 300 | 5369 5368 48 301 | 4112 4111 4 302 | 1312 2380 8 303 | 3060 2221 72 304 | 881 880 3 305 | 94 94 64 306 | 3975 4665 64 307 | 10289 8125 112 308 | 2358 2359 171 309 | 9802 8154 234 310 | 55 58 25 311 | 5230 5232 4 312 | 4223 4509 213 313 | 8919 3014 61 314 | 5035 5033 61 315 | 9805 221 152 316 | 7444 7440 72 317 | 5303 5305 112 318 | 9839 6474 135 319 | 8802 8801 45 320 | 9758 9757 135 321 | 6402 9049 61 322 | 11078 769 152 323 | 1211 6854 4 324 | 10224 10224 62 325 | 4283 4284 4 326 | 11236 235 112 327 | 3345 3344 64 328 | 87 379 4 329 | 11662 4573 234 330 | 186 187 10 331 | 7286 4440 64 332 | 12156 9787 120 333 | 218 130 204 334 | 10497 1571 152 335 | 7718 7718 245 336 | 123 125 40 337 | 5613 382 62 338 | 6378 125 61 339 | 8818 7199 123 340 | 391 390 3 341 | 1040 134 8 342 | 2195 905 45 343 | 6398 289 61 344 | 4579 4581 156 345 | 2570 645 165 346 | 4065 200 198 347 | 3224 1601 10 348 | 9262 3971 240 349 | 11154 4863 164 350 | 1427 1430 61 351 | 7452 828 216 352 | 64 2627 64 353 | 5808 2543 64 354 | 5027 3259 4 355 | 3938 4701 25 356 | 6017 4376 25 357 | 9027 3855 121 358 | 5634 820 64 359 | 10022 10023 135 360 | 10482 8847 61 361 | 6983 7201 72 362 | 9767 6518 72 363 | 4372 1352 64 364 | 4840 5813 156 365 | 9459 5038 160 366 | 7510 7509 61 367 | 623 624 72 368 | 11539 7940 121 369 | 316 1811 3 370 | 5722 1959 4 371 | 2040 6392 61 372 | 496 145 124 373 | 1446 1445 4 374 | 4845 7445 61 375 | 3346 3345 64 376 | 9235 8520 152 377 | 4214 7572 25 378 | 687 9224 25 379 | 12050 3008 120 380 | 10890 1445 61 381 | 2488 9016 48 382 | 2900 2899 4 383 | 6518 9590 61 384 | 7566 4321 61 385 | 406 280 4 386 | 9990 5165 162 387 | -------------------------------------------------------------------------------- /benchmarks/DB15K/1-n.txt: -------------------------------------------------------------------------------- 1 | 195 2 | 9736 7548 63 3 | 5382 3806 113 4 | 7740 6399 74 5 | 6502 6756 100 6 | 8482 8482 184 7 | 5335 3143 74 8 | 6874 5024 63 9 | 7135 9863 179 10 | 7135 9002 74 11 | 5106 9945 100 12 | 11645 11644 100 13 | 2805 174 31 14 | 11339 1628 63 15 | 3056 3295 184 16 | 5669 2486 100 17 | 8813 6350 106 18 | 4784 3875 63 19 | 6899 6900 106 20 | 6702 4226 106 21 | 2972 2967 74 22 | 487 2932 106 23 | 3207 804 144 24 | 2912 10092 106 25 | 12672 3276 113 26 | 1480 632 176 27 | 4445 8410 106 28 | 1658 13 206 29 | 876 874 74 30 | 10094 10093 63 31 | 6562 9927 63 32 | 8026 2628 184 33 | 8678 8677 74 34 | 11508 1585 106 35 | 4631 2915 100 36 | 9924 205 206 37 | 10131 10132 74 38 | 2561 7135 184 39 | 4631 6940 100 40 | 7904 2929 148 41 | 3404 9024 148 42 | 11084 5285 106 43 | 8610 6954 179 44 | 4117 6221 119 45 | 2975 2969 74 46 | 8118 627 74 47 | 10787 10783 63 48 | 7829 7001 63 49 | 6982 6976 179 50 | 4539 4483 106 51 | 2968 4381 113 52 | 492 2644 144 53 | 12237 9632 106 54 | 4409 9329 100 55 | 5629 5627 63 56 | 8946 467 31 57 | 6677 6679 106 58 | 9112 3041 74 59 | 11240 7424 179 60 | 7425 3000 220 61 | 1658 2598 206 62 | 7233 1494 100 63 | 513 3025 184 64 | 11951 7356 106 65 | 8245 8249 74 66 | 11179 8914 100 67 | 6263 1221 106 68 | 3843 3887 63 69 | 6970 6975 179 70 | 10649 4968 106 71 | 7882 2229 31 72 | 8610 2430 179 73 | 3156 10769 100 74 | 12035 8250 74 75 | 3121 2978 63 76 | 2246 2248 113 77 | 11880 10034 136 78 | 11478 9173 246 79 | 7455 2998 220 80 | 4619 4634 74 81 | 3741 3742 113 82 | 1669 5604 63 83 | 7500 6399 74 84 | 8012 5605 106 85 | 4258 4260 179 86 | 12140 5344 74 87 | 113 188 139 88 | 11455 1987 206 89 | 4631 8657 100 90 | 1229 1235 139 91 | 4055 4056 74 92 | 4066 4944 74 93 | 2892 1401 31 94 | 10299 10300 100 95 | 8813 5932 106 96 | 8629 4259 179 97 | 8222 8221 184 98 | 11179 9068 100 99 | 6283 2887 63 100 | 2246 2249 106 101 | 8616 8088 63 102 | 10301 710 106 103 | 5106 9943 100 104 | 7135 871 179 105 | 6181 6182 74 106 | 1304 943 74 107 | 2974 5010 106 108 | 8906 8472 63 109 | 4676 1014 206 110 | 7487 2675 246 111 | 12035 11207 74 112 | 1810 3622 144 113 | 10338 7056 74 114 | 11084 2171 106 115 | 2855 2853 113 116 | 113 2080 139 117 | 2458 2433 179 118 | 3498 99 31 119 | 6982 5171 179 120 | 2917 796 125 121 | 8129 2598 106 122 | 10785 1579 74 123 | 7412 4430 179 124 | 9924 174 206 125 | 7819 2479 63 126 | 8954 8955 63 127 | 12443 6644 113 128 | 7356 8397 106 129 | 12415 3949 139 130 | 10562 6757 179 131 | 2047 3264 106 132 | 10649 3535 106 133 | 9788 9632 106 134 | 10783 9652 136 135 | 2463 3289 113 136 | 6663 6664 136 137 | 113 101 168 138 | 11478 5928 246 139 | 11854 11857 106 140 | 1229 1233 139 141 | 3999 725 113 142 | 12528 8405 113 143 | 7436 7440 179 144 | 8146 8143 100 145 | 10322 7424 179 146 | 9788 7872 106 147 | 7524 919 206 148 | 2279 2280 100 149 | 3735 3895 113 150 | 8322 9880 136 151 | 1308 1954 136 152 | 7730 3082 74 153 | 2281 7842 179 154 | 8493 8494 113 155 | 12479 2909 113 156 | 8669 5066 106 157 | 5370 1595 235 158 | 12237 10162 106 159 | 3054 2803 184 160 | 11924 7528 74 161 | 3449 1800 119 162 | 1317 181 144 163 | 8547 8548 184 164 | 459 2924 100 165 | 7968 5419 63 166 | 5217 3127 148 167 | 9652 5290 74 168 | 12528 2058 113 169 | 1626 1938 106 170 | 7305 2446 74 171 | 9954 1352 206 172 | 9193 7178 106 173 | 5167 5171 179 174 | 4010 4012 106 175 | 7155 3054 148 176 | 1658 2351 206 177 | 6669 6671 100 178 | 8954 8956 63 179 | 11825 11135 74 180 | 1451 1452 74 181 | 4631 8658 100 182 | 7732 3081 100 183 | 6564 6569 74 184 | 8662 74 206 185 | 2972 2974 74 186 | 799 3129 106 187 | 8610 312 179 188 | 1658 933 206 189 | 8291 303 63 190 | 7209 9990 184 191 | 7250 7251 74 192 | 7359 7358 113 193 | 4316 2861 106 194 | 3084 1837 31 195 | 2281 7841 179 196 | 3156 6826 100 197 | -------------------------------------------------------------------------------- /benchmarks/DB15K/n-n.py: -------------------------------------------------------------------------------- 1 | lef = {} 2 | rig = {} 3 | rellef = {} 4 | relrig = {} 5 | 6 | triple = open("train2id.txt", "r") 7 | valid = open("valid2id.txt", "r") 8 | test = open("test2id.txt", "r") 9 | 10 | tot = (int)(triple.readline()) 11 | for i in range(tot): 12 | content = triple.readline() 13 | h,t,r = content.strip().split() 14 | if not (h,r) in lef: 15 | lef[(h,r)] = [] 16 | if not (r,t) in rig: 17 | rig[(r,t)] = [] 18 | lef[(h,r)].append(t) 19 | rig[(r,t)].append(h) 20 | if not r in rellef: 21 | rellef[r] = {} 22 | if not r in relrig: 23 | relrig[r] = {} 24 | rellef[r][h] = 1 25 | relrig[r][t] = 1 26 | 27 | tot = (int)(valid.readline()) 28 | for i in range(tot): 29 | content = valid.readline() 30 | h,t,r = content.strip().split() 31 | if not (h,r) in lef: 32 | lef[(h,r)] = [] 33 | if not (r,t) in rig: 34 | rig[(r,t)] = [] 35 | lef[(h,r)].append(t) 36 | rig[(r,t)].append(h) 37 | if not r in rellef: 38 | rellef[r] = {} 39 | if not r in relrig: 40 | relrig[r] = {} 41 | rellef[r][h] = 1 42 | relrig[r][t] = 1 43 | 44 | tot = (int)(test.readline()) 45 | for i in range(tot): 46 | content = test.readline() 47 | h,t,r = content.strip().split() 48 | if not (h,r) in lef: 49 | lef[(h,r)] = [] 50 | if not (r,t) in rig: 51 | rig[(r,t)] = [] 52 | lef[(h,r)].append(t) 53 | rig[(r,t)].append(h) 54 | if not r in rellef: 55 | rellef[r] = {} 56 | if not r in relrig: 57 | relrig[r] = {} 58 | rellef[r][h] = 1 59 | relrig[r][t] = 1 60 | 61 | test.close() 62 | valid.close() 63 | triple.close() 64 | 65 | f = open("type_constrain.txt", "w") 66 | f.write("%d\n"%(len(rellef))) 67 | for i in rellef: 68 | f.write("%s\t%d"%(i,len(rellef[i]))) 69 | for j in rellef[i]: 70 | f.write("\t%s"%(j)) 71 | f.write("\n") 72 | f.write("%s\t%d"%(i,len(relrig[i]))) 73 | for j in relrig[i]: 74 | f.write("\t%s"%(j)) 75 | f.write("\n") 76 | f.close() 77 | 78 | rellef = {} 79 | totlef = {} 80 | relrig = {} 81 | totrig = {} 82 | # lef: (h, r) 83 | # rig: (r, t) 84 | for i in lef: 85 | if not i[1] in rellef: 86 | rellef[i[1]] = 0 87 | totlef[i[1]] = 0 88 | rellef[i[1]] += len(lef[i]) 89 | totlef[i[1]] += 1.0 90 | 91 | for i in rig: 92 | if not i[0] in relrig: 93 | relrig[i[0]] = 0 94 | totrig[i[0]] = 0 95 | relrig[i[0]] += len(rig[i]) 96 | totrig[i[0]] += 1.0 97 | 98 | s11=0 99 | s1n=0 100 | sn1=0 101 | snn=0 102 | f = open("test2id.txt", "r") 103 | tot = (int)(f.readline()) 104 | for i in range(tot): 105 | content = f.readline() 106 | h,t,r = content.strip().split() 107 | rign = rellef[r] / totlef[r] 108 | lefn = relrig[r] / totrig[r] 109 | if (rign < 1.5 and lefn < 1.5): 110 | s11+=1 111 | if (rign >= 1.5 and lefn < 1.5): 112 | s1n+=1 113 | if (rign < 1.5 and lefn >= 1.5): 114 | sn1+=1 115 | if (rign >= 1.5 and lefn >= 1.5): 116 | snn+=1 117 | f.close() 118 | 119 | 120 | f = open("test2id.txt", "r") 121 | f11 = open("1-1.txt", "w") 122 | f1n = open("1-n.txt", "w") 123 | fn1 = open("n-1.txt", "w") 124 | fnn = open("n-n.txt", "w") 125 | fall = open("test2id_all.txt", "w") 126 | tot = (int)(f.readline()) 127 | fall.write("%d\n"%(tot)) 128 | f11.write("%d\n"%(s11)) 129 | f1n.write("%d\n"%(s1n)) 130 | fn1.write("%d\n"%(sn1)) 131 | fnn.write("%d\n"%(snn)) 132 | for i in range(tot): 133 | content = f.readline() 134 | h,t,r = content.strip().split() 135 | rign = rellef[r] / totlef[r] 136 | lefn = relrig[r] / totrig[r] 137 | if (rign < 1.5 and lefn < 1.5): 138 | f11.write(content) 139 | fall.write("0"+"\t"+content) 140 | if (rign >= 1.5 and lefn < 1.5): 141 | f1n.write(content) 142 | fall.write("1"+"\t"+content) 143 | if (rign < 1.5 and lefn >= 1.5): 144 | fn1.write(content) 145 | fall.write("2"+"\t"+content) 146 | if (rign >= 1.5 and lefn >= 1.5): 147 | fnn.write(content) 148 | fall.write("3"+"\t"+content) 149 | fall.close() 150 | f.close() 151 | f11.close() 152 | f1n.close() 153 | fn1.close() 154 | fnn.close() 155 | -------------------------------------------------------------------------------- /benchmarks/MKG-W/1-1.txt: -------------------------------------------------------------------------------- 1 | 247 2 | 132 131 28 3 | 2611 2612 126 4 | 4112 2353 114 5 | 12924 10989 8 6 | 8873 13995 75 7 | 7401 9613 14 8 | 6078 10917 84 9 | 9254 2296 26 10 | 9939 6224 8 11 | 10171 8602 14 12 | 2861 4541 14 13 | 1935 4144 8 14 | 9966 9965 14 15 | 9923 5352 26 16 | 11908 4315 8 17 | 12490 6081 8 18 | 6928 6927 14 19 | 9758 8104 14 20 | 14227 14226 8 21 | 11000 408 14 22 | 10498 12468 90 23 | 7631 1271 14 24 | 12684 6084 8 25 | 11186 10338 14 26 | 9144 14142 90 27 | 4149 7661 14 28 | 7055 2866 8 29 | 7758 1217 77 30 | 6737 1622 84 31 | 10738 7001 8 32 | 12611 12610 14 33 | 2287 11093 76 34 | 12463 12126 39 35 | 11779 11838 14 36 | 4250 4032 90 37 | 3771 2698 17 38 | 9924 1381 85 39 | 5418 6734 14 40 | 5895 9852 8 41 | 10524 7168 14 42 | 3913 6214 84 43 | 8046 11409 14 44 | 11021 1112 8 45 | 8280 6348 14 46 | 14830 14831 8 47 | 4948 473 14 48 | 12688 1779 8 49 | 10909 4791 8 50 | 10240 7968 8 51 | 14481 11717 14 52 | 2578 11047 84 53 | 960 959 8 54 | 8372 957 24 55 | 13686 13685 14 56 | 13373 13372 8 57 | 11945 1629 14 58 | 1577 9898 90 59 | 13155 9400 8 60 | 9596 9595 8 61 | 5264 5263 8 62 | 12331 12330 14 63 | 850 1946 90 64 | 4943 13587 8 65 | 5816 5815 14 66 | 11160 11159 8 67 | 4222 3273 114 68 | 10663 1022 52 69 | 10696 7910 8 70 | 3822 5441 52 71 | 971 247 74 72 | 2464 1356 77 73 | 11034 11033 14 74 | 5075 6511 8 75 | 12609 12608 14 76 | 5368 2650 77 77 | 12555 12554 8 78 | 3132 8933 8 79 | 14583 14866 14 80 | 10622 1080 82 81 | 2077 12446 14 82 | 9127 9126 66 83 | 12031 12030 14 84 | 2649 3416 8 85 | 11071 4008 17 86 | 14377 14376 14 87 | 12314 7816 14 88 | 5236 5235 8 89 | 14879 1194 126 90 | 7332 9542 14 91 | 270 4229 26 92 | 7551 7655 14 93 | 1083 1082 14 94 | 13398 9475 14 95 | 538 1683 77 96 | 12626 12646 14 97 | 14012 3490 85 98 | 2428 2358 14 99 | 271 9026 26 100 | 10326 10325 66 101 | 14866 14583 8 102 | 5818 12713 14 103 | 1427 14175 14 104 | 123 12471 17 105 | 5460 6401 39 106 | 11964 4926 26 107 | 13397 1872 8 108 | 12059 12058 8 109 | 3553 3315 8 110 | 8803 11183 76 111 | 9722 13695 39 112 | 2912 4706 39 113 | 7570 7569 8 114 | 9756 9755 14 115 | 4541 2861 8 116 | 14897 14898 8 117 | 2081 2692 8 118 | 11258 11257 66 119 | 7618 1356 87 120 | 14343 14342 14 121 | 5613 5612 14 122 | 6340 7835 8 123 | 13046 13045 14 124 | 13852 12392 14 125 | 14195 9167 143 126 | 9674 8879 8 127 | 19 88 84 128 | 11868 11867 8 129 | 11264 11263 8 130 | 8138 14375 8 131 | 3269 19 52 132 | 11226 11225 8 133 | 12043 12042 8 134 | 11848 11847 14 135 | 1569 1568 8 136 | 4137 11371 14 137 | 7312 7875 8 138 | 3898 8870 14 139 | 14831 14830 14 140 | 7701 605 17 141 | 11192 13914 8 142 | 14651 7275 14 143 | 12817 14117 42 144 | 11224 2848 8 145 | 125 124 84 146 | 2898 2897 14 147 | 12570 13643 14 148 | 6043 6042 8 149 | 14418 14417 8 150 | 5011 11940 8 151 | 14790 14940 8 152 | 3983 6296 8 153 | 12849 13130 8 154 | 14940 14790 14 155 | 12727 11579 8 156 | 8764 11122 66 157 | 14949 14950 14 158 | 2464 6591 76 159 | 12235 1966 14 160 | 3107 8756 14 161 | 7151 14476 14 162 | 14656 13035 8 163 | 12274 13642 8 164 | 7120 7119 8 165 | 9000 11843 14 166 | 6084 12684 24 167 | 1059 1058 74 168 | 8445 8444 14 169 | 11843 9000 8 170 | 3634 4434 14 171 | 14553 528 109 172 | 3206 127 8 173 | 10379 8206 8 174 | 12570 6166 8 175 | 11371 4137 8 176 | 14056 3079 14 177 | 14898 14897 14 178 | 11833 14229 42 179 | 13544 7720 14 180 | 12600 10988 8 181 | 7843 12852 14 182 | 4802 13562 14 183 | 1112 11021 14 184 | 8606 7036 14 185 | 11760 11759 8 186 | 1252 1251 14 187 | 2681 2680 14 188 | 14727 14726 14 189 | 2400 5807 14 190 | 9256 9255 14 191 | 2149 14605 8 192 | 6714 6713 8 193 | 9158 3391 119 194 | 13881 13880 14 195 | 9668 10401 82 196 | 13914 11192 14 197 | 7995 3323 14 198 | 1463 6914 14 199 | 9475 13398 8 200 | 9333 9166 8 201 | 2580 3506 14 202 | 12647 11873 14 203 | 8755 1587 14 204 | 10671 10670 42 205 | 5982 5981 8 206 | 7862 1062 14 207 | 11043 11655 8 208 | 762 2664 39 209 | 4441 9147 8 210 | 233 232 14 211 | 11800 11799 8 212 | 9211 7438 14 213 | 13445 13444 14 214 | 3079 14056 8 215 | 7605 2906 39 216 | 11714 11266 14 217 | 14559 14243 8 218 | 2695 377 77 219 | 10197 10196 39 220 | 1136 14606 76 221 | 14950 14949 8 222 | 8308 2701 77 223 | 2024 1023 14 224 | 8994 10843 14 225 | 14243 14559 14 226 | 10500 10839 8 227 | 2105 1056 74 228 | 14380 14379 8 229 | 8113 7265 17 230 | 9495 11518 14 231 | 5037 10765 14 232 | 13146 9245 8 233 | 3315 3553 14 234 | 3457 7147 39 235 | 3697 10290 76 236 | 12672 12671 14 237 | 9130 6966 24 238 | 9751 9750 14 239 | 10759 10758 14 240 | 3206 13529 14 241 | 12552 10142 8 242 | 14747 14746 14 243 | 6966 9130 8 244 | 13994 13993 8 245 | 14335 14334 14 246 | 7772 14484 8 247 | 7728 17 77 248 | 13820 7903 14 249 | -------------------------------------------------------------------------------- /benchmarks/MKG-W/1-n.txt: -------------------------------------------------------------------------------- 1 | 18 2 | 8471 962 62 3 | 1238 9964 62 4 | 6401 4795 23 5 | 1466 1467 23 6 | 10454 8038 62 7 | 2092 918 81 8 | 11794 13327 23 9 | 3837 829 102 10 | 7147 13462 23 11 | 2232 1094 62 12 | 3837 2596 102 13 | 141 10917 23 14 | 4977 5074 62 15 | 141 1120 23 16 | 4165 2617 62 17 | 4706 10695 23 18 | 141 6599 23 19 | 13963 9258 101 20 | -------------------------------------------------------------------------------- /benchmarks/MKG-W/n-n.py: -------------------------------------------------------------------------------- 1 | lef = {} 2 | rig = {} 3 | rellef = {} 4 | relrig = {} 5 | 6 | triple = open("train2id.txt", "r") 7 | valid = open("valid2id.txt", "r") 8 | test = open("test2id.txt", "r") 9 | 10 | tot = (int)(triple.readline()) 11 | for i in range(tot): 12 | content = triple.readline() 13 | h,t,r = content.strip().split() 14 | if not (h,r) in lef: 15 | lef[(h,r)] = [] 16 | if not (r,t) in rig: 17 | rig[(r,t)] = [] 18 | lef[(h,r)].append(t) 19 | rig[(r,t)].append(h) 20 | if not r in rellef: 21 | rellef[r] = {} 22 | if not r in relrig: 23 | relrig[r] = {} 24 | rellef[r][h] = 1 25 | relrig[r][t] = 1 26 | 27 | tot = (int)(valid.readline()) 28 | for i in range(tot): 29 | content = valid.readline() 30 | h,t,r = content.strip().split() 31 | if not (h,r) in lef: 32 | lef[(h,r)] = [] 33 | if not (r,t) in rig: 34 | rig[(r,t)] = [] 35 | lef[(h,r)].append(t) 36 | rig[(r,t)].append(h) 37 | if not r in rellef: 38 | rellef[r] = {} 39 | if not r in relrig: 40 | relrig[r] = {} 41 | rellef[r][h] = 1 42 | relrig[r][t] = 1 43 | 44 | tot = (int)(test.readline()) 45 | for i in range(tot): 46 | content = test.readline() 47 | h,t,r = content.strip().split() 48 | if not (h,r) in lef: 49 | lef[(h,r)] = [] 50 | if not (r,t) in rig: 51 | rig[(r,t)] = [] 52 | lef[(h,r)].append(t) 53 | rig[(r,t)].append(h) 54 | if not r in rellef: 55 | rellef[r] = {} 56 | if not r in relrig: 57 | relrig[r] = {} 58 | rellef[r][h] = 1 59 | relrig[r][t] = 1 60 | 61 | test.close() 62 | valid.close() 63 | triple.close() 64 | 65 | f = open("type_constrain.txt", "w") 66 | f.write("%d\n"%(len(rellef))) 67 | for i in rellef: 68 | f.write("%s\t%d"%(i,len(rellef[i]))) 69 | for j in rellef[i]: 70 | f.write("\t%s"%(j)) 71 | f.write("\n") 72 | f.write("%s\t%d"%(i,len(relrig[i]))) 73 | for j in relrig[i]: 74 | f.write("\t%s"%(j)) 75 | f.write("\n") 76 | f.close() 77 | 78 | rellef = {} 79 | totlef = {} 80 | relrig = {} 81 | totrig = {} 82 | # lef: (h, r) 83 | # rig: (r, t) 84 | for i in lef: 85 | if not i[1] in rellef: 86 | rellef[i[1]] = 0 87 | totlef[i[1]] = 0 88 | rellef[i[1]] += len(lef[i]) 89 | totlef[i[1]] += 1.0 90 | 91 | for i in rig: 92 | if not i[0] in relrig: 93 | relrig[i[0]] = 0 94 | totrig[i[0]] = 0 95 | relrig[i[0]] += len(rig[i]) 96 | totrig[i[0]] += 1.0 97 | 98 | s11=0 99 | s1n=0 100 | sn1=0 101 | snn=0 102 | f = open("test2id.txt", "r") 103 | tot = (int)(f.readline()) 104 | for i in range(tot): 105 | content = f.readline() 106 | h,t,r = content.strip().split() 107 | rign = rellef[r] / totlef[r] 108 | lefn = relrig[r] / totrig[r] 109 | if (rign < 1.5 and lefn < 1.5): 110 | s11+=1 111 | if (rign >= 1.5 and lefn < 1.5): 112 | s1n+=1 113 | if (rign < 1.5 and lefn >= 1.5): 114 | sn1+=1 115 | if (rign >= 1.5 and lefn >= 1.5): 116 | snn+=1 117 | f.close() 118 | 119 | 120 | f = open("test2id.txt", "r") 121 | f11 = open("1-1.txt", "w") 122 | f1n = open("1-n.txt", "w") 123 | fn1 = open("n-1.txt", "w") 124 | fnn = open("n-n.txt", "w") 125 | fall = open("test2id_all.txt", "w") 126 | tot = (int)(f.readline()) 127 | fall.write("%d\n"%(tot)) 128 | f11.write("%d\n"%(s11)) 129 | f1n.write("%d\n"%(s1n)) 130 | fn1.write("%d\n"%(sn1)) 131 | fnn.write("%d\n"%(snn)) 132 | for i in range(tot): 133 | content = f.readline() 134 | h,t,r = content.strip().split() 135 | rign = rellef[r] / totlef[r] 136 | lefn = relrig[r] / totrig[r] 137 | if (rign < 1.5 and lefn < 1.5): 138 | f11.write(content) 139 | fall.write("0"+"\t"+content) 140 | if (rign >= 1.5 and lefn < 1.5): 141 | f1n.write(content) 142 | fall.write("1"+"\t"+content) 143 | if (rign < 1.5 and lefn >= 1.5): 144 | fn1.write(content) 145 | fall.write("2"+"\t"+content) 146 | if (rign >= 1.5 and lefn >= 1.5): 147 | fnn.write(content) 148 | fall.write("3"+"\t"+content) 149 | fall.close() 150 | f.close() 151 | f11.close() 152 | f1n.close() 153 | fn1.close() 154 | fnn.close() 155 | -------------------------------------------------------------------------------- /benchmarks/MKG-W/relation2id.txt: -------------------------------------------------------------------------------- 1 | 169 2 | http://www.wikidata.org/entity/P161 0 3 | http://www.wikidata.org/entity/P58 1 4 | http://www.wikidata.org/entity/P106 2 5 | http://www.wikidata.org/entity/P175 3 6 | http://www.wikidata.org/entity/P840 4 7 | http://www.wikidata.org/entity/P162 5 8 | http://www.wikidata.org/entity/P20 6 9 | http://www.wikidata.org/entity/P344 7 10 | http://www.wikidata.org/entity/P156 8 11 | http://www.wikidata.org/entity/P17 9 12 | http://www.wikidata.org/entity/P750 10 13 | http://www.wikidata.org/entity/P86 11 14 | http://www.wikidata.org/entity/P1303 12 15 | http://www.wikidata.org/entity/P31 13 16 | http://www.wikidata.org/entity/P155 14 17 | http://www.wikidata.org/entity/P364 15 18 | http://www.wikidata.org/entity/P264 16 19 | http://www.wikidata.org/entity/P279 17 20 | http://www.wikidata.org/entity/P54 18 21 | http://www.wikidata.org/entity/P449 19 22 | http://www.wikidata.org/entity/P30 20 23 | http://www.wikidata.org/entity/P495 21 24 | http://www.wikidata.org/entity/P57 22 25 | http://www.wikidata.org/entity/P150 23 26 | http://www.wikidata.org/entity/P22 24 27 | http://www.wikidata.org/entity/P140 25 28 | http://www.wikidata.org/entity/P36 26 29 | http://www.wikidata.org/entity/P921 27 30 | http://www.wikidata.org/entity/P7 28 31 | http://www.wikidata.org/entity/P190 29 32 | http://www.wikidata.org/entity/P19 30 33 | http://www.wikidata.org/entity/P69 31 34 | http://www.wikidata.org/entity/P136 32 35 | http://www.wikidata.org/entity/P102 33 36 | http://www.wikidata.org/entity/P27 34 37 | http://www.wikidata.org/entity/P172 35 38 | http://www.wikidata.org/entity/P135 36 39 | http://www.wikidata.org/entity/P272 37 40 | http://www.wikidata.org/entity/P361 38 41 | http://www.wikidata.org/entity/P47 39 42 | http://www.wikidata.org/entity/P108 40 43 | http://www.wikidata.org/entity/P937 41 44 | http://www.wikidata.org/entity/P197 42 45 | http://www.wikidata.org/entity/P126 43 46 | http://www.wikidata.org/entity/P509 44 47 | http://www.wikidata.org/entity/P437 45 48 | http://www.wikidata.org/entity/P50 46 49 | http://www.wikidata.org/entity/P674 47 50 | http://www.wikidata.org/entity/P131 48 51 | http://www.wikidata.org/entity/P413 49 52 | http://www.wikidata.org/entity/P179 50 53 | http://www.wikidata.org/entity/P1142 51 54 | http://www.wikidata.org/entity/P551 52 55 | http://www.wikidata.org/entity/P166 53 56 | http://www.wikidata.org/entity/P118 54 57 | http://www.wikidata.org/entity/P171 55 58 | http://www.wikidata.org/entity/P1040 56 59 | http://www.wikidata.org/entity/P607 57 60 | http://www.wikidata.org/entity/P123 58 61 | http://www.wikidata.org/entity/P647 59 62 | http://www.wikidata.org/entity/P2348 60 63 | http://www.wikidata.org/entity/P400 61 64 | http://www.wikidata.org/entity/P737 62 65 | http://www.wikidata.org/entity/P749 63 66 | http://www.wikidata.org/entity/P676 64 67 | http://www.wikidata.org/entity/P127 65 68 | http://www.wikidata.org/entity/P40 66 69 | http://www.wikidata.org/entity/P159 67 70 | http://www.wikidata.org/entity/P206 68 71 | http://www.wikidata.org/entity/P421 69 72 | http://www.wikidata.org/entity/P185 70 73 | http://www.wikidata.org/entity/P178 71 74 | http://www.wikidata.org/entity/P915 72 75 | http://www.wikidata.org/entity/P119 73 76 | http://www.wikidata.org/entity/P355 74 77 | http://www.wikidata.org/entity/P26 75 78 | http://www.wikidata.org/entity/P144 76 79 | http://www.wikidata.org/entity/P1431 77 80 | http://www.wikidata.org/entity/P275 78 81 | http://www.wikidata.org/entity/P1412 79 82 | http://www.wikidata.org/entity/P463 80 83 | http://www.wikidata.org/entity/P530 81 84 | http://www.wikidata.org/entity/P800 82 85 | http://www.wikidata.org/entity/P2438 83 86 | http://www.wikidata.org/entity/P1376 84 87 | http://www.wikidata.org/entity/P287 85 88 | http://www.wikidata.org/entity/P452 86 89 | http://www.wikidata.org/entity/P112 87 90 | http://www.wikidata.org/entity/P469 88 91 | http://www.wikidata.org/entity/P170 89 92 | http://www.wikidata.org/entity/P527 90 93 | http://www.wikidata.org/entity/P25 91 94 | http://www.wikidata.org/entity/P37 92 95 | http://www.wikidata.org/entity/P241 93 96 | http://www.wikidata.org/entity/P611 94 97 | http://www.wikidata.org/entity/P641 95 98 | http://www.wikidata.org/entity/P1336 96 99 | http://www.wikidata.org/entity/P2541 97 100 | http://www.wikidata.org/entity/P276 98 101 | http://www.wikidata.org/entity/P1344 99 102 | http://www.wikidata.org/entity/P84 100 103 | http://www.wikidata.org/entity/P1441 101 104 | http://www.wikidata.org/entity/P2512 102 105 | http://www.wikidata.org/entity/P408 103 106 | http://www.wikidata.org/entity/P137 104 107 | http://www.wikidata.org/entity/P149 105 108 | http://www.wikidata.org/entity/P176 106 109 | http://www.wikidata.org/entity/P138 107 110 | http://www.wikidata.org/entity/P740 108 111 | http://www.wikidata.org/entity/P101 109 112 | http://www.wikidata.org/entity/P2554 110 113 | http://www.wikidata.org/entity/P706 111 114 | http://www.wikidata.org/entity/P103 112 115 | http://www.wikidata.org/entity/P1387 113 116 | http://www.wikidata.org/entity/P451 114 117 | http://www.wikidata.org/entity/P9 115 118 | http://www.wikidata.org/entity/P710 116 119 | http://www.wikidata.org/entity/P767 117 120 | http://www.wikidata.org/entity/P945 118 121 | http://www.wikidata.org/entity/P184 119 122 | http://www.wikidata.org/entity/P1056 120 123 | http://www.wikidata.org/entity/P1408 121 124 | http://www.wikidata.org/entity/P735 122 125 | http://www.wikidata.org/entity/P87 123 126 | http://www.wikidata.org/entity/P277 124 127 | http://www.wikidata.org/entity/P1716 125 128 | http://www.wikidata.org/entity/P403 126 129 | http://www.wikidata.org/entity/P412 127 130 | http://www.wikidata.org/entity/P1366 128 131 | http://www.wikidata.org/entity/P286 129 132 | http://www.wikidata.org/entity/P306 130 133 | http://www.wikidata.org/entity/P1811 131 134 | http://www.wikidata.org/entity/P1365 132 135 | http://www.wikidata.org/entity/P466 133 136 | http://www.wikidata.org/entity/P180 134 137 | http://www.wikidata.org/entity/P485 135 138 | http://www.wikidata.org/entity/P425 136 139 | http://www.wikidata.org/entity/P1071 137 140 | http://www.wikidata.org/entity/P407 138 141 | http://www.wikidata.org/entity/P61 139 142 | http://www.wikidata.org/entity/P195 140 143 | http://www.wikidata.org/entity/P1383 141 144 | http://www.wikidata.org/entity/P501 142 145 | http://www.wikidata.org/entity/P504 143 146 | http://www.wikidata.org/entity/P462 144 147 | http://www.wikidata.org/entity/P941 145 148 | http://www.wikidata.org/entity/P410 146 149 | http://www.wikidata.org/entity/P81 147 150 | http://www.wikidata.org/entity/P658 148 151 | http://www.wikidata.org/entity/P1889 149 152 | http://www.wikidata.org/entity/P406 150 153 | http://www.wikidata.org/entity/P201 151 154 | http://www.wikidata.org/entity/P2936 152 155 | http://www.wikidata.org/entity/P461 153 156 | http://www.wikidata.org/entity/P205 154 157 | http://www.wikidata.org/entity/P991 155 158 | http://www.wikidata.org/entity/P725 156 159 | http://www.wikidata.org/entity/P289 157 160 | http://www.wikidata.org/entity/P726 158 161 | http://www.wikidata.org/entity/P488 159 162 | http://www.wikidata.org/entity/P1532 160 163 | http://www.wikidata.org/entity/P122 161 164 | http://www.wikidata.org/entity/P1018 162 165 | http://www.wikidata.org/entity/P134 163 166 | http://www.wikidata.org/entity/P802 164 167 | http://www.wikidata.org/entity/P1066 165 168 | http://www.wikidata.org/entity/P200 166 169 | http://www.wikidata.org/entity/P263 167 170 | http://www.wikidata.org/entity/P1427 168 171 | -------------------------------------------------------------------------------- /benchmarks/MKG-Y/1-1.txt: -------------------------------------------------------------------------------- 1 | 90 2 | 3528 3397 8 3 | 12356 2851 13 4 | 4690 14759 15 5 | 630 11244 23 6 | 13610 1917 15 7 | 1988 13294 13 8 | 8435 1515 8 9 | 8679 3200 24 10 | 13050 13049 8 11 | 12133 12134 15 12 | 3736 13694 8 13 | 65 3003 8 14 | 14789 2115 15 15 | 11384 9938 11 16 | 12318 12317 8 17 | 11787 14805 8 18 | 10251 10250 8 19 | 11811 8610 20 20 | 14813 8187 16 21 | 6296 8148 11 22 | 14820 2730 15 23 | 6959 14827 23 24 | 772 13933 8 25 | 12515 12514 8 26 | 8971 2474 8 27 | 10926 9178 8 28 | 12481 4118 8 29 | 5348 13272 24 30 | 5755 11338 8 31 | 3906 3905 8 32 | 2478 5304 8 33 | 14853 14209 15 34 | 12983 4341 13 35 | 12382 12381 8 36 | 14729 14730 16 37 | 2700 6748 23 38 | 7084 2877 8 39 | 6641 3534 8 40 | 813 812 8 41 | 3619 3618 8 42 | 2353 13256 23 43 | 6177 10046 23 44 | 6795 4196 8 45 | 10994 10993 8 46 | 278 9081 11 47 | 14894 14895 20 48 | 11573 11572 8 49 | 2017 14901 16 50 | 3587 3586 8 51 | 2765 2539 8 52 | 7117 1090 15 53 | 11983 10835 15 54 | 7124 758 24 55 | 14917 5552 13 56 | 9432 8247 16 57 | 14787 14788 15 58 | 6301 6300 8 59 | 9089 2505 15 60 | 14927 14928 20 61 | 11626 4341 15 62 | 1912 1911 8 63 | 618 6296 8 64 | 6413 6414 15 65 | 6323 12424 8 66 | 14357 14356 8 67 | 12591 12592 15 68 | 14805 11787 8 69 | 268 3314 15 70 | 5469 5468 8 71 | 13341 13342 15 72 | 3916 6178 23 73 | 1409 11266 8 74 | 8679 3200 13 75 | 2785 11760 11 76 | 14695 14694 8 77 | 2780 6146 14 78 | 10723 2123 8 79 | 13202 13201 8 80 | 8092 1710 8 81 | 13282 628 8 82 | 7245 1070 23 83 | 2535 2323 8 84 | 4013 11244 23 85 | 11338 5755 8 86 | 3401 12546 23 87 | 14149 14148 8 88 | 4730 10181 13 89 | 10222 585 15 90 | 8147 8967 8 91 | 2146 11123 8 92 | -------------------------------------------------------------------------------- /benchmarks/MKG-Y/n-1.txt: -------------------------------------------------------------------------------- 1 | 246 2 | 5553 2971 22 3 | 4846 7769 5 4 | 14760 7480 5 5 | 14656 4334 0 6 | 12639 5704 0 7 | 12853 68 10 8 | 11177 1585 17 9 | 14764 11162 0 10 | 11148 758 10 11 | 11786 1938 5 12 | 1602 8190 0 13 | 4779 2900 5 14 | 14601 8938 7 15 | 10774 68 10 16 | 13327 13328 0 17 | 1940 4780 5 18 | 14552 2730 5 19 | 3042 1917 0 20 | 2054 1917 0 21 | 13717 387 0 22 | 6722 6701 5 23 | 743 14782 5 24 | 14783 2094 5 25 | 9222 11431 0 26 | 12958 4341 0 27 | 3246 3200 0 28 | 14784 1585 17 29 | 14785 1251 5 30 | 4593 21 7 31 | 11662 241 0 32 | 14787 14788 0 33 | 7557 68 10 34 | 4598 4341 0 35 | 13733 5692 5 36 | 14731 4565 5 37 | 5983 9817 0 38 | 14524 11815 5 39 | 1893 21 7 40 | 11322 14792 5 41 | 14793 7812 17 42 | 10805 10639 0 43 | 14799 9653 5 44 | 6677 21 7 45 | 11566 6753 0 46 | 3384 241 10 47 | 1436 14806 0 48 | 8269 8730 5 49 | 6291 14809 0 50 | 4556 2302 0 51 | 9205 4780 5 52 | 14675 1938 5 53 | 11768 13226 0 54 | 521 6376 10 55 | 12137 2780 10 56 | 6688 971 10 57 | 7460 92 0 58 | 12052 12736 5 59 | 14819 7480 5 60 | 11318 4971 5 61 | 11196 4462 7 62 | 14617 14616 5 63 | 14824 14545 5 64 | 10508 2393 0 65 | 4800 450 5 66 | 10328 2900 0 67 | 1632 971 10 68 | 14824 1906 5 69 | 9850 9817 0 70 | 8922 1419 0 71 | 11210 3659 5 72 | 6947 23 5 73 | 7605 241 10 74 | 12849 3552 5 75 | 5680 1055 0 76 | 1052 7446 0 77 | 2946 883 0 78 | 6045 1917 0 79 | 8844 5146 5 80 | 14844 14544 5 81 | 6426 971 10 82 | 11713 11714 0 83 | 14588 5971 5 84 | 10041 4929 0 85 | 14285 2851 5 86 | 7584 998 0 87 | 14369 14370 5 88 | 4564 12761 5 89 | 8004 13534 0 90 | 14849 13470 5 91 | 6657 12290 5 92 | 12313 3073 0 93 | 4188 10919 5 94 | 6343 6755 17 95 | 7159 7160 0 96 | 9281 4125 0 97 | 2701 6755 17 98 | 9367 971 10 99 | 10478 68 10 100 | 2194 971 10 101 | 11222 5877 5 102 | 13113 1938 5 103 | 13786 68 10 104 | 6357 21 7 105 | 14862 895 5 106 | 14865 5704 0 107 | 3689 349 10 108 | 3022 68 10 109 | 14871 2971 17 110 | 1214 1215 0 111 | 11232 3073 0 112 | 13262 2986 5 113 | 8180 1906 5 114 | 13012 241 5 115 | 8323 2780 5 116 | 11557 8370 22 117 | 2119 971 10 118 | 12329 68 10 119 | 11190 7904 5 120 | 14879 4718 0 121 | 11566 6376 10 122 | 14882 14474 17 123 | 7483 13551 0 124 | 13241 4505 5 125 | 4525 2819 5 126 | 7705 7991 5 127 | 9144 4551 17 128 | 1193 6376 5 129 | 14888 14889 5 130 | 14871 12394 0 131 | 5270 68 10 132 | 14891 14892 5 133 | 3934 961 0 134 | 14809 6497 5 135 | 14134 6863 5 136 | 12516 2382 5 137 | 14898 790 5 138 | 13994 7416 0 139 | 14899 4551 17 140 | 4943 14009 5 141 | 5134 12887 5 142 | 14302 3431 10 143 | 12887 5012 5 144 | 14903 3248 17 145 | 10983 450 5 146 | 8361 8360 5 147 | 7715 5483 5 148 | 13103 1737 0 149 | 14905 13518 5 150 | 12690 13163 0 151 | 10937 10428 5 152 | 9812 9811 5 153 | 9536 1874 5 154 | 9330 790 5 155 | 6491 6492 0 156 | 14918 7291 7 157 | 303 6080 5 158 | 14919 2900 5 159 | 11369 5770 0 160 | 12822 9454 0 161 | 11545 9514 0 162 | 14923 4850 5 163 | 2292 21 7 164 | 9750 3503 5 165 | 14929 2816 5 166 | 3540 4742 0 167 | 14445 1938 5 168 | 3312 758 5 169 | 14930 1002 0 170 | 8834 68 10 171 | 9669 9668 5 172 | 14931 1585 17 173 | 5214 8894 0 174 | 13259 6391 17 175 | 10501 2439 5 176 | 9675 8938 7 177 | 6997 13454 0 178 | 5927 5926 5 179 | 2873 8227 5 180 | 10775 241 10 181 | 5749 3502 5 182 | 14938 14939 5 183 | 1984 971 10 184 | 8949 9499 22 185 | 4448 9181 5 186 | 11397 6376 0 187 | 12536 5927 5 188 | 11599 3431 10 189 | 14948 387 0 190 | 14950 14951 5 191 | 11227 1629 5 192 | 11409 1002 0 193 | 14923 6701 5 194 | 3903 2016 17 195 | 11543 4125 0 196 | 14958 971 10 197 | 1816 241 5 198 | 2116 971 10 199 | 7811 7812 22 200 | 13644 14969 5 201 | 11793 1251 5 202 | 11148 758 19 203 | 11803 790 5 204 | 12030 55 19 205 | 2334 998 0 206 | 13468 13470 5 207 | 7438 11106 5 208 | 1605 1604 5 209 | 14587 5971 5 210 | 6346 10491 0 211 | 1612 3878 17 212 | 7044 241 10 213 | 107 4553 0 214 | 1039 971 10 215 | 13426 10607 7 216 | 11227 5438 5 217 | 13337 13272 0 218 | 1782 5706 5 219 | 11015 11014 5 220 | 1815 241 10 221 | 8350 2780 10 222 | 12150 1476 0 223 | 14974 14806 0 224 | 14984 14985 5 225 | 12510 1874 5 226 | 14329 1938 5 227 | 3297 12321 0 228 | 4851 13168 5 229 | 9611 585 0 230 | 11223 12568 5 231 | 14991 7709 0 232 | 14202 8179 5 233 | 3991 3990 5 234 | 14992 8624 5 235 | 2042 21 7 236 | 14423 6009 0 237 | 13743 11793 5 238 | 2968 2780 10 239 | 3899 7436 0 240 | 14995 9594 5 241 | 9909 971 10 242 | 2270 971 10 243 | 5421 13636 0 244 | 9829 750 5 245 | 14552 11893 5 246 | 8502 4504 0 247 | 13595 10882 0 248 | -------------------------------------------------------------------------------- /benchmarks/MKG-Y/n-n.py: -------------------------------------------------------------------------------- 1 | lef = {} 2 | rig = {} 3 | rellef = {} 4 | relrig = {} 5 | 6 | triple = open("train2id.txt", "r") 7 | valid = open("valid2id.txt", "r") 8 | test = open("test2id.txt", "r") 9 | 10 | tot = (int)(triple.readline()) 11 | for i in range(tot): 12 | content = triple.readline() 13 | h,t,r = content.strip().split() 14 | if not (h,r) in lef: 15 | lef[(h,r)] = [] 16 | if not (r,t) in rig: 17 | rig[(r,t)] = [] 18 | lef[(h,r)].append(t) 19 | rig[(r,t)].append(h) 20 | if not r in rellef: 21 | rellef[r] = {} 22 | if not r in relrig: 23 | relrig[r] = {} 24 | rellef[r][h] = 1 25 | relrig[r][t] = 1 26 | 27 | tot = (int)(valid.readline()) 28 | for i in range(tot): 29 | content = valid.readline() 30 | h,t,r = content.strip().split() 31 | if not (h,r) in lef: 32 | lef[(h,r)] = [] 33 | if not (r,t) in rig: 34 | rig[(r,t)] = [] 35 | lef[(h,r)].append(t) 36 | rig[(r,t)].append(h) 37 | if not r in rellef: 38 | rellef[r] = {} 39 | if not r in relrig: 40 | relrig[r] = {} 41 | rellef[r][h] = 1 42 | relrig[r][t] = 1 43 | 44 | tot = (int)(test.readline()) 45 | for i in range(tot): 46 | content = test.readline() 47 | h,t,r = content.strip().split() 48 | if not (h,r) in lef: 49 | lef[(h,r)] = [] 50 | if not (r,t) in rig: 51 | rig[(r,t)] = [] 52 | lef[(h,r)].append(t) 53 | rig[(r,t)].append(h) 54 | if not r in rellef: 55 | rellef[r] = {} 56 | if not r in relrig: 57 | relrig[r] = {} 58 | rellef[r][h] = 1 59 | relrig[r][t] = 1 60 | 61 | test.close() 62 | valid.close() 63 | triple.close() 64 | 65 | f = open("type_constrain.txt", "w") 66 | f.write("%d\n"%(len(rellef))) 67 | for i in rellef: 68 | f.write("%s\t%d"%(i,len(rellef[i]))) 69 | for j in rellef[i]: 70 | f.write("\t%s"%(j)) 71 | f.write("\n") 72 | f.write("%s\t%d"%(i,len(relrig[i]))) 73 | for j in relrig[i]: 74 | f.write("\t%s"%(j)) 75 | f.write("\n") 76 | f.close() 77 | 78 | rellef = {} 79 | totlef = {} 80 | relrig = {} 81 | totrig = {} 82 | # lef: (h, r) 83 | # rig: (r, t) 84 | for i in lef: 85 | if not i[1] in rellef: 86 | rellef[i[1]] = 0 87 | totlef[i[1]] = 0 88 | rellef[i[1]] += len(lef[i]) 89 | totlef[i[1]] += 1.0 90 | 91 | for i in rig: 92 | if not i[0] in relrig: 93 | relrig[i[0]] = 0 94 | totrig[i[0]] = 0 95 | relrig[i[0]] += len(rig[i]) 96 | totrig[i[0]] += 1.0 97 | 98 | s11=0 99 | s1n=0 100 | sn1=0 101 | snn=0 102 | f = open("test2id.txt", "r") 103 | tot = (int)(f.readline()) 104 | for i in range(tot): 105 | content = f.readline() 106 | h,t,r = content.strip().split() 107 | rign = rellef[r] / totlef[r] 108 | lefn = relrig[r] / totrig[r] 109 | if (rign < 1.5 and lefn < 1.5): 110 | s11+=1 111 | if (rign >= 1.5 and lefn < 1.5): 112 | s1n+=1 113 | if (rign < 1.5 and lefn >= 1.5): 114 | sn1+=1 115 | if (rign >= 1.5 and lefn >= 1.5): 116 | snn+=1 117 | f.close() 118 | 119 | 120 | f = open("test2id.txt", "r") 121 | f11 = open("1-1.txt", "w") 122 | f1n = open("1-n.txt", "w") 123 | fn1 = open("n-1.txt", "w") 124 | fnn = open("n-n.txt", "w") 125 | fall = open("test2id_all.txt", "w") 126 | tot = (int)(f.readline()) 127 | fall.write("%d\n"%(tot)) 128 | f11.write("%d\n"%(s11)) 129 | f1n.write("%d\n"%(s1n)) 130 | fn1.write("%d\n"%(sn1)) 131 | fnn.write("%d\n"%(snn)) 132 | for i in range(tot): 133 | content = f.readline() 134 | h,t,r = content.strip().split() 135 | rign = rellef[r] / totlef[r] 136 | lefn = relrig[r] / totrig[r] 137 | if (rign < 1.5 and lefn < 1.5): 138 | f11.write(content) 139 | fall.write("0"+"\t"+content) 140 | if (rign >= 1.5 and lefn < 1.5): 141 | f1n.write(content) 142 | fall.write("1"+"\t"+content) 143 | if (rign < 1.5 and lefn >= 1.5): 144 | fn1.write(content) 145 | fall.write("2"+"\t"+content) 146 | if (rign >= 1.5 and lefn >= 1.5): 147 | fnn.write(content) 148 | fall.write("3"+"\t"+content) 149 | fall.close() 150 | f.close() 151 | f11.close() 152 | f1n.close() 153 | fn1.close() 154 | fnn.close() 155 | -------------------------------------------------------------------------------- /benchmarks/MKG-Y/relation2id.txt: -------------------------------------------------------------------------------- 1 | 28 2 | wasBornIn 0 3 | playsFor 1 4 | wroteMusicFor 2 5 | created 3 6 | isAffiliatedTo 4 7 | isLocatedIn 5 8 | directed 6 9 | hasWonPrize 7 10 | isMarriedTo 8 11 | actedIn 9 12 | isCitizenOf 10 13 | hasChild 11 14 | edited 12 15 | livesIn 13 16 | hasOfficialLanguage 14 17 | diedIn 15 18 | happenedIn 16 19 | graduatedFrom 17 20 | owns 18 21 | isPoliticianOf 19 22 | hasAcademicAdvisor 20 23 | hasCapital 21 24 | worksAt 22 25 | influences 23 26 | isLeaderOf 24 27 | participatedIn 25 28 | hasNeighbor 26 29 | dealsWith 27 30 | -------------------------------------------------------------------------------- /embeddings/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjukg/AdaMF-MAT/50930e9b28aed57133bedfc281391185ac9b68a4/embeddings/.DS_Store -------------------------------------------------------------------------------- /embeddings/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjukg/AdaMF-MAT/50930e9b28aed57133bedfc281391185ac9b68a4/embeddings/.keep -------------------------------------------------------------------------------- /figure/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjukg/AdaMF-MAT/50930e9b28aed57133bedfc281391185ac9b68a4/figure/model.png -------------------------------------------------------------------------------- /mmkgc/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjukg/AdaMF-MAT/50930e9b28aed57133bedfc281391185ac9b68a4/mmkgc/.DS_Store -------------------------------------------------------------------------------- /mmkgc/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function -------------------------------------------------------------------------------- /mmkgc/adv/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class BaseGenerator(nn.Module): 6 | def __init__( 7 | self, 8 | noise_dim, 9 | structure_dim, 10 | img_dim 11 | ): 12 | super(BaseGenerator, self).__init__() 13 | self.proj_dim = 512 14 | self.noise_dim = noise_dim 15 | self.generator_model = nn.Sequential( 16 | nn.Linear(noise_dim + structure_dim, self.proj_dim), 17 | nn.LeakyReLU(), 18 | nn.Linear(self.proj_dim, img_dim) 19 | ) 20 | 21 | 22 | def forward(self, batch_ent_emb): 23 | random_noise = torch.randn((batch_ent_emb.shape[0], self.noise_dim)).cuda() 24 | batch_data = torch.cat((random_noise, batch_ent_emb), dim=-1) 25 | out = self.generator_model(batch_data) 26 | return out 27 | 28 | 29 | class RandomGenerator(nn.Module): 30 | def __init__( 31 | self, 32 | noise_dim, 33 | img_dim 34 | ): 35 | super(RandomGenerator, self).__init__() 36 | self.proj_dim = 256 37 | self.noise_dim = noise_dim 38 | self.generator_model = nn.Sequential( 39 | nn.Linear(noise_dim, self.proj_dim), 40 | nn.LeakyReLU(), 41 | nn.Linear(self.proj_dim, img_dim) 42 | ) 43 | 44 | 45 | def forward(self, batch_ent_emb): 46 | random_noise = torch.randn((batch_ent_emb.shape[0], self.noise_dim)).cuda() 47 | out = self.generator_model(random_noise) 48 | return out 49 | 50 | 51 | class MultiGenerator(nn.Module): 52 | def __init__( 53 | self, 54 | noise_dim, 55 | structure_dim, 56 | img_dim 57 | ): 58 | super(MultiGenerator, self).__init__() 59 | self.img_generator = BaseGenerator(noise_dim, structure_dim, img_dim) 60 | self.text_generator = BaseGenerator(noise_dim, structure_dim, img_dim) 61 | 62 | def forward(self, batch_ent_emb, modal): 63 | if modal == 1: 64 | return self.img_generator(batch_ent_emb) 65 | elif modal == 2: 66 | return self.text_generator(batch_ent_emb) 67 | else: 68 | raise NotImplementedError 69 | 70 | 71 | 72 | class Similarity(nn.Module): 73 | """ 74 | Dot product or cosine similarity 75 | """ 76 | 77 | def __init__(self, temp): 78 | super().__init__() 79 | self.temp = temp 80 | self.cos = nn.CosineSimilarity(dim=-1) 81 | 82 | def forward(self, x, y): 83 | return self.cos(x, y) / self.temp 84 | 85 | 86 | 87 | class ContrastiveLoss(nn.Module): 88 | def __init__(self, temp=0.5): 89 | super().__init__() 90 | self.loss = nn.CrossEntropyLoss() 91 | self.sim_func = Similarity(temp=temp) 92 | 93 | def forward(self, node_emb, img_emb): 94 | batch_sim = self.sim_func(node_emb.unsqueeze(1), img_emb.unsqueeze(0)) 95 | labels = torch.arange(batch_sim.size(0)).long().to('cuda') 96 | return self.loss(batch_sim, labels) -------------------------------------------------------------------------------- /mmkgc/base/Base.cpp: -------------------------------------------------------------------------------- 1 | #include "Setting.h" 2 | #include "Random.h" 3 | #include "Reader.h" 4 | #include "Corrupt.h" 5 | #include "Test.h" 6 | #include 7 | #include 8 | 9 | extern "C" 10 | void setInPath(char *path); 11 | 12 | extern "C" 13 | void setTrainPath(char *path); 14 | 15 | extern "C" 16 | void setValidPath(char *path); 17 | 18 | extern "C" 19 | void setTestPath(char *path); 20 | 21 | extern "C" 22 | void setEntPath(char *path); 23 | 24 | extern "C" 25 | void setRelPath(char *path); 26 | 27 | extern "C" 28 | void setOutPath(char *path); 29 | 30 | extern "C" 31 | void setWorkThreads(INT threads); 32 | 33 | extern "C" 34 | void setBern(INT con); 35 | 36 | extern "C" 37 | INT getWorkThreads(); 38 | 39 | extern "C" 40 | INT getEntityTotal(); 41 | 42 | extern "C" 43 | INT getRelationTotal(); 44 | 45 | extern "C" 46 | INT getTripleTotal(); 47 | 48 | extern "C" 49 | INT getTrainTotal(); 50 | 51 | extern "C" 52 | INT getTestTotal(); 53 | 54 | extern "C" 55 | INT getValidTotal(); 56 | 57 | extern "C" 58 | void randReset(); 59 | 60 | extern "C" 61 | void importTrainFiles(); 62 | 63 | struct Parameter { 64 | INT id; 65 | INT *batch_h; 66 | INT *batch_t; 67 | INT *batch_r; 68 | REAL *batch_y; 69 | INT batchSize; 70 | INT negRate; 71 | INT negRelRate; 72 | bool p; 73 | bool val_loss; 74 | INT mode; 75 | bool filter_flag; 76 | }; 77 | 78 | void* getBatch(void* con) { 79 | Parameter *para = (Parameter *)(con); 80 | INT id = para -> id; 81 | INT *batch_h = para -> batch_h; 82 | INT *batch_t = para -> batch_t; 83 | INT *batch_r = para -> batch_r; 84 | REAL *batch_y = para -> batch_y; 85 | INT batchSize = para -> batchSize; 86 | INT negRate = para -> negRate; 87 | INT negRelRate = para -> negRelRate; 88 | bool p = para -> p; 89 | bool val_loss = para -> val_loss; 90 | INT mode = para -> mode; 91 | bool filter_flag = para -> filter_flag; 92 | INT lef, rig; 93 | if (batchSize % workThreads == 0) { 94 | lef = id * (batchSize / workThreads); 95 | rig = (id + 1) * (batchSize / workThreads); 96 | } else { 97 | lef = id * (batchSize / workThreads + 1); 98 | rig = (id + 1) * (batchSize / workThreads + 1); 99 | if (rig > batchSize) rig = batchSize; 100 | } 101 | REAL prob = 500; 102 | if (val_loss == false) { 103 | for (INT batch = lef; batch < rig; batch++) { 104 | INT i = rand_max(id, trainTotal); 105 | batch_h[batch] = trainList[i].h; 106 | batch_t[batch] = trainList[i].t; 107 | batch_r[batch] = trainList[i].r; 108 | batch_y[batch] = 1; 109 | INT last = batchSize; 110 | for (INT times = 0; times < negRate; times ++) { 111 | if (mode == 0){ 112 | if (bernFlag) 113 | prob = 1000 * right_mean[trainList[i].r] / (right_mean[trainList[i].r] + left_mean[trainList[i].r]); 114 | if (randd(id) % 1000 < prob) { 115 | batch_h[batch + last] = trainList[i].h; 116 | batch_t[batch + last] = corrupt_head(id, trainList[i].h, trainList[i].r); 117 | batch_r[batch + last] = trainList[i].r; 118 | } else { 119 | batch_h[batch + last] = corrupt_tail(id, trainList[i].t, trainList[i].r); 120 | batch_t[batch + last] = trainList[i].t; 121 | batch_r[batch + last] = trainList[i].r; 122 | } 123 | batch_y[batch + last] = -1; 124 | last += batchSize; 125 | } else { 126 | if(mode == -1){ 127 | batch_h[batch + last] = corrupt_tail(id, trainList[i].t, trainList[i].r); 128 | batch_t[batch + last] = trainList[i].t; 129 | batch_r[batch + last] = trainList[i].r; 130 | } else { 131 | batch_h[batch + last] = trainList[i].h; 132 | batch_t[batch + last] = corrupt_head(id, trainList[i].h, trainList[i].r); 133 | batch_r[batch + last] = trainList[i].r; 134 | } 135 | batch_y[batch + last] = -1; 136 | last += batchSize; 137 | } 138 | } 139 | for (INT times = 0; times < negRelRate; times++) { 140 | batch_h[batch + last] = trainList[i].h; 141 | batch_t[batch + last] = trainList[i].t; 142 | batch_r[batch + last] = corrupt_rel(id, trainList[i].h, trainList[i].t, trainList[i].r, p); 143 | batch_y[batch + last] = -1; 144 | last += batchSize; 145 | } 146 | } 147 | } 148 | else 149 | { 150 | for (INT batch = lef; batch < rig; batch++) 151 | { 152 | batch_h[batch] = validList[batch].h; 153 | batch_t[batch] = validList[batch].t; 154 | batch_r[batch] = validList[batch].r; 155 | batch_y[batch] = 1; 156 | } 157 | } 158 | pthread_exit(NULL); 159 | } 160 | 161 | extern "C" 162 | void sampling( 163 | INT *batch_h, 164 | INT *batch_t, 165 | INT *batch_r, 166 | REAL *batch_y, 167 | INT batchSize, 168 | INT negRate = 1, 169 | INT negRelRate = 0, 170 | INT mode = 0, 171 | bool filter_flag = true, 172 | bool p = false, 173 | bool val_loss = false 174 | ) { 175 | pthread_t *pt = (pthread_t *)malloc(workThreads * sizeof(pthread_t)); 176 | Parameter *para = (Parameter *)malloc(workThreads * sizeof(Parameter)); 177 | for (INT threads = 0; threads < workThreads; threads++) { 178 | para[threads].id = threads; 179 | para[threads].batch_h = batch_h; 180 | para[threads].batch_t = batch_t; 181 | para[threads].batch_r = batch_r; 182 | para[threads].batch_y = batch_y; 183 | para[threads].batchSize = batchSize; 184 | para[threads].negRate = negRate; 185 | para[threads].negRelRate = negRelRate; 186 | para[threads].p = p; 187 | para[threads].val_loss = val_loss; 188 | para[threads].mode = mode; 189 | para[threads].filter_flag = filter_flag; 190 | pthread_create(&pt[threads], NULL, getBatch, (void*)(para+threads)); 191 | } 192 | for (INT threads = 0; threads < workThreads; threads++) 193 | pthread_join(pt[threads], NULL); 194 | 195 | free(pt); 196 | free(para); 197 | } 198 | 199 | int main() { 200 | importTrainFiles(); 201 | return 0; 202 | } -------------------------------------------------------------------------------- /mmkgc/base/Corrupt.h: -------------------------------------------------------------------------------- 1 | #ifndef CORRUPT_H 2 | #define CORRUPT_H 3 | #include "Random.h" 4 | #include "Triple.h" 5 | #include "Reader.h" 6 | 7 | INT corrupt_head(INT id, INT h, INT r, bool filter_flag = true) { 8 | INT lef, rig, mid, ll, rr; 9 | if (not filter_flag) { 10 | INT tmp = rand_max(id, entityTotal - 1); 11 | if (tmp < h) 12 | return tmp; 13 | else 14 | return tmp + 1; 15 | } 16 | lef = lefHead[h] - 1; 17 | rig = rigHead[h]; 18 | while (lef + 1 < rig) { 19 | mid = (lef + rig) >> 1; 20 | if (trainHead[mid].r >= r) rig = mid; else 21 | lef = mid; 22 | } 23 | ll = rig; 24 | lef = lefHead[h]; 25 | rig = rigHead[h] + 1; 26 | while (lef + 1 < rig) { 27 | mid = (lef + rig) >> 1; 28 | if (trainHead[mid].r <= r) lef = mid; else 29 | rig = mid; 30 | } 31 | rr = lef; 32 | INT tmp = rand_max(id, entityTotal - (rr - ll + 1)); 33 | if (tmp < trainHead[ll].t) return tmp; 34 | if (tmp > trainHead[rr].t - rr + ll - 1) return tmp + rr - ll + 1; 35 | lef = ll, rig = rr + 1; 36 | while (lef + 1 < rig) { 37 | mid = (lef + rig) >> 1; 38 | if (trainHead[mid].t - mid + ll - 1 < tmp) 39 | lef = mid; 40 | else 41 | rig = mid; 42 | } 43 | return tmp + lef - ll + 1; 44 | } 45 | 46 | INT corrupt_tail(INT id, INT t, INT r, bool filter_flag = true) { 47 | INT lef, rig, mid, ll, rr; 48 | if (not filter_flag) { 49 | INT tmp = rand_max(id, entityTotal - 1); 50 | if (tmp < t) 51 | return tmp; 52 | else 53 | return tmp + 1; 54 | } 55 | lef = lefTail[t] - 1; 56 | rig = rigTail[t]; 57 | while (lef + 1 < rig) { 58 | mid = (lef + rig) >> 1; 59 | if (trainTail[mid].r >= r) rig = mid; else 60 | lef = mid; 61 | } 62 | ll = rig; 63 | lef = lefTail[t]; 64 | rig = rigTail[t] + 1; 65 | while (lef + 1 < rig) { 66 | mid = (lef + rig) >> 1; 67 | if (trainTail[mid].r <= r) lef = mid; else 68 | rig = mid; 69 | } 70 | rr = lef; 71 | INT tmp = rand_max(id, entityTotal - (rr - ll + 1)); 72 | if (tmp < trainTail[ll].h) return tmp; 73 | if (tmp > trainTail[rr].h - rr + ll - 1) return tmp + rr - ll + 1; 74 | lef = ll, rig = rr + 1; 75 | while (lef + 1 < rig) { 76 | mid = (lef + rig) >> 1; 77 | if (trainTail[mid].h - mid + ll - 1 < tmp) 78 | lef = mid; 79 | else 80 | rig = mid; 81 | } 82 | return tmp + lef - ll + 1; 83 | } 84 | 85 | 86 | INT corrupt_rel(INT id, INT h, INT t, INT r, bool p = false, bool filter_flag = true) { 87 | INT lef, rig, mid, ll, rr; 88 | if (not filter_flag) { 89 | INT tmp = rand_max(id, relationTotal - 1); 90 | if (tmp < r) 91 | return tmp; 92 | else 93 | return tmp + 1; 94 | } 95 | lef = lefRel[h] - 1; 96 | rig = rigRel[h]; 97 | while (lef + 1 < rig) { 98 | mid = (lef + rig) >> 1; 99 | if (trainRel[mid].t >= t) rig = mid; else 100 | lef = mid; 101 | } 102 | ll = rig; 103 | lef = lefRel[h]; 104 | rig = rigRel[h] + 1; 105 | while (lef + 1 < rig) { 106 | mid = (lef + rig) >> 1; 107 | if (trainRel[mid].t <= t) lef = mid; else 108 | rig = mid; 109 | } 110 | rr = lef; 111 | INT tmp; 112 | if(p == false) { 113 | tmp = rand_max(id, relationTotal - (rr - ll + 1)); 114 | } 115 | else { 116 | INT start = r * (relationTotal - 1); 117 | REAL sum = 1; 118 | bool *record = (bool *)calloc(relationTotal - 1, sizeof(bool)); 119 | for (INT i = ll; i <= rr; ++i){ 120 | if (trainRel[i].r > r){ 121 | sum -= prob[start + trainRel[i].r-1]; 122 | record[trainRel[i].r-1] = true; 123 | } 124 | else if (trainRel[i].r < r){ 125 | sum -= prob[start + trainRel[i].r]; 126 | record[trainRel[i].r] = true; 127 | } 128 | } 129 | REAL *prob_tmp = (REAL *)calloc(relationTotal-(rr-ll+1), sizeof(REAL)); 130 | INT cnt = 0; 131 | REAL rec = 0; 132 | for (INT i = start; i < start + relationTotal - 1; ++i) { 133 | if (record[i-start]) 134 | continue; 135 | rec += prob[i] / sum; 136 | prob_tmp[cnt++] = rec; 137 | } 138 | REAL m = rand_max(id, 10000) / 10000.0; 139 | lef = 0; 140 | rig = cnt - 1; 141 | while (lef < rig) { 142 | mid = (lef + rig) >> 1; 143 | if (prob_tmp[mid] < m) 144 | lef = mid + 1; 145 | else 146 | rig = mid; 147 | } 148 | tmp = rig; 149 | free(prob_tmp); 150 | free(record); 151 | } 152 | if (tmp < trainRel[ll].r) return tmp; 153 | if (tmp > trainRel[rr].r - rr + ll - 1) return tmp + rr - ll + 1; 154 | lef = ll, rig = rr + 1; 155 | while (lef + 1 < rig) { 156 | mid = (lef + rig) >> 1; 157 | if (trainRel[mid].r - mid + ll - 1 < tmp) 158 | lef = mid; 159 | else 160 | rig = mid; 161 | } 162 | return tmp + lef - ll + 1; 163 | } 164 | 165 | 166 | bool _find(INT h, INT t, INT r) { 167 | INT lef = 0; 168 | INT rig = tripleTotal - 1; 169 | INT mid; 170 | while (lef + 1 < rig) { 171 | INT mid = (lef + rig) >> 1; 172 | if ((tripleList[mid]. h < h) || (tripleList[mid]. h == h && tripleList[mid]. r < r) || (tripleList[mid]. h == h && tripleList[mid]. r == r && tripleList[mid]. t < t)) lef = mid; else rig = mid; 173 | } 174 | if (tripleList[lef].h == h && tripleList[lef].r == r && tripleList[lef].t == t) return true; 175 | if (tripleList[rig].h == h && tripleList[rig].r == r && tripleList[rig].t == t) return true; 176 | return false; 177 | } 178 | 179 | INT corrupt(INT h, INT r){ 180 | INT ll = tail_lef[r]; 181 | INT rr = tail_rig[r]; 182 | INT loop = 0; 183 | INT t; 184 | while(true) { 185 | t = tail_type[rand(ll, rr)]; 186 | if (not _find(h, t, r)) { 187 | return t; 188 | } else { 189 | loop ++; 190 | if (loop >= 1000) { 191 | return corrupt_head(0, h, r); 192 | } 193 | } 194 | } 195 | } 196 | #endif 197 | -------------------------------------------------------------------------------- /mmkgc/base/Random.h: -------------------------------------------------------------------------------- 1 | #ifndef RANDOM_H 2 | #define RANDOM_H 3 | #include "Setting.h" 4 | #include 5 | 6 | // the random seeds for all threads. 7 | unsigned long long *next_random; 8 | 9 | // reset the random seeds for all threads 10 | extern "C" 11 | void randReset() { 12 | next_random = (unsigned long long *)calloc(workThreads, sizeof(unsigned long long)); 13 | for (INT i = 0; i < workThreads; i++) 14 | next_random[i] = rand(); 15 | } 16 | 17 | // get a random interger for the id-th thread with the corresponding random seed. 18 | unsigned long long randd(INT id) { 19 | next_random[id] = next_random[id] * (unsigned long long)(25214903917) + 11; 20 | return next_random[id]; 21 | } 22 | 23 | // get a random interger from the range [0,x) for the id-th thread. 24 | INT rand_max(INT id, INT x) { 25 | INT res = randd(id) % x; 26 | while (res < 0) 27 | res += x; 28 | return res; 29 | } 30 | 31 | // get a random interger from the range [a,b) for the id-th thread. 32 | INT rand(INT a, INT b){ 33 | return (rand() % (b-a))+ a; 34 | } 35 | #endif 36 | -------------------------------------------------------------------------------- /mmkgc/base/Setting.h: -------------------------------------------------------------------------------- 1 | #ifndef SETTING_H 2 | #define SETTING_H 3 | #define INT long 4 | #define REAL float 5 | #include 6 | #include 7 | #include 8 | 9 | std::string inPath = "../data/FB15K/"; 10 | std::string outPath = "../data/FB15K/"; 11 | std::string ent_file = ""; 12 | std::string rel_file = ""; 13 | std::string train_file = ""; 14 | std::string valid_file = ""; 15 | std::string test_file = ""; 16 | 17 | extern "C" 18 | void setInPath(char *path) { 19 | INT len = strlen(path); 20 | inPath = ""; 21 | for (INT i = 0; i < len; i++) 22 | inPath = inPath + path[i]; 23 | printf("Input Files Path : %s\n", inPath.c_str()); 24 | } 25 | 26 | extern "C" 27 | void setOutPath(char *path) { 28 | INT len = strlen(path); 29 | outPath = ""; 30 | for (INT i = 0; i < len; i++) 31 | outPath = outPath + path[i]; 32 | printf("Output Files Path : %s\n", outPath.c_str()); 33 | } 34 | 35 | extern "C" 36 | void setTrainPath(char *path) { 37 | INT len = strlen(path); 38 | train_file = ""; 39 | for (INT i = 0; i < len; i++) 40 | train_file = train_file + path[i]; 41 | printf("Training Files Path : %s\n", train_file.c_str()); 42 | } 43 | 44 | extern "C" 45 | void setValidPath(char *path) { 46 | INT len = strlen(path); 47 | valid_file = ""; 48 | for (INT i = 0; i < len; i++) 49 | valid_file = valid_file + path[i]; 50 | printf("Valid Files Path : %s\n", valid_file.c_str()); 51 | } 52 | 53 | extern "C" 54 | void setTestPath(char *path) { 55 | INT len = strlen(path); 56 | test_file = ""; 57 | for (INT i = 0; i < len; i++) 58 | test_file = test_file + path[i]; 59 | printf("Test Files Path : %s\n", test_file.c_str()); 60 | } 61 | 62 | extern "C" 63 | void setEntPath(char *path) { 64 | INT len = strlen(path); 65 | ent_file = ""; 66 | for (INT i = 0; i < len; i++) 67 | ent_file = ent_file + path[i]; 68 | printf("Entity Files Path : %s\n", ent_file.c_str()); 69 | } 70 | 71 | extern "C" 72 | void setRelPath(char *path) { 73 | INT len = strlen(path); 74 | rel_file = ""; 75 | for (INT i = 0; i < len; i++) 76 | rel_file = rel_file + path[i]; 77 | printf("Relation Files Path : %s\n", rel_file.c_str()); 78 | } 79 | 80 | /* 81 | ============================================================ 82 | */ 83 | 84 | INT workThreads = 1; 85 | 86 | extern "C" 87 | void setWorkThreads(INT threads) { 88 | workThreads = threads; 89 | } 90 | 91 | extern "C" 92 | INT getWorkThreads() { 93 | return workThreads; 94 | } 95 | 96 | /* 97 | ============================================================ 98 | */ 99 | 100 | INT relationTotal = 0; 101 | INT entityTotal = 0; 102 | INT tripleTotal = 0; 103 | INT testTotal = 0; 104 | INT trainTotal = 0; 105 | INT validTotal = 0; 106 | 107 | extern "C" 108 | INT getEntityTotal() { 109 | return entityTotal; 110 | } 111 | 112 | extern "C" 113 | INT getRelationTotal() { 114 | return relationTotal; 115 | } 116 | 117 | extern "C" 118 | INT getTripleTotal() { 119 | return tripleTotal; 120 | } 121 | 122 | extern "C" 123 | INT getTrainTotal() { 124 | return trainTotal; 125 | } 126 | 127 | extern "C" 128 | INT getTestTotal() { 129 | return testTotal; 130 | } 131 | 132 | extern "C" 133 | INT getValidTotal() { 134 | return validTotal; 135 | } 136 | /* 137 | ============================================================ 138 | */ 139 | 140 | INT bernFlag = 0; 141 | 142 | extern "C" 143 | void setBern(INT con) { 144 | bernFlag = con; 145 | } 146 | 147 | #endif 148 | -------------------------------------------------------------------------------- /mmkgc/base/Triple.h: -------------------------------------------------------------------------------- 1 | #ifndef TRIPLE_H 2 | #define TRIPLE_H 3 | #include "Setting.h" 4 | 5 | struct Triple { 6 | 7 | INT h, r, t; 8 | 9 | static bool cmp_head(const Triple &a, const Triple &b) { 10 | return (a.h < b.h)||(a.h == b.h && a.r < b.r)||(a.h == b.h && a.r == b.r && a.t < b.t); 11 | } 12 | 13 | static bool cmp_tail(const Triple &a, const Triple &b) { 14 | return (a.t < b.t)||(a.t == b.t && a.r < b.r)||(a.t == b.t && a.r == b.r && a.h < b.h); 15 | } 16 | 17 | static bool cmp_rel(const Triple &a, const Triple &b) { 18 | return (a.h < b.h)||(a.h == b.h && a.t < b.t)||(a.h == b.h && a.t == b.t && a.r < b.r); 19 | } 20 | 21 | static bool cmp_rel2(const Triple &a, const Triple &b) { 22 | return (a.r < b.r)||(a.r == b.r && a.h < b.h)||(a.r == b.r && a.h == b.h && a.t < b.t); 23 | } 24 | 25 | }; 26 | 27 | #endif 28 | -------------------------------------------------------------------------------- /mmkgc/config/AdvConTrainer.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | from calendar import c 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | import torch.optim as optim 7 | import os 8 | import time 9 | import sys 10 | import datetime 11 | import ctypes 12 | import json 13 | import numpy as np 14 | import copy 15 | from tqdm import tqdm 16 | from ..adv.modules import ContrastiveLoss 17 | 18 | 19 | class AdvConTrainer(object): 20 | 21 | def __init__(self, 22 | model=None, 23 | data_loader=None, 24 | train_times=1000, 25 | alpha=0.5, 26 | use_gpu=True, 27 | opt_method="sgd", 28 | save_steps=None, 29 | checkpoint_dir=None, 30 | generator=None, 31 | lrg=None, 32 | temp=None, 33 | lamda=None): 34 | 35 | self.work_threads = 8 36 | self.train_times = train_times 37 | 38 | self.opt_method = opt_method 39 | self.optimizer = None 40 | self.lr_decay = 0 41 | self.weight_decay = 0 42 | self.alpha = alpha 43 | # learning rate of the generator 44 | assert lrg is not None 45 | self.alpha_g = lrg 46 | 47 | self.model = model 48 | self.data_loader = data_loader 49 | self.use_gpu = use_gpu 50 | self.save_steps = save_steps 51 | self.checkpoint_dir = checkpoint_dir 52 | 53 | # the generator part 54 | assert generator is not None 55 | assert temp is not None 56 | assert lamda is not None 57 | self.optimizer_g = None 58 | self.generator = generator 59 | self.batch_size = self.model.batch_size 60 | self.generator.cuda() 61 | # add constrastive loss 62 | self.contrastive_loss = ContrastiveLoss(temp=temp) 63 | self.lamda = lamda 64 | 65 | 66 | def train_one_step(self, data): 67 | # training D 68 | self.optimizer.zero_grad() 69 | loss, p_score = self.model({ 70 | 'batch_h': self.to_var(data['batch_h'], self.use_gpu), 71 | 'batch_t': self.to_var(data['batch_t'], self.use_gpu), 72 | 'batch_r': self.to_var(data['batch_r'], self.use_gpu), 73 | 'batch_y': self.to_var(data['batch_y'], self.use_gpu), 74 | 'mode': data['mode'] 75 | }) 76 | # generate fake multimodal feature 77 | batch_h_gen = self.to_var(data['batch_h'][0: self.batch_size], self.use_gpu) 78 | batch_t_gen = self.to_var(data['batch_t'][0: self.batch_size], self.use_gpu) 79 | batch_r = self.to_var(data['batch_r'][0: self.batch_size], self.use_gpu) 80 | batch_hs = self.model.model.get_batch_ent_embs(batch_h_gen) 81 | batch_ts = self.model.model.get_batch_ent_embs(batch_t_gen) 82 | batch_gen_hv = self.generator(batch_hs) 83 | batch_gen_tv = self.generator(batch_ts) 84 | scores, real_img_embs = self.model.model.get_fake_score( 85 | batch_h=batch_h_gen, 86 | batch_r=batch_r, 87 | batch_t=batch_t_gen, 88 | mode=data['mode'], 89 | fake_hv=batch_gen_hv, 90 | fake_tv=batch_gen_tv 91 | ) 92 | h_img, t_img = real_img_embs 93 | # when training D: positive_score > fake_score 94 | for score in scores: 95 | loss += self.model.loss(p_score, score) / 3 96 | loss_con = self.contrastive_loss(h_img, batch_gen_hv) + self.contrastive_loss(t_img, batch_gen_tv) + self.contrastive_loss(batch_gen_hv, h_img) + self.contrastive_loss(batch_gen_tv, t_img) 97 | loss += loss_con * self.lamda 98 | loss.backward() 99 | self.optimizer.step() 100 | # training G 101 | self.optimizer_g.zero_grad() 102 | batch_hs = self.model.model.get_batch_ent_embs(batch_h_gen) 103 | batch_ts = self.model.model.get_batch_ent_embs(batch_t_gen) 104 | _, p_score = self.model({ 105 | 'batch_h': batch_h_gen, 106 | 'batch_t': batch_t_gen, 107 | 'batch_r': batch_r, 108 | 'batch_y': self.to_var(data['batch_y'], self.use_gpu), 109 | 'mode': data['mode'] 110 | }) 111 | batch_gen_hv = self.generator(batch_hs) 112 | batch_gen_tv = self.generator(batch_ts) 113 | scores, real_img_embs = self.model.model.get_fake_score( 114 | batch_h=batch_h_gen, 115 | batch_r=batch_r, 116 | batch_t=batch_t_gen, 117 | mode=data['mode'], 118 | fake_hv=batch_gen_hv, 119 | fake_tv=batch_gen_tv 120 | ) 121 | h_img, t_img = real_img_embs 122 | loss_g = 0.0 123 | for score in scores: 124 | loss_g += self.model.loss(score, p_score) / 3 125 | loss_con = self.contrastive_loss(h_img, batch_gen_hv) + self.contrastive_loss(t_img, batch_gen_tv) + self.contrastive_loss(batch_gen_hv, h_img) + self.contrastive_loss(batch_gen_tv, t_img) 126 | loss_g += loss_con * self.lamda 127 | loss_g.backward() 128 | self.optimizer_g.step() 129 | return loss.item(), loss_g.item() 130 | 131 | def run(self): 132 | if self.use_gpu: 133 | self.model.cuda() 134 | 135 | if self.optimizer is not None: 136 | pass 137 | elif self.opt_method == "Adam" or self.opt_method == "adam": 138 | self.optimizer = optim.Adam( 139 | self.model.parameters(), 140 | lr=self.alpha, 141 | weight_decay=self.weight_decay, 142 | ) 143 | self.optimizer_g = optim.Adam( 144 | self.generator.parameters(), 145 | lr=self.alpha_g, 146 | weight_decay=self.weight_decay, 147 | ) 148 | print( 149 | "Learning Rate of D: {}\nLearning Rate of G: {}".format( 150 | self.alpha, self.alpha_g) 151 | ) 152 | else: 153 | raise NotImplementedError 154 | print("Finish initializing...") 155 | 156 | training_range = tqdm(range(self.train_times)) 157 | for epoch in training_range: 158 | res = 0.0 159 | res_g = 0.0 160 | for data in self.data_loader: 161 | loss, loss_g = self.train_one_step(data) 162 | res += loss 163 | res_g += loss_g 164 | training_range.set_description("Epoch %d | D loss: %f, G loss %f" % (epoch, res, res_g)) 165 | 166 | if self.save_steps and self.checkpoint_dir and (epoch + 1) % self.save_steps == 0: 167 | print("Epoch %d has finished, saving..." % (epoch)) 168 | self.model.save_checkpoint(os.path.join(self.checkpoint_dir + "-" + str(epoch) + ".ckpt")) 169 | 170 | def set_model(self, model): 171 | self.model = model 172 | 173 | def to_var(self, x, use_gpu): 174 | if use_gpu: 175 | return Variable(torch.from_numpy(x).cuda()) 176 | else: 177 | return Variable(torch.from_numpy(x)) 178 | 179 | def set_use_gpu(self, use_gpu): 180 | self.use_gpu = use_gpu 181 | 182 | def set_alpha(self, alpha): 183 | self.alpha = alpha 184 | 185 | def set_lr_decay(self, lr_decay): 186 | self.lr_decay = lr_decay 187 | 188 | def set_weight_decay(self, weight_decay): 189 | self.weight_decay = weight_decay 190 | 191 | def set_opt_method(self, opt_method): 192 | self.opt_method = opt_method 193 | 194 | def set_train_times(self, train_times): 195 | self.train_times = train_times 196 | 197 | def set_save_steps(self, save_steps, checkpoint_dir=None): 198 | self.save_steps = save_steps 199 | if not self.checkpoint_dir: 200 | self.set_checkpoint_dir(checkpoint_dir) 201 | 202 | def set_checkpoint_dir(self, checkpoint_dir): 203 | self.checkpoint_dir = checkpoint_dir 204 | 205 | 206 | 207 | 208 | -------------------------------------------------------------------------------- /mmkgc/config/AdvMixTrainer.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | from calendar import c 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | import torch.optim as optim 7 | import os 8 | import time 9 | import sys 10 | import datetime 11 | import ctypes 12 | import json 13 | import numpy as np 14 | import copy 15 | from tqdm import tqdm 16 | 17 | 18 | class AdvMixTrainer(object): 19 | 20 | def __init__(self, 21 | model=None, 22 | data_loader=None, 23 | train_times=1000, 24 | alpha=0.5, 25 | use_gpu=True, 26 | opt_method="sgd", 27 | save_steps=None, 28 | checkpoint_dir=None, 29 | generator=None, 30 | lrg=None, 31 | mu=None): 32 | 33 | self.work_threads = 8 34 | self.train_times = train_times 35 | 36 | self.opt_method = opt_method 37 | self.optimizer = None 38 | self.lr_decay = 0 39 | self.weight_decay = 0 40 | self.alpha = alpha 41 | # learning rate of the generator 42 | assert lrg is not None 43 | self.alpha_g = lrg 44 | 45 | self.model = model 46 | self.data_loader = data_loader 47 | self.use_gpu = use_gpu 48 | self.save_steps = save_steps 49 | self.checkpoint_dir = checkpoint_dir 50 | 51 | # the generator part 52 | assert generator is not None 53 | assert mu is not None 54 | self.optimizer_g = None 55 | self.generator = generator 56 | self.batch_size = self.model.batch_size 57 | self.generator.cuda() 58 | self.mu = mu 59 | 60 | def train_one_step(self, data): 61 | # training D 62 | self.optimizer.zero_grad() 63 | loss, p_score = self.model({ 64 | 'batch_h': self.to_var(data['batch_h'], self.use_gpu), 65 | 'batch_t': self.to_var(data['batch_t'], self.use_gpu), 66 | 'batch_r': self.to_var(data['batch_r'], self.use_gpu), 67 | 'batch_y': self.to_var(data['batch_y'], self.use_gpu), 68 | 'mode': data['mode'] 69 | }) 70 | # generate fake multimodal feature 71 | batch_h_gen = self.to_var(data['batch_h'][0: self.batch_size], self.use_gpu) 72 | batch_t_gen = self.to_var(data['batch_t'][0: self.batch_size], self.use_gpu) 73 | batch_r = self.to_var(data['batch_r'][0: self.batch_size], self.use_gpu) 74 | batch_hs = self.model.model.get_batch_ent_embs(batch_h_gen) 75 | batch_ts = self.model.model.get_batch_ent_embs(batch_t_gen) 76 | batch_gen_hv = self.generator(batch_hs, 1) 77 | batch_gen_tv = self.generator(batch_ts, 1) 78 | batch_gen_ht = self.generator(batch_hs, 2) 79 | batch_gen_tt = self.generator(batch_ts, 2) 80 | scores, _ = self.model.model.get_fake_score( 81 | batch_h=batch_h_gen, 82 | batch_r=batch_r, 83 | batch_t=batch_t_gen, 84 | mode=data['mode'], 85 | fake_hv=batch_gen_hv, 86 | fake_tv=batch_gen_tv, 87 | fake_ht=batch_gen_ht, 88 | fake_tt=batch_gen_tt 89 | ) 90 | # when training D: positive_score > fake_score 91 | for score in scores: 92 | # print(p_score.shape, score.shape) 93 | loss += self.model.loss(p_score, score) * self.mu 94 | loss.backward() 95 | self.optimizer.step() 96 | # training G 97 | self.optimizer_g.zero_grad() 98 | batch_hs = self.model.model.get_batch_ent_embs(batch_h_gen) 99 | batch_ts = self.model.model.get_batch_ent_embs(batch_t_gen) 100 | p_score = self.model({ 101 | 'batch_h': batch_h_gen, 102 | 'batch_t': batch_t_gen, 103 | 'batch_r': batch_r, 104 | 'batch_y': self.to_var(data['batch_y'], self.use_gpu), 105 | 'mode': data['mode'] 106 | }, fast_return=True) 107 | batch_gen_hv = self.generator(batch_hs, 1) 108 | batch_gen_tv = self.generator(batch_ts, 1) 109 | batch_gen_ht = self.generator(batch_hs, 2) 110 | batch_gen_tt = self.generator(batch_ts, 2) 111 | scores, _ = self.model.model.get_fake_score( 112 | batch_h=batch_h_gen, 113 | batch_r=batch_r, 114 | batch_t=batch_t_gen, 115 | mode=data['mode'], 116 | fake_hv=batch_gen_hv, 117 | fake_tv=batch_gen_tv, 118 | fake_ht=batch_gen_ht, 119 | fake_tt=batch_gen_tt 120 | ) 121 | loss_g = 0.0 122 | for score in scores: 123 | loss_g += self.model.loss(score, p_score) 124 | loss_g.backward() 125 | self.optimizer_g.step() 126 | return loss.item(), loss_g.item() 127 | 128 | def run(self): 129 | if self.use_gpu: 130 | self.model.cuda() 131 | 132 | if self.optimizer is not None: 133 | pass 134 | elif self.opt_method == "Adam" or self.opt_method == "adam": 135 | self.optimizer = optim.Adam( 136 | self.model.parameters(), 137 | lr=self.alpha, 138 | weight_decay=self.weight_decay, 139 | ) 140 | self.optimizer_g = optim.Adam( 141 | self.generator.parameters(), 142 | lr=self.alpha_g, 143 | weight_decay=self.weight_decay, 144 | ) 145 | print( 146 | "Learning Rate of D: {}\nLearning Rate of G: {}".format( 147 | self.alpha, self.alpha_g) 148 | ) 149 | else: 150 | raise NotImplementedError 151 | print("Finish initializing...") 152 | 153 | training_range = tqdm(range(self.train_times)) 154 | for epoch in training_range: 155 | res = 0.0 156 | res_g = 0.0 157 | for data in self.data_loader: 158 | loss, loss_g = self.train_one_step(data) 159 | res += loss 160 | res_g += loss_g 161 | training_range.set_description("Epoch %d | D loss: %f, G loss %f" % (epoch, res, res_g)) 162 | 163 | if self.save_steps and self.checkpoint_dir and (epoch + 1) % self.save_steps == 0: 164 | print("Epoch %d has finished, saving..." % (epoch)) 165 | self.model.save_checkpoint(os.path.join(self.checkpoint_dir + "-" + str(epoch) + ".ckpt")) 166 | 167 | def set_model(self, model): 168 | self.model = model 169 | 170 | def to_var(self, x, use_gpu): 171 | if use_gpu: 172 | return Variable(torch.from_numpy(x).cuda()) 173 | else: 174 | return Variable(torch.from_numpy(x)) 175 | 176 | def set_use_gpu(self, use_gpu): 177 | self.use_gpu = use_gpu 178 | 179 | def set_alpha(self, alpha): 180 | self.alpha = alpha 181 | 182 | def set_lr_decay(self, lr_decay): 183 | self.lr_decay = lr_decay 184 | 185 | def set_weight_decay(self, weight_decay): 186 | self.weight_decay = weight_decay 187 | 188 | def set_opt_method(self, opt_method): 189 | self.opt_method = opt_method 190 | 191 | def set_train_times(self, train_times): 192 | self.train_times = train_times 193 | 194 | def set_save_steps(self, save_steps, checkpoint_dir=None): 195 | self.save_steps = save_steps 196 | if not self.checkpoint_dir: 197 | self.set_checkpoint_dir(checkpoint_dir) 198 | 199 | def set_checkpoint_dir(self, checkpoint_dir): 200 | self.checkpoint_dir = checkpoint_dir 201 | 202 | 203 | -------------------------------------------------------------------------------- /mmkgc/config/AdvTrainer.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | from calendar import c 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | import torch.optim as optim 7 | import os 8 | import time 9 | import sys 10 | import datetime 11 | import ctypes 12 | import json 13 | import numpy as np 14 | import copy 15 | from tqdm import tqdm 16 | 17 | 18 | class AdvTrainer(object): 19 | 20 | def __init__(self, 21 | model=None, 22 | data_loader=None, 23 | train_times=1000, 24 | alpha=0.5, 25 | use_gpu=True, 26 | opt_method="sgd", 27 | save_steps=None, 28 | checkpoint_dir=None, 29 | generator=None, 30 | lrg=None, 31 | mu=None): 32 | 33 | self.work_threads = 8 34 | self.train_times = train_times 35 | 36 | self.opt_method = opt_method 37 | self.optimizer = None 38 | self.lr_decay = 0 39 | self.weight_decay = 0 40 | self.alpha = alpha 41 | # learning rate of the generator 42 | assert lrg is not None 43 | self.alpha_g = lrg 44 | 45 | self.model = model 46 | self.data_loader = data_loader 47 | self.use_gpu = use_gpu 48 | self.save_steps = save_steps 49 | self.checkpoint_dir = checkpoint_dir 50 | 51 | # the generator part 52 | assert generator is not None 53 | assert mu is not None 54 | self.optimizer_g = None 55 | self.generator = generator 56 | self.batch_size = self.model.batch_size 57 | self.generator.cuda() 58 | self.mu = mu 59 | 60 | def train_one_step(self, data): 61 | # training D 62 | self.optimizer.zero_grad() 63 | loss, p_score = self.model({ 64 | 'batch_h': self.to_var(data['batch_h'], self.use_gpu), 65 | 'batch_t': self.to_var(data['batch_t'], self.use_gpu), 66 | 'batch_r': self.to_var(data['batch_r'], self.use_gpu), 67 | 'batch_y': self.to_var(data['batch_y'], self.use_gpu), 68 | 'mode': data['mode'] 69 | }) 70 | # generate fake multimodal feature 71 | batch_h_gen = self.to_var(data['batch_h'][0: self.batch_size], self.use_gpu) 72 | batch_t_gen = self.to_var(data['batch_t'][0: self.batch_size], self.use_gpu) 73 | batch_r = self.to_var(data['batch_r'][0: self.batch_size], self.use_gpu) 74 | batch_hs = self.model.model.get_batch_ent_embs(batch_h_gen) 75 | batch_ts = self.model.model.get_batch_ent_embs(batch_t_gen) 76 | batch_gen_hv = self.generator(batch_hs) 77 | batch_gen_tv = self.generator(batch_ts) 78 | scores, _ = self.model.model.get_fake_score( 79 | batch_h=batch_h_gen, 80 | batch_r=batch_r, 81 | batch_t=batch_t_gen, 82 | mode=data['mode'], 83 | fake_hv=batch_gen_hv, 84 | fake_tv=batch_gen_tv 85 | ) 86 | # when training D: positive_score > fake_score 87 | for score in scores: 88 | # print(p_score.shape, score.shape) 89 | loss += self.model.loss(p_score, score) * self.mu 90 | loss.backward() 91 | self.optimizer.step() 92 | # training G 93 | self.optimizer_g.zero_grad() 94 | batch_hs = self.model.model.get_batch_ent_embs(batch_h_gen) 95 | batch_ts = self.model.model.get_batch_ent_embs(batch_t_gen) 96 | p_score = self.model({ 97 | 'batch_h': batch_h_gen, 98 | 'batch_t': batch_t_gen, 99 | 'batch_r': batch_r, 100 | 'batch_y': self.to_var(data['batch_y'], self.use_gpu), 101 | 'mode': data['mode'] 102 | }, fast_return=True) 103 | batch_gen_hv = self.generator(batch_hs) 104 | batch_gen_tv = self.generator(batch_ts) 105 | scores, _ = self.model.model.get_fake_score( 106 | batch_h=batch_h_gen, 107 | batch_r=batch_r, 108 | batch_t=batch_t_gen, 109 | mode=data['mode'], 110 | fake_hv=batch_gen_hv, 111 | fake_tv=batch_gen_tv 112 | ) 113 | loss_g = 0.0 114 | for score in scores: 115 | loss_g += self.model.loss(score, p_score) 116 | loss_g.backward() 117 | self.optimizer_g.step() 118 | return loss.item(), loss_g.item() 119 | 120 | def run(self): 121 | if self.use_gpu: 122 | self.model.cuda() 123 | 124 | if self.optimizer is not None: 125 | pass 126 | elif self.opt_method == "Adam" or self.opt_method == "adam": 127 | self.optimizer = optim.Adam( 128 | self.model.parameters(), 129 | lr=self.alpha, 130 | weight_decay=self.weight_decay, 131 | ) 132 | self.optimizer_g = optim.Adam( 133 | self.generator.parameters(), 134 | lr=self.alpha_g, 135 | weight_decay=self.weight_decay, 136 | ) 137 | print( 138 | "Learning Rate of D: {}\nLearning Rate of G: {}".format( 139 | self.alpha, self.alpha_g) 140 | ) 141 | else: 142 | raise NotImplementedError 143 | print("Finish initializing...") 144 | 145 | training_range = tqdm(range(self.train_times)) 146 | for epoch in training_range: 147 | res = 0.0 148 | res_g = 0.0 149 | for data in self.data_loader: 150 | loss, loss_g = self.train_one_step(data) 151 | res += loss 152 | res_g += loss_g 153 | training_range.set_description("Epoch %d | D loss: %f, G loss %f" % (epoch, res, res_g)) 154 | 155 | if self.save_steps and self.checkpoint_dir and (epoch + 1) % self.save_steps == 0: 156 | print("Epoch %d has finished, saving..." % (epoch)) 157 | self.model.save_checkpoint(os.path.join(self.checkpoint_dir + "-" + str(epoch) + ".ckpt")) 158 | 159 | def set_model(self, model): 160 | self.model = model 161 | 162 | def to_var(self, x, use_gpu): 163 | if use_gpu: 164 | return Variable(torch.from_numpy(x).cuda()) 165 | else: 166 | return Variable(torch.from_numpy(x)) 167 | 168 | def set_use_gpu(self, use_gpu): 169 | self.use_gpu = use_gpu 170 | 171 | def set_alpha(self, alpha): 172 | self.alpha = alpha 173 | 174 | def set_lr_decay(self, lr_decay): 175 | self.lr_decay = lr_decay 176 | 177 | def set_weight_decay(self, weight_decay): 178 | self.weight_decay = weight_decay 179 | 180 | def set_opt_method(self, opt_method): 181 | self.opt_method = opt_method 182 | 183 | def set_train_times(self, train_times): 184 | self.train_times = train_times 185 | 186 | def set_save_steps(self, save_steps, checkpoint_dir=None): 187 | self.save_steps = save_steps 188 | if not self.checkpoint_dir: 189 | self.set_checkpoint_dir(checkpoint_dir) 190 | 191 | def set_checkpoint_dir(self, checkpoint_dir): 192 | self.checkpoint_dir = checkpoint_dir 193 | 194 | 195 | -------------------------------------------------------------------------------- /mmkgc/config/MMKRLTrainer.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import copy 3 | import ctypes 4 | import datetime 5 | import json 6 | import os 7 | import sys 8 | import time 9 | from calendar import c 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | from torch.autograd import Variable 16 | from tqdm import tqdm 17 | 18 | 19 | class MMKRLTrainer(object): 20 | 21 | def __init__(self, 22 | model=None, 23 | data_loader=None, 24 | train_times=1000, 25 | alpha=0.5, 26 | use_gpu=True, 27 | opt_method="sgd", 28 | save_steps=None, 29 | checkpoint_dir=None, 30 | generator=None, 31 | lrg=None, 32 | mu=None): 33 | 34 | self.work_threads = 8 35 | self.train_times = train_times 36 | 37 | self.opt_method = opt_method 38 | self.optimizer = None 39 | self.lr_decay = 0 40 | self.weight_decay = 0 41 | self.alpha = alpha 42 | # learning rate of the generator 43 | assert lrg is not None 44 | self.alpha_g = lrg 45 | 46 | self.model = model 47 | self.data_loader = data_loader 48 | self.use_gpu = use_gpu 49 | self.save_steps = save_steps 50 | self.checkpoint_dir = checkpoint_dir 51 | 52 | # the generator part 53 | assert generator is not None 54 | assert mu is not None 55 | self.optimizer_g = None 56 | self.generator = generator 57 | self.batch_size = self.model.batch_size 58 | self.generator.cuda() 59 | self.mu = mu 60 | 61 | def train_one_step(self, data): 62 | # training D 63 | self.optimizer.zero_grad() 64 | loss, p_score = self.model({ 65 | 'batch_h': self.to_var(data['batch_h'], self.use_gpu), 66 | 'batch_t': self.to_var(data['batch_t'], self.use_gpu), 67 | 'batch_r': self.to_var(data['batch_r'], self.use_gpu), 68 | 'batch_y': self.to_var(data['batch_y'], self.use_gpu), 69 | 'mode': data['mode'] 70 | }) 71 | # generate fake multimodal feature 72 | batch_h_gen = self.to_var( 73 | data['batch_h'][0: self.batch_size], self.use_gpu) 74 | batch_t_gen = self.to_var( 75 | data['batch_t'][0: self.batch_size], self.use_gpu) 76 | batch_r = self.to_var( 77 | data['batch_r'][0: self.batch_size], self.use_gpu) 78 | batch_hs = self.model.model.get_batch_ent_embs(batch_h_gen) 79 | batch_ts = self.model.model.get_batch_ent_embs(batch_t_gen) 80 | batch_gen_hv = self.generator(batch_hs) 81 | batch_gen_tv = self.generator(batch_ts) 82 | score = self.model.model.get_fake_score( 83 | batch_h=batch_h_gen, 84 | batch_r=batch_r, 85 | batch_t=batch_t_gen, 86 | mode=data['mode'], 87 | fake_hv=batch_gen_hv, 88 | fake_tv=batch_gen_tv 89 | ) 90 | # when training D: positive_score > fake_score 91 | loss += self.model.loss(p_score, score) * 0.07 92 | loss.backward() 93 | self.optimizer.step() 94 | # training G 95 | self.optimizer_g.zero_grad() 96 | batch_hs = self.model.model.get_batch_ent_embs(batch_h_gen) 97 | batch_ts = self.model.model.get_batch_ent_embs(batch_t_gen) 98 | p_score = self.model({ 99 | 'batch_h': batch_h_gen, 100 | 'batch_t': batch_t_gen, 101 | 'batch_r': batch_r, 102 | 'batch_y': self.to_var(data['batch_y'], self.use_gpu), 103 | 'mode': data['mode'] 104 | }, fast_return=True) 105 | batch_gen_hv = self.generator(batch_hs) 106 | batch_gen_tv = self.generator(batch_ts) 107 | score = self.model.model.get_fake_score( 108 | batch_h=batch_h_gen, 109 | batch_r=batch_r, 110 | batch_t=batch_t_gen, 111 | mode=data['mode'], 112 | fake_hv=batch_gen_hv, 113 | fake_tv=batch_gen_tv 114 | ) 115 | loss_g = self.model.loss(score, p_score) 116 | loss_g.backward() 117 | self.optimizer_g.step() 118 | return loss.item(), loss_g.item() 119 | 120 | def run(self): 121 | if self.use_gpu: 122 | self.model.cuda() 123 | 124 | if self.optimizer is not None: 125 | pass 126 | elif self.opt_method == "Adam" or self.opt_method == "adam": 127 | self.optimizer = optim.Adam( 128 | self.model.parameters(), 129 | lr=self.alpha, 130 | weight_decay=self.weight_decay, 131 | ) 132 | self.optimizer_g = optim.Adam( 133 | self.generator.parameters(), 134 | lr=self.alpha_g, 135 | weight_decay=self.weight_decay, 136 | ) 137 | print( 138 | "Learning Rate of D: {}\nLearning Rate of G: {}".format( 139 | self.alpha, self.alpha_g) 140 | ) 141 | else: 142 | raise NotImplementedError 143 | print("Finish initializing...") 144 | 145 | training_range = tqdm(range(self.train_times)) 146 | for epoch in training_range: 147 | res = 0.0 148 | res_g = 0.0 149 | for data in self.data_loader: 150 | loss, loss_g = self.train_one_step(data) 151 | res += loss 152 | res_g += loss_g 153 | training_range.set_description( 154 | "Epoch %d | D loss: %f, G loss %f" % (epoch, res, res_g)) 155 | 156 | if self.save_steps and self.checkpoint_dir and (epoch + 1) % self.save_steps == 0: 157 | print("Epoch %d has finished, saving..." % (epoch)) 158 | self.model.save_checkpoint(os.path.join( 159 | self.checkpoint_dir + "-" + str(epoch) + ".ckpt")) 160 | 161 | def set_model(self, model): 162 | self.model = model 163 | 164 | def to_var(self, x, use_gpu): 165 | if use_gpu: 166 | return Variable(torch.from_numpy(x).cuda()) 167 | else: 168 | return Variable(torch.from_numpy(x)) 169 | 170 | def set_use_gpu(self, use_gpu): 171 | self.use_gpu = use_gpu 172 | 173 | def set_alpha(self, alpha): 174 | self.alpha = alpha 175 | 176 | def set_lr_decay(self, lr_decay): 177 | self.lr_decay = lr_decay 178 | 179 | def set_weight_decay(self, weight_decay): 180 | self.weight_decay = weight_decay 181 | 182 | def set_opt_method(self, opt_method): 183 | self.opt_method = opt_method 184 | 185 | def set_train_times(self, train_times): 186 | self.train_times = train_times 187 | 188 | def set_save_steps(self, save_steps, checkpoint_dir=None): 189 | self.save_steps = save_steps 190 | if not self.checkpoint_dir: 191 | self.set_checkpoint_dir(checkpoint_dir) 192 | 193 | def set_checkpoint_dir(self, checkpoint_dir): 194 | self.checkpoint_dir = checkpoint_dir 195 | -------------------------------------------------------------------------------- /mmkgc/config/MultiAdvMixTrainer.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | from calendar import c 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | import torch.optim as optim 7 | import os 8 | import time 9 | import sys 10 | import datetime 11 | import ctypes 12 | import json 13 | import numpy as np 14 | import copy 15 | from tqdm import tqdm 16 | 17 | 18 | class MultiAdvMixTrainer(object): 19 | 20 | def __init__(self, 21 | model=None, 22 | data_loader=None, 23 | train_times=1000, 24 | alpha=0.5, 25 | use_gpu=True, 26 | opt_method="sgd", 27 | save_steps=None, 28 | checkpoint_dir=None, 29 | generator=None, 30 | lrg=None, 31 | mu=None, 32 | adv_num=1): 33 | 34 | self.work_threads = 8 35 | self.train_times = train_times 36 | 37 | self.opt_method = opt_method 38 | self.optimizer = None 39 | self.lr_decay = 0 40 | self.weight_decay = 0 41 | self.alpha = alpha 42 | # learning rate of the generator 43 | assert lrg is not None 44 | self.alpha_g = lrg 45 | 46 | self.model = model 47 | self.data_loader = data_loader 48 | self.use_gpu = use_gpu 49 | self.save_steps = save_steps 50 | self.checkpoint_dir = checkpoint_dir 51 | 52 | # the generator part 53 | assert generator is not None 54 | assert mu is not None 55 | self.optimizer_g = None 56 | self.generator = generator 57 | self.batch_size = self.model.batch_size 58 | self.generator.cuda() 59 | self.mu = mu 60 | # k is the number of adv group 61 | self.adv_num = adv_num 62 | 63 | def train_one_step(self, data): 64 | # training D 65 | self.optimizer.zero_grad() 66 | loss, p_score = self.model({ 67 | 'batch_h': self.to_var(data['batch_h'], self.use_gpu), 68 | 'batch_t': self.to_var(data['batch_t'], self.use_gpu), 69 | 'batch_r': self.to_var(data['batch_r'], self.use_gpu), 70 | 'batch_y': self.to_var(data['batch_y'], self.use_gpu), 71 | 'mode': data['mode'] 72 | }) 73 | # generate fake multimodal feature 74 | batch_h_gen = self.to_var(data['batch_h'][0: self.batch_size], self.use_gpu) 75 | batch_t_gen = self.to_var(data['batch_t'][0: self.batch_size], self.use_gpu) 76 | batch_r = self.to_var(data['batch_r'][0: self.batch_size], self.use_gpu) 77 | batch_hs = self.model.model.get_batch_ent_embs(batch_h_gen) 78 | batch_ts = self.model.model.get_batch_ent_embs(batch_t_gen) 79 | for _ in range(self.adv_num): 80 | batch_gen_hv = self.generator(batch_hs, 1) 81 | batch_gen_tv = self.generator(batch_ts, 1) 82 | batch_gen_ht = self.generator(batch_hs, 2) 83 | batch_gen_tt = self.generator(batch_ts, 2) 84 | scores, _ = self.model.model.get_fake_score( 85 | batch_h=batch_h_gen, 86 | batch_r=batch_r, 87 | batch_t=batch_t_gen, 88 | mode=data['mode'], 89 | fake_hv=batch_gen_hv, 90 | fake_tv=batch_gen_tv, 91 | fake_ht=batch_gen_ht, 92 | fake_tt=batch_gen_tt 93 | ) 94 | # when training D: positive_score > fake_score 95 | for score in scores: 96 | # print(p_score.shape, score.shape) 97 | loss += self.model.loss(p_score, score) * self.mu / self.adv_num 98 | loss.backward() 99 | self.optimizer.step() 100 | # training G 101 | self.optimizer_g.zero_grad() 102 | batch_hs = self.model.model.get_batch_ent_embs(batch_h_gen) 103 | batch_ts = self.model.model.get_batch_ent_embs(batch_t_gen) 104 | p_score = self.model({ 105 | 'batch_h': batch_h_gen, 106 | 'batch_t': batch_t_gen, 107 | 'batch_r': batch_r, 108 | 'batch_y': self.to_var(data['batch_y'], self.use_gpu), 109 | 'mode': data['mode'] 110 | }, fast_return=True) 111 | loss_g = 0.0 112 | for _ in range(self.adv_num): 113 | batch_gen_hv = self.generator(batch_hs, 1) 114 | batch_gen_tv = self.generator(batch_ts, 1) 115 | batch_gen_ht = self.generator(batch_hs, 2) 116 | batch_gen_tt = self.generator(batch_ts, 2) 117 | scores, _ = self.model.model.get_fake_score( 118 | batch_h=batch_h_gen, 119 | batch_r=batch_r, 120 | batch_t=batch_t_gen, 121 | mode=data['mode'], 122 | fake_hv=batch_gen_hv, 123 | fake_tv=batch_gen_tv, 124 | fake_ht=batch_gen_ht, 125 | fake_tt=batch_gen_tt 126 | ) 127 | for score in scores: 128 | loss_g += self.model.loss(score, p_score) / self.adv_num 129 | loss_g.backward() 130 | self.optimizer_g.step() 131 | return loss.item(), loss_g.item() 132 | 133 | def run(self): 134 | if self.use_gpu: 135 | self.model.cuda() 136 | 137 | if self.optimizer is not None: 138 | pass 139 | elif self.opt_method == "Adam" or self.opt_method == "adam": 140 | self.optimizer = optim.Adam( 141 | self.model.parameters(), 142 | lr=self.alpha, 143 | weight_decay=self.weight_decay, 144 | ) 145 | self.optimizer_g = optim.Adam( 146 | self.generator.parameters(), 147 | lr=self.alpha_g, 148 | weight_decay=self.weight_decay, 149 | ) 150 | print( 151 | "Learning Rate of D: {}\nLearning Rate of G: {}".format( 152 | self.alpha, self.alpha_g) 153 | ) 154 | else: 155 | raise NotImplementedError 156 | print("Finish initializing...") 157 | 158 | training_range = tqdm(range(self.train_times)) 159 | for epoch in training_range: 160 | res = 0.0 161 | res_g = 0.0 162 | for data in self.data_loader: 163 | loss, loss_g = self.train_one_step(data) 164 | res += loss 165 | res_g += loss_g 166 | training_range.set_description("Epoch %d | D loss: %f, G loss %f" % (epoch, res, res_g)) 167 | 168 | if self.save_steps and self.checkpoint_dir and (epoch + 1) % self.save_steps == 0: 169 | print("Epoch %d has finished, saving..." % (epoch)) 170 | self.model.save_checkpoint(os.path.join(self.checkpoint_dir + "-" + str(epoch) + ".ckpt")) 171 | 172 | def set_model(self, model): 173 | self.model = model 174 | 175 | def to_var(self, x, use_gpu): 176 | if use_gpu: 177 | return Variable(torch.from_numpy(x).cuda()) 178 | else: 179 | return Variable(torch.from_numpy(x)) 180 | 181 | def set_use_gpu(self, use_gpu): 182 | self.use_gpu = use_gpu 183 | 184 | def set_alpha(self, alpha): 185 | self.alpha = alpha 186 | 187 | def set_lr_decay(self, lr_decay): 188 | self.lr_decay = lr_decay 189 | 190 | def set_weight_decay(self, weight_decay): 191 | self.weight_decay = weight_decay 192 | 193 | def set_opt_method(self, opt_method): 194 | self.opt_method = opt_method 195 | 196 | def set_train_times(self, train_times): 197 | self.train_times = train_times 198 | 199 | def set_save_steps(self, save_steps, checkpoint_dir=None): 200 | self.save_steps = save_steps 201 | if not self.checkpoint_dir: 202 | self.set_checkpoint_dir(checkpoint_dir) 203 | 204 | def set_checkpoint_dir(self, checkpoint_dir): 205 | self.checkpoint_dir = checkpoint_dir 206 | 207 | 208 | -------------------------------------------------------------------------------- /mmkgc/config/RSMEAdvTrainer.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | from calendar import c 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | import torch.optim as optim 7 | import os 8 | import time 9 | import sys 10 | import datetime 11 | import ctypes 12 | import json 13 | import numpy as np 14 | import copy 15 | from tqdm import tqdm 16 | 17 | 18 | class RSMEAdvTrainer(object): 19 | 20 | def __init__(self, 21 | model=None, 22 | data_loader=None, 23 | train_times=1000, 24 | alpha=0.5, 25 | use_gpu=True, 26 | opt_method="sgd", 27 | save_steps=None, 28 | checkpoint_dir=None, 29 | generator=None, 30 | lrg=None): 31 | 32 | self.work_threads = 8 33 | self.train_times = train_times 34 | 35 | self.opt_method = opt_method 36 | self.optimizer = None 37 | self.lr_decay = 0 38 | self.weight_decay = 0 39 | self.alpha = alpha 40 | # learning rate of the generator 41 | assert lrg is not None 42 | self.alpha_g = lrg 43 | 44 | self.model = model 45 | self.data_loader = data_loader 46 | self.use_gpu = use_gpu 47 | self.save_steps = save_steps 48 | self.checkpoint_dir = checkpoint_dir 49 | 50 | # the generator part 51 | assert generator is not None 52 | self.optimizer_g = None 53 | self.generator = generator 54 | self.batch_size = self.model.batch_size 55 | self.generator.cuda() 56 | 57 | def train_one_step(self, data): 58 | # training D 59 | self.optimizer.zero_grad() 60 | loss, p_score = self.model({ 61 | 'batch_h': self.to_var(data['batch_h'], self.use_gpu), 62 | 'batch_t': self.to_var(data['batch_t'], self.use_gpu), 63 | 'batch_r': self.to_var(data['batch_r'], self.use_gpu), 64 | 'batch_y': self.to_var(data['batch_y'], self.use_gpu), 65 | 'mode': data['mode'] 66 | }) 67 | # generate fake multimodal feature 68 | batch_h_gen = self.to_var(data['batch_h'][0: self.batch_size], self.use_gpu) 69 | batch_t_gen = self.to_var(data['batch_t'][0: self.batch_size], self.use_gpu) 70 | batch_r = self.to_var(data['batch_r'][0: self.batch_size], self.use_gpu) 71 | # batch_hs = self.model.model.get_batch_ent_embs(batch_h_gen) 72 | batch_ts = self.model.model.get_batch_ent_embs(batch_t_gen) 73 | batch_gen_hv = None 74 | batch_gen_tv = self.generator(batch_ts) 75 | scores = self.model.model.get_fake_score( 76 | batch_h=batch_h_gen, 77 | batch_r=batch_r, 78 | batch_t=batch_t_gen, 79 | mode=data['mode'], 80 | fake_hv=batch_gen_hv, 81 | fake_tv=batch_gen_tv 82 | ) 83 | # when training D: positive_score > fake_score 84 | for score in scores: 85 | loss += self.model.loss(p_score, score) 86 | loss.backward() 87 | self.optimizer.step() 88 | # training G 89 | self.optimizer_g.zero_grad() 90 | # batch_hs = self.model.model.get_batch_ent_embs(batch_h_gen) 91 | batch_ts = self.model.model.get_batch_ent_embs(batch_t_gen) 92 | p_score = self.model({ 93 | 'batch_h': batch_h_gen, 94 | 'batch_t': batch_t_gen, 95 | 'batch_r': batch_r, 96 | 'batch_y': self.to_var(data['batch_y'], self.use_gpu), 97 | 'mode': data['mode'] 98 | }, fast_return=True) 99 | batch_gen_hv = None 100 | batch_gen_tv = self.generator(batch_ts) 101 | scores = self.model.model.get_fake_score( 102 | batch_h=batch_h_gen, 103 | batch_r=batch_r, 104 | batch_t=batch_t_gen, 105 | mode=data['mode'], 106 | fake_hv=batch_gen_hv, 107 | fake_tv=batch_gen_tv 108 | ) 109 | loss_g = 0.0 110 | for score in scores: 111 | loss_g += self.model.loss(score, p_score) 112 | loss_g.backward() 113 | self.optimizer_g.step() 114 | return loss.item(), loss_g.item() 115 | 116 | def run(self): 117 | if self.use_gpu: 118 | self.model.cuda() 119 | 120 | if self.optimizer is not None: 121 | pass 122 | elif self.opt_method == "Adam" or self.opt_method == "adam": 123 | self.optimizer = optim.Adam( 124 | self.model.parameters(), 125 | lr=self.alpha, 126 | weight_decay=self.weight_decay, 127 | ) 128 | self.optimizer_g = optim.Adam( 129 | self.generator.parameters(), 130 | lr=self.alpha_g, 131 | weight_decay=self.weight_decay, 132 | ) 133 | print( 134 | "Learning Rate of D: {}\nLearning Rate of G: {}".format( 135 | self.alpha, self.alpha_g) 136 | ) 137 | else: 138 | raise NotImplementedError 139 | print("Finish initializing...") 140 | 141 | training_range = tqdm(range(self.train_times)) 142 | for epoch in training_range: 143 | res = 0.0 144 | res_g = 0.0 145 | for data in self.data_loader: 146 | loss, loss_g = self.train_one_step(data) 147 | res += loss 148 | res_g += loss_g 149 | training_range.set_description("Epoch %d | D loss: %f, G loss %f" % (epoch, res, res_g)) 150 | 151 | if self.save_steps and self.checkpoint_dir and (epoch + 1) % self.save_steps == 0: 152 | print("Epoch %d has finished, saving..." % (epoch)) 153 | self.model.save_checkpoint(os.path.join(self.checkpoint_dir + "-" + str(epoch) + ".ckpt")) 154 | 155 | def set_model(self, model): 156 | self.model = model 157 | 158 | def to_var(self, x, use_gpu): 159 | if use_gpu: 160 | return Variable(torch.from_numpy(x).cuda()) 161 | else: 162 | return Variable(torch.from_numpy(x)) 163 | 164 | def set_use_gpu(self, use_gpu): 165 | self.use_gpu = use_gpu 166 | 167 | def set_alpha(self, alpha): 168 | self.alpha = alpha 169 | 170 | def set_lr_decay(self, lr_decay): 171 | self.lr_decay = lr_decay 172 | 173 | def set_weight_decay(self, weight_decay): 174 | self.weight_decay = weight_decay 175 | 176 | def set_opt_method(self, opt_method): 177 | self.opt_method = opt_method 178 | 179 | def set_train_times(self, train_times): 180 | self.train_times = train_times 181 | 182 | def set_save_steps(self, save_steps, checkpoint_dir=None): 183 | self.save_steps = save_steps 184 | if not self.checkpoint_dir: 185 | self.set_checkpoint_dir(checkpoint_dir) 186 | 187 | def set_checkpoint_dir(self, checkpoint_dir): 188 | self.checkpoint_dir = checkpoint_dir 189 | 190 | 191 | -------------------------------------------------------------------------------- /mmkgc/config/Tester.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import torch.optim as optim 6 | import torch.nn.functional as F 7 | import os 8 | import time 9 | import sys 10 | import datetime 11 | import ctypes 12 | import json 13 | import numpy as np 14 | from sklearn.metrics import roc_auc_score 15 | import copy 16 | from tqdm import tqdm 17 | 18 | class Tester(object): 19 | 20 | def __init__(self, model = None, data_loader = None, use_gpu = True, other_model=None, norm=False, mu=0.5): 21 | base_file = os.path.abspath(os.path.join(os.path.dirname(__file__), "../release/Base.so")) 22 | self.lib = ctypes.cdll.LoadLibrary(base_file) 23 | self.lib.testHead.argtypes = [ctypes.c_void_p, ctypes.c_int64, ctypes.c_int64] 24 | self.lib.testTail.argtypes = [ctypes.c_void_p, ctypes.c_int64, ctypes.c_int64] 25 | self.lib.test_link_prediction.argtypes = [ctypes.c_int64] 26 | 27 | self.lib.getTestLinkMRR.argtypes = [ctypes.c_int64] 28 | self.lib.getTestLinkMR.argtypes = [ctypes.c_int64] 29 | self.lib.getTestLinkHit10.argtypes = [ctypes.c_int64] 30 | self.lib.getTestLinkHit3.argtypes = [ctypes.c_int64] 31 | self.lib.getTestLinkHit1.argtypes = [ctypes.c_int64] 32 | 33 | self.lib.getTestLinkMRR.restype = ctypes.c_float 34 | self.lib.getTestLinkMR.restype = ctypes.c_float 35 | self.lib.getTestLinkHit10.restype = ctypes.c_float 36 | self.lib.getTestLinkHit3.restype = ctypes.c_float 37 | self.lib.getTestLinkHit1.restype = ctypes.c_float 38 | 39 | self.model = model 40 | self.data_loader = data_loader 41 | self.use_gpu = use_gpu 42 | self.other_model = other_model 43 | self.norm = norm 44 | self.mu = mu 45 | 46 | if self.use_gpu: 47 | self.model.cuda() 48 | 49 | def set_model(self, model): 50 | self.model = model 51 | 52 | def set_data_loader(self, data_loader): 53 | self.data_loader = data_loader 54 | 55 | def set_use_gpu(self, use_gpu): 56 | self.use_gpu = use_gpu 57 | if self.use_gpu and self.model != None: 58 | self.model.cuda() 59 | 60 | def to_var(self, x, use_gpu): 61 | if use_gpu: 62 | return Variable(torch.from_numpy(x).cuda()) 63 | else: 64 | return Variable(torch.from_numpy(x)) 65 | 66 | def test_one_step(self, data): 67 | return self.model.predict({ 68 | 'batch_h': self.to_var(data['batch_h'], self.use_gpu), 69 | 'batch_t': self.to_var(data['batch_t'], self.use_gpu), 70 | 'batch_r': self.to_var(data['batch_r'], self.use_gpu), 71 | 'mode': data['mode'] 72 | }) 73 | 74 | 75 | 76 | def run_link_prediction(self, type_constrain=False): 77 | self.lib.initTest() 78 | self.data_loader.set_sampling_mode('link') 79 | if type_constrain: 80 | type_constrain = 1 81 | else: 82 | type_constrain = 0 83 | training_range = self.data_loader 84 | for index, [data_head, data_tail] in enumerate(training_range): 85 | score = self.test_one_step(data_head) 86 | self.lib.testHead(score.__array_interface__["data"][0], index, type_constrain) 87 | score = self.test_one_step(data_tail) 88 | self.lib.testTail(score.__array_interface__["data"][0], index, type_constrain) 89 | self.lib.test_link_prediction(type_constrain) 90 | 91 | mrr = self.lib.getTestLinkMRR(type_constrain) 92 | mr = self.lib.getTestLinkMR(type_constrain) 93 | hit10 = self.lib.getTestLinkHit10(type_constrain) 94 | hit3 = self.lib.getTestLinkHit3(type_constrain) 95 | hit1 = self.lib.getTestLinkHit1(type_constrain) 96 | return mrr, mr, hit10, hit3, hit1 97 | 98 | def get_best_threshlod(self, score, ans): 99 | res = np.concatenate([ans.reshape(-1,1), score.reshape(-1,1)], axis = -1) 100 | order = np.argsort(score) 101 | res = res[order] 102 | 103 | total_all = (float)(len(score)) 104 | total_current = 0.0 105 | total_true = np.sum(ans) 106 | total_false = total_all - total_true 107 | 108 | res_mx = 0.0 109 | threshlod = None 110 | for index, [ans, score] in enumerate(res): 111 | if ans == 1: 112 | total_current += 1.0 113 | res_current = (2 * total_current + total_false - index - 1) / total_all 114 | if res_current > res_mx: 115 | res_mx = res_current 116 | threshlod = score 117 | return threshlod, res_mx 118 | 119 | def run_triple_classification(self, threshlod = None): 120 | self.lib.initTest() 121 | self.data_loader.set_sampling_mode('classification') 122 | score = [] 123 | ans = [] 124 | training_range = tqdm(self.data_loader) 125 | for index, [pos_ins, neg_ins] in enumerate(training_range): 126 | res_pos = self.test_one_step(pos_ins) 127 | ans = ans + [1 for i in range(len(res_pos))] 128 | score.append(res_pos) 129 | 130 | res_neg = self.test_one_step(neg_ins) 131 | ans = ans + [0 for i in range(len(res_pos))] 132 | score.append(res_neg) 133 | 134 | score = np.concatenate(score, axis = -1) 135 | ans = np.array(ans) 136 | 137 | if threshlod == None: 138 | threshlod, _ = self.get_best_threshlod(score, ans) 139 | 140 | res = np.concatenate([ans.reshape(-1,1), score.reshape(-1,1)], axis = -1) 141 | order = np.argsort(score) 142 | res = res[order] 143 | 144 | total_all = (float)(len(score)) 145 | total_current = 0.0 146 | total_true = np.sum(ans) 147 | total_false = total_all - total_true 148 | 149 | for index, [ans, score] in enumerate(res): 150 | if score > threshlod: 151 | acc = (2 * total_current + total_false - index) / total_all 152 | break 153 | elif ans == 1: 154 | total_current += 1.0 155 | 156 | return acc, threshlod 157 | -------------------------------------------------------------------------------- /mmkgc/config/Trainer.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import torch.optim as optim 6 | import os 7 | import time 8 | import sys 9 | import datetime 10 | import ctypes 11 | import json 12 | import numpy as np 13 | import copy 14 | from tqdm import tqdm 15 | 16 | 17 | class Trainer(object): 18 | 19 | def __init__(self, 20 | model=None, 21 | data_loader=None, 22 | train_times=1000, 23 | alpha=0.5, 24 | use_gpu=True, 25 | opt_method="sgd", 26 | save_steps=None, 27 | checkpoint_dir=None, 28 | train_mode='adp', 29 | beta=0.5): 30 | 31 | self.work_threads = 8 32 | self.train_times = train_times 33 | 34 | self.opt_method = opt_method 35 | self.optimizer = None 36 | self.lr_decay = 0 37 | self.weight_decay = 0 38 | self.alpha = alpha 39 | 40 | self.model = model 41 | self.data_loader = data_loader 42 | self.use_gpu = use_gpu 43 | self.save_steps = save_steps 44 | self.checkpoint_dir = checkpoint_dir 45 | 46 | self.train_mode = train_mode 47 | self.beta = beta 48 | 49 | def train_one_step(self, data): 50 | self.optimizer.zero_grad() 51 | loss, _ = self.model({ 52 | 'batch_h': self.to_var(data['batch_h'], self.use_gpu), 53 | 'batch_t': self.to_var(data['batch_t'], self.use_gpu), 54 | 'batch_r': self.to_var(data['batch_r'], self.use_gpu), 55 | 'batch_y': self.to_var(data['batch_y'], self.use_gpu), 56 | 'mode': data['mode'] 57 | }) 58 | loss.backward() 59 | self.optimizer.step() 60 | return loss.item() 61 | 62 | def run(self): 63 | if self.use_gpu: 64 | self.model.cuda() 65 | 66 | if self.optimizer is not None: 67 | pass 68 | elif self.opt_method == "Adagrad" or self.opt_method == "adagrad": 69 | self.optimizer = optim.Adagrad( 70 | self.model.parameters(), 71 | lr=self.alpha, 72 | lr_decay=self.lr_decay, 73 | weight_decay=self.weight_decay, 74 | ) 75 | elif self.opt_method == "Adadelta" or self.opt_method == "adadelta": 76 | self.optimizer = optim.Adadelta( 77 | self.model.parameters(), 78 | lr=self.alpha, 79 | weight_decay=self.weight_decay, 80 | ) 81 | elif self.opt_method == "Adam" or self.opt_method == "adam": 82 | self.optimizer = optim.Adam( 83 | self.model.parameters(), 84 | lr=self.alpha, 85 | weight_decay=self.weight_decay, 86 | ) 87 | else: 88 | self.optimizer = optim.SGD( 89 | self.model.parameters(), 90 | lr=self.alpha, 91 | weight_decay=self.weight_decay, 92 | ) 93 | print("Finish initializing...") 94 | 95 | training_range = tqdm(range(self.train_times)) 96 | for epoch in training_range: 97 | res = 0.0 98 | for data in self.data_loader: 99 | loss = self.train_one_step(data) 100 | res += loss 101 | training_range.set_description("Epoch %d | loss: %f" % (epoch, res)) 102 | 103 | if self.save_steps and self.checkpoint_dir and (epoch + 1) % self.save_steps == 0: 104 | print("Epoch %d has finished, saving..." % (epoch)) 105 | self.model.save_checkpoint(os.path.join(self.checkpoint_dir + "-" + str(epoch) + ".ckpt")) 106 | 107 | def set_model(self, model): 108 | self.model = model 109 | 110 | def to_var(self, x, use_gpu): 111 | if use_gpu: 112 | return Variable(torch.from_numpy(x).cuda()) 113 | else: 114 | return Variable(torch.from_numpy(x)) 115 | 116 | def set_use_gpu(self, use_gpu): 117 | self.use_gpu = use_gpu 118 | 119 | def set_alpha(self, alpha): 120 | self.alpha = alpha 121 | 122 | def set_lr_decay(self, lr_decay): 123 | self.lr_decay = lr_decay 124 | 125 | def set_weight_decay(self, weight_decay): 126 | self.weight_decay = weight_decay 127 | 128 | def set_opt_method(self, opt_method): 129 | self.opt_method = opt_method 130 | 131 | def set_train_times(self, train_times): 132 | self.train_times = train_times 133 | 134 | def set_save_steps(self, save_steps, checkpoint_dir=None): 135 | self.save_steps = save_steps 136 | if not self.checkpoint_dir: 137 | self.set_checkpoint_dir(checkpoint_dir) 138 | 139 | def set_checkpoint_dir(self, checkpoint_dir): 140 | self.checkpoint_dir = checkpoint_dir 141 | -------------------------------------------------------------------------------- /mmkgc/config/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .Trainer import Trainer 6 | from .Tester import Tester 7 | from .AdvTrainer import AdvTrainer 8 | from .AdvConTrainer import AdvConTrainer 9 | from .AdvMixTrainer import AdvMixTrainer 10 | from .AdvConMixTrainer import AdvConMixTrainer 11 | from .MMKRLTrainer import MMKRLTrainer 12 | from .MultiAdvMixTrainer import MultiAdvMixTrainer 13 | 14 | __all__ = [ 15 | 'Trainer', 16 | 'Tester', 17 | 'AdvTrainer', 18 | 'AdvConTrainer', 19 | 'RSMEAdvTrainer', 20 | 'AdvMixTrainer', 21 | 'AdvConMixTrainer', 22 | 'MMKRLTrainer', 23 | 'MultiAdvMixTrainer' 24 | ] 25 | -------------------------------------------------------------------------------- /mmkgc/data/TestDataLoader.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import os 3 | import ctypes 4 | import numpy as np 5 | 6 | class TestDataSampler(object): 7 | 8 | def __init__(self, data_total, data_sampler): 9 | self.data_total = data_total 10 | self.data_sampler = data_sampler 11 | self.total = 0 12 | 13 | def __iter__(self): 14 | return self 15 | 16 | def __next__(self): 17 | self.total += 1 18 | if self.total > self.data_total: 19 | raise StopIteration() 20 | return self.data_sampler() 21 | 22 | def __len__(self): 23 | return self.data_total 24 | 25 | class TestDataLoader(object): 26 | 27 | def __init__(self, in_path = "./", sampling_mode = 'link', type_constrain = True): 28 | base_file = os.path.abspath(os.path.join(os.path.dirname(__file__), "../release/Base.so")) 29 | self.lib = ctypes.cdll.LoadLibrary(base_file) 30 | """for link prediction""" 31 | self.lib.getHeadBatch.argtypes = [ 32 | ctypes.c_void_p, 33 | ctypes.c_void_p, 34 | ctypes.c_void_p, 35 | ] 36 | self.lib.getTailBatch.argtypes = [ 37 | ctypes.c_void_p, 38 | ctypes.c_void_p, 39 | ctypes.c_void_p, 40 | ] 41 | """for triple classification""" 42 | self.lib.getTestBatch.argtypes = [ 43 | ctypes.c_void_p, 44 | ctypes.c_void_p, 45 | ctypes.c_void_p, 46 | ctypes.c_void_p, 47 | ctypes.c_void_p, 48 | ctypes.c_void_p, 49 | ] 50 | """set essential parameters""" 51 | self.in_path = in_path 52 | self.sampling_mode = sampling_mode 53 | self.type_constrain = type_constrain 54 | self.read() 55 | 56 | def read(self): 57 | self.lib.setInPath(ctypes.create_string_buffer(self.in_path.encode(), len(self.in_path) * 2)) 58 | self.lib.randReset() 59 | self.lib.importTestFiles() 60 | 61 | if self.type_constrain: 62 | self.lib.importTypeFiles() 63 | 64 | self.relTotal = self.lib.getRelationTotal() 65 | self.entTotal = self.lib.getEntityTotal() 66 | self.testTotal = self.lib.getTestTotal() 67 | 68 | self.test_h = np.zeros(self.entTotal, dtype=np.int64) 69 | self.test_t = np.zeros(self.entTotal, dtype=np.int64) 70 | self.test_r = np.zeros(self.entTotal, dtype=np.int64) 71 | self.test_h_addr = self.test_h.__array_interface__["data"][0] 72 | self.test_t_addr = self.test_t.__array_interface__["data"][0] 73 | self.test_r_addr = self.test_r.__array_interface__["data"][0] 74 | 75 | self.test_pos_h = np.zeros(self.testTotal, dtype=np.int64) 76 | self.test_pos_t = np.zeros(self.testTotal, dtype=np.int64) 77 | self.test_pos_r = np.zeros(self.testTotal, dtype=np.int64) 78 | self.test_pos_h_addr = self.test_pos_h.__array_interface__["data"][0] 79 | self.test_pos_t_addr = self.test_pos_t.__array_interface__["data"][0] 80 | self.test_pos_r_addr = self.test_pos_r.__array_interface__["data"][0] 81 | self.test_neg_h = np.zeros(self.testTotal, dtype=np.int64) 82 | self.test_neg_t = np.zeros(self.testTotal, dtype=np.int64) 83 | self.test_neg_r = np.zeros(self.testTotal, dtype=np.int64) 84 | self.test_neg_h_addr = self.test_neg_h.__array_interface__["data"][0] 85 | self.test_neg_t_addr = self.test_neg_t.__array_interface__["data"][0] 86 | self.test_neg_r_addr = self.test_neg_r.__array_interface__["data"][0] 87 | 88 | def sampling_lp(self): 89 | res = [] 90 | self.lib.getHeadBatch(self.test_h_addr, self.test_t_addr, self.test_r_addr) 91 | res.append({ 92 | "batch_h": self.test_h.copy(), 93 | "batch_t": self.test_t[:1].copy(), 94 | "batch_r": self.test_r[:1].copy(), 95 | "mode": "head_batch" 96 | }) 97 | self.lib.getTailBatch(self.test_h_addr, self.test_t_addr, self.test_r_addr) 98 | res.append({ 99 | "batch_h": self.test_h[:1], 100 | "batch_t": self.test_t, 101 | "batch_r": self.test_r[:1], 102 | "mode": "tail_batch" 103 | }) 104 | return res 105 | 106 | def sampling_tc(self): 107 | self.lib.getTestBatch( 108 | self.test_pos_h_addr, 109 | self.test_pos_t_addr, 110 | self.test_pos_r_addr, 111 | self.test_neg_h_addr, 112 | self.test_neg_t_addr, 113 | self.test_neg_r_addr, 114 | ) 115 | return [ 116 | { 117 | 'batch_h': self.test_pos_h, 118 | 'batch_t': self.test_pos_t, 119 | 'batch_r': self.test_pos_r , 120 | "mode": "normal" 121 | }, 122 | { 123 | 'batch_h': self.test_neg_h, 124 | 'batch_t': self.test_neg_t, 125 | 'batch_r': self.test_neg_r, 126 | "mode": "normal" 127 | } 128 | ] 129 | 130 | """interfaces to get essential parameters""" 131 | 132 | def get_ent_tot(self): 133 | return self.entTotal 134 | 135 | def get_rel_tot(self): 136 | return self.relTotal 137 | 138 | def get_triple_tot(self): 139 | return self.testTotal 140 | 141 | def set_sampling_mode(self, sampling_mode): 142 | self.sampling_mode = sampling_mode 143 | 144 | def __len__(self): 145 | return self.testTotal 146 | 147 | def __iter__(self): 148 | if self.sampling_mode == "link": 149 | self.lib.initTest() 150 | return TestDataSampler(self.testTotal, self.sampling_lp) 151 | else: 152 | self.lib.initTest() 153 | return TestDataSampler(1, self.sampling_tc) -------------------------------------------------------------------------------- /mmkgc/data/TrainDataLoader.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import os 3 | import ctypes 4 | import numpy as np 5 | 6 | 7 | class TrainDataSampler(object): 8 | 9 | def __init__(self, nbatches, datasampler): 10 | self.nbatches = nbatches 11 | self.datasampler = datasampler 12 | self.batch = 0 13 | 14 | def __iter__(self): 15 | return self 16 | 17 | def __next__(self): 18 | self.batch += 1 19 | if self.batch > self.nbatches: 20 | raise StopIteration() 21 | return self.datasampler() 22 | 23 | def __len__(self): 24 | return self.nbatches 25 | 26 | 27 | class TrainDataLoader(object): 28 | 29 | def __init__(self, 30 | in_path="./", 31 | tri_file=None, 32 | ent_file=None, 33 | rel_file=None, 34 | batch_size=None, 35 | nbatches=None, 36 | threads=8, 37 | sampling_mode="normal", 38 | bern_flag=False, 39 | filter_flag=True, 40 | neg_ent=1, 41 | neg_rel=0): 42 | 43 | base_file = os.path.abspath(os.path.join(os.path.dirname(__file__), "../release/Base.so")) 44 | self.lib = ctypes.cdll.LoadLibrary(base_file) 45 | """argtypes""" 46 | self.lib.sampling.argtypes = [ 47 | ctypes.c_void_p, 48 | ctypes.c_void_p, 49 | ctypes.c_void_p, 50 | ctypes.c_void_p, 51 | ctypes.c_int64, 52 | ctypes.c_int64, 53 | ctypes.c_int64, 54 | ctypes.c_int64, 55 | ctypes.c_int64, 56 | ctypes.c_int64, 57 | ctypes.c_int64 58 | ] 59 | self.in_path = in_path 60 | self.tri_file = tri_file 61 | self.ent_file = ent_file 62 | self.rel_file = rel_file 63 | if in_path != None: 64 | self.tri_file = in_path + "train2id.txt" 65 | self.ent_file = in_path + "entity2id.txt" 66 | self.rel_file = in_path + "relation2id.txt" 67 | """set essential parameters""" 68 | self.work_threads = threads 69 | self.nbatches = nbatches 70 | self.batch_size = batch_size 71 | self.bern = bern_flag 72 | self.filter = filter_flag 73 | self.negative_ent = neg_ent 74 | self.negative_rel = neg_rel 75 | self.sampling_mode = sampling_mode 76 | self.cross_sampling_flag = 0 77 | self.read() 78 | 79 | def read(self): 80 | if self.in_path != None: 81 | self.lib.setInPath(ctypes.create_string_buffer(self.in_path.encode(), len(self.in_path) * 2)) 82 | else: 83 | self.lib.setTrainPath(ctypes.create_string_buffer(self.tri_file.encode(), len(self.tri_file) * 2)) 84 | self.lib.setEntPath(ctypes.create_string_buffer(self.ent_file.encode(), len(self.ent_file) * 2)) 85 | self.lib.setRelPath(ctypes.create_string_buffer(self.rel_file.encode(), len(self.rel_file) * 2)) 86 | 87 | self.lib.setBern(self.bern) 88 | self.lib.setWorkThreads(self.work_threads) 89 | self.lib.randReset() 90 | self.lib.importTrainFiles() 91 | self.relTotal = self.lib.getRelationTotal() 92 | self.entTotal = self.lib.getEntityTotal() 93 | self.tripleTotal = self.lib.getTrainTotal() 94 | 95 | if self.batch_size is None: 96 | self.batch_size = self.tripleTotal // self.nbatches 97 | if self.nbatches is None: 98 | self.nbatches = self.tripleTotal // self.batch_size 99 | self.batch_seq_size = self.batch_size * (1 + self.negative_ent + self.negative_rel) 100 | 101 | self.batch_h = np.zeros(self.batch_seq_size, dtype=np.int64) 102 | self.batch_t = np.zeros(self.batch_seq_size, dtype=np.int64) 103 | self.batch_r = np.zeros(self.batch_seq_size, dtype=np.int64) 104 | self.batch_y = np.zeros(self.batch_seq_size, dtype=np.float32) 105 | self.batch_h_addr = self.batch_h.__array_interface__["data"][0] 106 | self.batch_t_addr = self.batch_t.__array_interface__["data"][0] 107 | self.batch_r_addr = self.batch_r.__array_interface__["data"][0] 108 | self.batch_y_addr = self.batch_y.__array_interface__["data"][0] 109 | 110 | def sampling(self): 111 | self.lib.sampling( 112 | self.batch_h_addr, 113 | self.batch_t_addr, 114 | self.batch_r_addr, 115 | self.batch_y_addr, 116 | self.batch_size, 117 | self.negative_ent, 118 | self.negative_rel, 119 | 0, 120 | self.filter, 121 | 0, 122 | 0 123 | ) 124 | return { 125 | "batch_h": self.batch_h, 126 | "batch_t": self.batch_t, 127 | "batch_r": self.batch_r, 128 | "batch_y": self.batch_y, 129 | "mode": "normal" 130 | } 131 | 132 | def sampling_head(self): 133 | self.lib.sampling( 134 | self.batch_h_addr, 135 | self.batch_t_addr, 136 | self.batch_r_addr, 137 | self.batch_y_addr, 138 | self.batch_size, 139 | self.negative_ent, 140 | self.negative_rel, 141 | -1, 142 | self.filter, 143 | 0, 144 | 0 145 | ) 146 | return { 147 | "batch_h": self.batch_h, 148 | "batch_t": self.batch_t[:self.batch_size], 149 | "batch_r": self.batch_r[:self.batch_size], 150 | "batch_y": self.batch_y, 151 | "mode": "head_batch" 152 | } 153 | 154 | def sampling_tail(self): 155 | self.lib.sampling( 156 | self.batch_h_addr, 157 | self.batch_t_addr, 158 | self.batch_r_addr, 159 | self.batch_y_addr, 160 | self.batch_size, 161 | self.negative_ent, 162 | self.negative_rel, 163 | 1, 164 | self.filter, 165 | 0, 166 | 0 167 | ) 168 | return { 169 | "batch_h": self.batch_h[:self.batch_size], 170 | "batch_t": self.batch_t, 171 | "batch_r": self.batch_r[:self.batch_size], 172 | "batch_y": self.batch_y, 173 | "mode": "tail_batch" 174 | } 175 | 176 | def cross_sampling(self): 177 | self.cross_sampling_flag = 1 - self.cross_sampling_flag 178 | if self.cross_sampling_flag == 0: 179 | return self.sampling_head() 180 | else: 181 | return self.sampling_tail() 182 | 183 | """interfaces to set essential parameters""" 184 | 185 | def set_work_threads(self, work_threads): 186 | self.work_threads = work_threads 187 | 188 | def set_in_path(self, in_path): 189 | self.in_path = in_path 190 | 191 | def set_nbatches(self, nbatches): 192 | self.nbatches = nbatches 193 | 194 | def set_batch_size(self, batch_size): 195 | self.batch_size = batch_size 196 | self.nbatches = self.tripleTotal // self.batch_size 197 | 198 | def set_ent_neg_rate(self, rate): 199 | self.negative_ent = rate 200 | 201 | def set_rel_neg_rate(self, rate): 202 | self.negative_rel = rate 203 | 204 | def set_bern_flag(self, bern): 205 | self.bern = bern 206 | 207 | def set_filter_flag(self, filter): 208 | self.filter = filter 209 | 210 | """interfaces to get essential parameters""" 211 | 212 | def get_batch_size(self): 213 | return self.batch_size 214 | 215 | def get_ent_tot(self): 216 | return self.entTotal 217 | 218 | def get_rel_tot(self): 219 | return self.relTotal 220 | 221 | def get_triple_tot(self): 222 | return self.tripleTotal 223 | 224 | def __iter__(self): 225 | if self.sampling_mode == "normal": 226 | return TrainDataSampler(self.nbatches, self.sampling) 227 | else: 228 | return TrainDataSampler(self.nbatches, self.cross_sampling) 229 | 230 | def __len__(self): 231 | return self.nbatches 232 | -------------------------------------------------------------------------------- /mmkgc/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .TrainDataLoader import TrainDataLoader 6 | from .TestDataLoader import TestDataLoader 7 | 8 | __all__ = [ 9 | 'TrainDataLoader', 10 | 'TestDataLoader' 11 | ] -------------------------------------------------------------------------------- /mmkgc/make.sh: -------------------------------------------------------------------------------- 1 | mkdir release 2 | g++ ./base/Base.cpp -fPIC -shared -o ./release/Base.so -pthread -O3 -march=native -------------------------------------------------------------------------------- /mmkgc/module/BaseModule.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | import json 5 | import numpy as np 6 | 7 | class BaseModule(nn.Module): 8 | 9 | def __init__(self): 10 | super(BaseModule, self).__init__() 11 | self.zero_const = nn.Parameter(torch.Tensor([0])) 12 | self.zero_const.requires_grad = False 13 | self.pi_const = nn.Parameter(torch.Tensor([3.14159265358979323846])) 14 | self.pi_const.requires_grad = False 15 | 16 | def load_checkpoint(self, path): 17 | self.load_state_dict(torch.load(os.path.join(path))) 18 | self.eval() 19 | 20 | def save_checkpoint(self, path): 21 | torch.save(self.state_dict(), path) 22 | 23 | def load_parameters(self, path): 24 | f = open(path, "r") 25 | parameters = json.loads(f.read()) 26 | f.close() 27 | for i in parameters: 28 | parameters[i] = torch.Tensor(parameters[i]) 29 | self.load_state_dict(parameters, strict = False) 30 | self.eval() 31 | 32 | def save_parameters(self, path): 33 | f = open(path, "w") 34 | f.write(json.dumps(self.get_parameters("list"))) 35 | f.close() 36 | 37 | def get_parameters(self, mode = "numpy", param_dict = None): 38 | all_param_dict = self.state_dict() 39 | if param_dict == None: 40 | param_dict = all_param_dict.keys() 41 | res = {} 42 | for param in param_dict: 43 | if mode == "numpy": 44 | res[param] = all_param_dict[param].cpu().numpy() 45 | elif mode == "list": 46 | res[param] = all_param_dict[param].cpu().numpy().tolist() 47 | else: 48 | res[param] = all_param_dict[param] 49 | return res 50 | 51 | def set_parameters(self, parameters): 52 | for i in parameters: 53 | parameters[i] = torch.Tensor(parameters[i]) 54 | self.load_state_dict(parameters, strict = False) 55 | self.eval() -------------------------------------------------------------------------------- /mmkgc/module/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .BaseModule import BaseModule 6 | 7 | __init__ = { 8 | 'BaseModule' 9 | } -------------------------------------------------------------------------------- /mmkgc/module/loss/Loss.py: -------------------------------------------------------------------------------- 1 | from ..BaseModule import BaseModule 2 | 3 | class Loss(BaseModule): 4 | 5 | def __init__(self): 6 | super(Loss, self).__init__() -------------------------------------------------------------------------------- /mmkgc/module/loss/MarginLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | from .Loss import Loss 7 | 8 | class MarginLoss(Loss): 9 | 10 | def __init__(self, adv_temperature = None, margin = 6.0): 11 | super(MarginLoss, self).__init__() 12 | self.margin = nn.Parameter(torch.Tensor([margin])) 13 | self.margin.requires_grad = False 14 | if adv_temperature != None: 15 | self.adv_temperature = nn.Parameter(torch.Tensor([adv_temperature])) 16 | self.adv_temperature.requires_grad = False 17 | self.adv_flag = True 18 | else: 19 | self.adv_flag = False 20 | 21 | def get_weights(self, n_score): 22 | return F.softmax(-n_score * self.adv_temperature, dim = -1).detach() 23 | 24 | def forward(self, p_score, n_score): 25 | if self.adv_flag: 26 | return (self.get_weights(n_score) * torch.max(p_score - n_score, -self.margin)).sum(dim = -1).mean() + self.margin 27 | else: 28 | return (torch.max(p_score - n_score, -self.margin)).mean() + self.margin 29 | 30 | 31 | def predict(self, p_score, n_score): 32 | score = self.forward(p_score, n_score) 33 | return score.cpu().data.numpy() -------------------------------------------------------------------------------- /mmkgc/module/loss/SigmoidLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from .Loss import Loss 6 | 7 | class SigmoidLoss(Loss): 8 | 9 | def __init__(self, adv_temperature = None): 10 | super(SigmoidLoss, self).__init__() 11 | self.criterion = nn.LogSigmoid() 12 | if adv_temperature != None: 13 | self.adv_temperature = nn.Parameter(torch.Tensor([adv_temperature])) 14 | self.adv_temperature.requires_grad = False 15 | self.adv_flag = True 16 | else: 17 | self.adv_flag = False 18 | 19 | def get_weights(self, n_score): 20 | return F.softmax(n_score * self.adv_temperature, dim = -1).detach() 21 | 22 | def forward(self, p_score, n_score): 23 | if self.adv_flag: 24 | return -(self.criterion(p_score).mean() + (self.get_weights(n_score) * self.criterion(-n_score)).sum(dim = -1).mean()) / 2 25 | else: 26 | return -(self.criterion(p_score).mean() + self.criterion(-n_score).mean()) / 2 27 | 28 | def predict(self, p_score, n_score): 29 | score = self.forward(p_score, n_score) 30 | return score.cpu().data.numpy() -------------------------------------------------------------------------------- /mmkgc/module/loss/SoftplusLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from .Loss import Loss 6 | 7 | class SoftplusLoss(Loss): 8 | 9 | def __init__(self, adv_temperature = None): 10 | super(SoftplusLoss, self).__init__() 11 | self.criterion = nn.Softplus() 12 | if adv_temperature != None: 13 | self.adv_temperature = nn.Parameter(torch.Tensor([adv_temperature])) 14 | self.adv_temperature.requires_grad = False 15 | self.adv_flag = True 16 | else: 17 | self.adv_flag = False 18 | 19 | def get_weights(self, n_score): 20 | return F.softmax(n_score * self.adv_temperature, dim = -1).detach() 21 | 22 | def forward(self, p_score, n_score): 23 | if self.adv_flag: 24 | return (self.criterion(-p_score).mean() + (self.get_weights(n_score) * self.criterion(n_score)).sum(dim = -1).mean()) / 2 25 | else: 26 | return (self.criterion(-p_score).mean() + self.criterion(n_score).mean()) / 2 27 | 28 | 29 | def predict(self, p_score, n_score): 30 | score = self.forward(p_score, n_score) 31 | return score.cpu().data.numpy() -------------------------------------------------------------------------------- /mmkgc/module/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .Loss import Loss 6 | from .MarginLoss import MarginLoss 7 | from .SoftplusLoss import SoftplusLoss 8 | from .SigmoidLoss import SigmoidLoss 9 | 10 | __all__ = [ 11 | 'Loss', 12 | 'MarginLoss', 13 | 'SoftplusLoss', 14 | 'SigmoidLoss', 15 | ] -------------------------------------------------------------------------------- /mmkgc/module/model/EnsembleComplEx.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .Model import Model 5 | 6 | 7 | class EnsembleComplEx(Model): 8 | 9 | def __init__( 10 | self, 11 | ent_tot, 12 | rel_tot, 13 | dim=100, 14 | margin=6.0, 15 | epsilon=2.0, 16 | visual_embs=None, 17 | textual_embs=None 18 | ): 19 | super(EnsembleComplEx, self).__init__(ent_tot, rel_tot) 20 | 21 | self.margin = margin 22 | self.epsilon = epsilon 23 | 24 | self.dim_e = dim * 2 25 | self.dim_r = dim * 2 26 | 27 | self.ent_emb_s = nn.Embedding(self.ent_tot, self.dim_e) 28 | self.rel_emb_s = nn.Embedding(self.rel_tot, self.dim_r) 29 | self.dim = dim 30 | 31 | self.ent_emb_v = nn.Embedding.from_pretrained(visual_embs) 32 | self.ent_emb_t = nn.Embedding.from_pretrained(textual_embs) 33 | self.ent_emb_v.requires_grad_(True) 34 | self.ent_emb_t.requires_grad_(True) 35 | visual_dim = self.ent_emb_v.weight.shape[1] 36 | textual_dim = self.ent_emb_t.weight.shape[1] 37 | self.visual_proj = nn.Linear(visual_dim, self.dim_e) 38 | self.textual_proj = nn.Linear(textual_dim, self.dim_e) 39 | self.rel_emb_v = nn.Embedding(self.rel_tot, 2 * self.dim_r) 40 | self.rel_emb_t = nn.Embedding(self.rel_tot, self.dim_r) 41 | self.rel_emb_j = nn.Embedding(self.rel_tot, self.dim_r) 42 | self.init_emb() 43 | self.predict_mode = "all" 44 | self.ent_attn = nn.Parameter(torch.zeros((self.dim_e, ))) 45 | self.ent_attn.requires_grad_(True) 46 | 47 | def init_emb(self): 48 | self.ent_embedding_range = nn.Parameter( 49 | torch.Tensor([(self.margin + self.epsilon) / self.dim_e]), 50 | requires_grad=False 51 | ) 52 | nn.init.uniform_( 53 | tensor=self.ent_emb_s.weight.data, 54 | a=-self.ent_embedding_range.item(), 55 | b=self.ent_embedding_range.item() 56 | ) 57 | self.rel_embedding_range = nn.Parameter( 58 | torch.Tensor([(self.margin + self.epsilon) / self.dim_r]), 59 | requires_grad=False 60 | ) 61 | nn.init.uniform_( 62 | tensor=self.rel_emb_s.weight.data, 63 | a=-self.rel_embedding_range.item(), 64 | b=self.rel_embedding_range.item() 65 | ) 66 | nn.init.uniform_( 67 | tensor=self.rel_emb_v.weight.data, 68 | a=-self.rel_embedding_range.item(), 69 | b=self.rel_embedding_range.item() 70 | ) 71 | nn.init.uniform_( 72 | tensor=self.rel_emb_t.weight.data, 73 | a=-self.rel_embedding_range.item(), 74 | b=self.rel_embedding_range.item() 75 | ) 76 | nn.init.uniform_( 77 | tensor=self.rel_emb_j.weight.data, 78 | a=-self.rel_embedding_range.item(), 79 | b=self.rel_embedding_range.item() 80 | ) 81 | self.margin = nn.Parameter(torch.Tensor([self.margin])) 82 | self.margin.requires_grad = False 83 | 84 | def score_function_complex(self, h, t, r, mode): 85 | h_re = h[:, 0: self.dim] 86 | h_im = h[:, self.dim: self.dim_e] 87 | t_re = t[:, 0: self.dim] 88 | t_im = t[:, self.dim: self.dim_e] 89 | r_re = r[:, 0: self.dim] 90 | r_im = r[:, self.dim: self.dim_e] 91 | return torch.sum( 92 | h_re * t_re * r_re 93 | + h_im * t_im * r_re 94 | + h_re * t_im * r_im 95 | - h_im * t_re * r_im, 96 | -1 97 | ) 98 | 99 | def forward(self, data, require_att=False): 100 | batch_h = data['batch_h'] 101 | batch_t = data['batch_t'] 102 | batch_r = data['batch_r'] 103 | mode = data['mode'] 104 | # structural embeddings 105 | h = self.ent_emb_s(batch_h) 106 | t = self.ent_emb_s(batch_t) 107 | r = self.rel_emb_s(batch_r) 108 | # visual embeddings 109 | hv = self.visual_proj(self.ent_emb_v(batch_h)) 110 | tv = self.visual_proj(self.ent_emb_v(batch_t)) 111 | rv = self.rel_emb_v(batch_r) 112 | # textual embeddings 113 | ht = self.textual_proj(self.ent_emb_t(batch_h)) 114 | tt = self.textual_proj(self.ent_emb_t(batch_t)) 115 | rt = self.rel_emb_t(batch_r) 116 | # joint embeddings 117 | hj, att_h = self.get_joint_embeddings(h, hv, ht) 118 | tj, att_t = self.get_joint_embeddings(t, tv, tt) 119 | rj = self.rel_emb_j(batch_r) 120 | # scores 121 | score_s = self.score_function_complex(h, t, r, mode) 122 | score_v = self.score_function_complex(hv, tv, rv, mode) 123 | score_t = self.score_function_complex(ht, tt, rt, mode) 124 | score_j = self.score_function_complex(hj, tj, rj, mode) 125 | if require_att: 126 | return [score_s, score_v, score_t, score_j], (att_h + att_t) 127 | return [score_s, score_v, score_t, score_j] 128 | 129 | def get_joint_embeddings(self, es, ev, et): 130 | e = torch.stack((es, ev, et), dim=1) 131 | dot = torch.exp(e @ self.ent_attn) 132 | att_w = dot / torch.sum(dot, dim=1).reshape(-1, 1) 133 | w1, w2, w3 = att_w[:, 0].reshape(-1, 1), att_w[:, 134 | 1].reshape(-1, 1), att_w[:, 2].reshape(-1, 1) 135 | ej = w1 * es + w2 * ev + w3 * et 136 | return ej, att_w 137 | 138 | def predict(self, data): 139 | pred_result, att = self.forward(data, require_att=True) 140 | if self.predict_mode == "s": 141 | score = -pred_result[0] 142 | elif self.predict_mode == "v": 143 | score = -pred_result[1] 144 | elif self.predict_mode == "t": 145 | score = -pred_result[2] 146 | elif self.predict_mode == "j": 147 | score = -pred_result[3] 148 | elif self.predict_mode == "all": 149 | att /= 2 150 | w1, w2, w3 = att[:, 0].reshape(-1, 1), att[:, 151 | 1].reshape(-1, 1), att[:, 2].reshape(-1, 1) 152 | score = - \ 153 | (w1 * pred_result[0] + w2 * pred_result[1] + 154 | w3 * pred_result[2] + pred_result[3]) 155 | else: 156 | raise NotImplementedError("No such prediction setting!") 157 | return score.cpu().data.numpy() 158 | 159 | def regularization(self, data): 160 | batch_h = data['batch_h'] 161 | batch_t = data['batch_t'] 162 | batch_r = data['batch_r'] 163 | h = self.ent_emb_s(batch_h) 164 | t = self.ent_emb_s(batch_t) 165 | r = self.rel_emb_s(batch_r) 166 | h_re = h[:, 0: self.dim] 167 | h_im = h[:, self.dim: self.dim_e] 168 | t_re = t[:, 0: self.dim] 169 | t_im = t[:, self.dim: self.dim_e] 170 | r_re = r[:, 0: self.dim] 171 | r_im = r[:, self.dim: self.dim_e] 172 | regul = (torch.mean(h_re ** 2) + 173 | torch.mean(h_im ** 2) + 174 | torch.mean(t_re ** 2) + 175 | torch.mean(t_im ** 2) + 176 | torch.mean(r_re ** 2) + 177 | torch.mean(r_im ** 2)) / 6 178 | return regul 179 | -------------------------------------------------------------------------------- /mmkgc/module/model/EnsembleMMKGE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .Model import Model 5 | 6 | class EnsembleMMKGE(Model): 7 | 8 | def __init__( 9 | self, 10 | ent_tot, 11 | rel_tot, 12 | dim = 100, 13 | margin = 6.0, 14 | epsilon = 2.0, 15 | visual_embs=None, 16 | textual_embs=None 17 | ): 18 | super(EnsembleMMKGE, self).__init__(ent_tot, rel_tot) 19 | 20 | self.margin = margin 21 | self.epsilon = epsilon 22 | 23 | self.dim_e = dim * 2 24 | self.dim_r = dim 25 | 26 | self.ent_emb_s = nn.Embedding(self.ent_tot, self.dim_e) 27 | self.rel_emb_s = nn.Embedding(self.rel_tot, self.dim_r) 28 | 29 | self.ent_emb_v = nn.Embedding.from_pretrained(visual_embs) 30 | self.ent_emb_t = nn.Embedding.from_pretrained(textual_embs) 31 | self.ent_emb_v.requires_grad_(True) 32 | self.ent_emb_t.requires_grad_(True) 33 | visual_dim = self.ent_emb_v.weight.shape[1] 34 | textual_dim = self.ent_emb_t.weight.shape[1] 35 | self.visual_proj = nn.Linear(visual_dim, self.dim_e) 36 | self.textual_proj = nn.Linear(textual_dim, self.dim_e) 37 | self.rel_emb_v = nn.Embedding(self.rel_tot, self.dim_r) 38 | self.rel_emb_t = nn.Embedding(self.rel_tot, self.dim_r) 39 | self.rel_emb_j = nn.Embedding(self.rel_tot, self.dim_r) 40 | self.init_emb() 41 | self.predict_mode = "all" 42 | self.ent_attn = nn.Parameter(torch.zeros((self.dim_e, ))) 43 | self.ent_attn.requires_grad_(True) 44 | 45 | 46 | def init_emb(self): 47 | self.ent_embedding_range = nn.Parameter( 48 | torch.Tensor([(self.margin + self.epsilon) / self.dim_e]), 49 | requires_grad=False 50 | ) 51 | nn.init.uniform_( 52 | tensor = self.ent_emb_s.weight.data, 53 | a=-self.ent_embedding_range.item(), 54 | b=self.ent_embedding_range.item() 55 | ) 56 | self.rel_embedding_range = nn.Parameter( 57 | torch.Tensor([(self.margin + self.epsilon) / self.dim_r]), 58 | requires_grad=False 59 | ) 60 | nn.init.uniform_( 61 | tensor = self.rel_emb_s.weight.data, 62 | a=-self.rel_embedding_range.item(), 63 | b=self.rel_embedding_range.item() 64 | ) 65 | nn.init.uniform_( 66 | tensor = self.rel_emb_v.weight.data, 67 | a=-self.rel_embedding_range.item(), 68 | b=self.rel_embedding_range.item() 69 | ) 70 | nn.init.uniform_( 71 | tensor = self.rel_emb_t.weight.data, 72 | a=-self.rel_embedding_range.item(), 73 | b=self.rel_embedding_range.item() 74 | ) 75 | nn.init.uniform_( 76 | tensor = self.rel_emb_j.weight.data, 77 | a=-self.rel_embedding_range.item(), 78 | b=self.rel_embedding_range.item() 79 | ) 80 | self.margin = nn.Parameter(torch.Tensor([self.margin])) 81 | self.margin.requires_grad = False 82 | 83 | 84 | def score_function_transe(self, h, r, t, mode): 85 | h = F.normalize(h, 2, -1) 86 | r = F.normalize(r, 2, -1) 87 | t = F.normalize(t, 2, -1) 88 | if mode == 'head_batch': 89 | score = h + (r - t) 90 | else: 91 | score = (h + r) - t 92 | score = torch.norm(score, 1, -1).flatten() 93 | return score 94 | 95 | 96 | def score_function_rotate(self, h, t, r, mode): 97 | pi = self.pi_const 98 | 99 | re_head, im_head = torch.chunk(h, 2, dim=-1) 100 | re_tail, im_tail = torch.chunk(t, 2, dim=-1) 101 | 102 | phase_relation = r / (self.rel_embedding_range.item() / pi) 103 | 104 | re_relation = torch.cos(phase_relation) 105 | im_relation = torch.sin(phase_relation) 106 | 107 | re_head = re_head.view(-1, re_relation.shape[0], re_head.shape[-1]).permute(1, 0, 2) 108 | re_tail = re_tail.view(-1, re_relation.shape[0], re_tail.shape[-1]).permute(1, 0, 2) 109 | im_head = im_head.view(-1, re_relation.shape[0], im_head.shape[-1]).permute(1, 0, 2) 110 | im_tail = im_tail.view(-1, re_relation.shape[0], im_tail.shape[-1]).permute(1, 0, 2) 111 | im_relation = im_relation.view(-1, re_relation.shape[0], im_relation.shape[-1]).permute(1, 0, 2) 112 | re_relation = re_relation.view(-1, re_relation.shape[0], re_relation.shape[-1]).permute(1, 0, 2) 113 | 114 | if mode == "head_batch": 115 | re_score = re_relation * re_tail + im_relation * im_tail 116 | im_score = re_relation * im_tail - im_relation * re_tail 117 | re_score = re_score - re_head 118 | im_score = im_score - im_head 119 | else: 120 | re_score = re_head * re_relation - im_head * im_relation 121 | im_score = re_head * im_relation + im_head * re_relation 122 | re_score = re_score - re_tail 123 | im_score = im_score - im_tail 124 | 125 | score = torch.stack([re_score, im_score], dim = 0) 126 | score = score.norm(dim = 0).sum(dim = -1) 127 | return score.permute(1, 0).flatten() 128 | 129 | def forward(self, data, require_att=False): 130 | batch_h = data['batch_h'] 131 | batch_t = data['batch_t'] 132 | batch_r = data['batch_r'] 133 | mode = data['mode'] 134 | # structural embeddings 135 | h = self.ent_emb_s(batch_h) 136 | t = self.ent_emb_s(batch_t) 137 | r = self.rel_emb_s(batch_r) 138 | # visual embeddings 139 | hv = self.visual_proj(self.ent_emb_v(batch_h)) 140 | tv = self.visual_proj(self.ent_emb_v(batch_t)) 141 | rv = self.rel_emb_v(batch_r) 142 | # textual embeddings 143 | ht = self.textual_proj(self.ent_emb_t(batch_h)) 144 | tt = self.textual_proj(self.ent_emb_t(batch_t)) 145 | rt = self.rel_emb_t(batch_r) 146 | # joint embeddings 147 | hj, att_h = self.get_joint_embeddings(h, hv, ht) 148 | tj, att_t = self.get_joint_embeddings(t, tv, tt) 149 | rj = self.rel_emb_j(batch_r) 150 | # scores 151 | score_s = self.margin - self.score_function_rotate(h, t, r, mode) 152 | score_v = self.margin - self.score_function_rotate(hv, tv, rv, mode) 153 | score_t = self.margin - self.score_function_rotate(ht, tt, rt, mode) 154 | score_j = self.margin - self.score_function_rotate(hj, tj, rj, mode) 155 | if require_att: 156 | return [score_s, score_v, score_t, score_j], (att_h + att_t) 157 | return [score_s, score_v, score_t, score_j] 158 | 159 | def get_joint_embeddings(self, es, ev, et): 160 | e = torch.stack((es, ev, et), dim=1) 161 | dot = torch.exp(e @ self.ent_attn) 162 | att_w = dot / torch.sum(dot, dim=1).reshape(-1, 1) 163 | w1, w2, w3 = att_w[:, 0].reshape(-1, 1), att_w[:, 1].reshape(-1, 1), att_w[:, 2].reshape(-1, 1) 164 | ej = w1 * es + w2 * ev + w3 * et 165 | return ej, att_w 166 | 167 | 168 | def predict(self, data): 169 | pred_result, att = self.forward(data, require_att=True) 170 | if self.predict_mode == "s": 171 | score = -pred_result[0] 172 | elif self.predict_mode == "v": 173 | score = -pred_result[1] 174 | elif self.predict_mode == "t": 175 | score = -pred_result[2] 176 | elif self.predict_mode == "j": 177 | score = -pred_result[3] 178 | elif self.predict_mode == "all": 179 | att /= 2 180 | w1, w2, w3 = att[:, 0].reshape(-1, 1), att[:, 1].reshape(-1, 1), att[:, 2].reshape(-1, 1) 181 | score = -(w1 * pred_result[0] + w2 * pred_result[1] + w3 * pred_result[2] + pred_result[3]) 182 | else: 183 | raise NotImplementedError("No such prediction setting!") 184 | return score.cpu().data.numpy() 185 | 186 | def regularization(self, data): 187 | batch_h = data['batch_h'] 188 | batch_t = data['batch_t'] 189 | batch_r = data['batch_r'] 190 | h = self.ent_emb_s(batch_h) 191 | t = self.ent_emb_s(batch_t) 192 | r = self.rel_emb_s(batch_r) 193 | regul = (torch.mean(h ** 2) + 194 | torch.mean(t ** 2) + 195 | torch.mean(r ** 2)) / 3 196 | return regul 197 | -------------------------------------------------------------------------------- /mmkgc/module/model/IKRL.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import time 5 | from .Model import Model 6 | 7 | 8 | class IKRL(Model): 9 | 10 | def __init__(self, ent_tot, rel_tot, dim=100, p_norm=1, img_emb=None, 11 | img_dim=4096, norm_flag=True, margin=None, epsilon=None, 12 | test_mode='lp', beta=None): 13 | super(IKRL, self).__init__(ent_tot, rel_tot) 14 | 15 | self.dim = dim 16 | self.margin = margin 17 | self.epsilon = epsilon 18 | self.norm_flag = norm_flag 19 | self.p_norm = p_norm 20 | self.img_dim = img_dim 21 | self.test_mode = test_mode 22 | 23 | self.ent_embeddings = nn.Embedding(self.ent_tot, self.dim) 24 | self.rel_embeddings = nn.Embedding(self.rel_tot, self.dim) 25 | # 新增的投影矩阵和图像embeddings 26 | self.img_proj = nn.Linear(self.img_dim, self.dim) 27 | self.img_embeddings = nn.Embedding.from_pretrained(img_emb).requires_grad_(True) 28 | self.beta = beta 29 | 30 | if margin is None or epsilon is None: 31 | nn.init.xavier_uniform_(self.ent_embeddings.weight.data) 32 | nn.init.xavier_uniform_(self.rel_embeddings.weight.data) 33 | else: 34 | self.embedding_range = nn.Parameter( 35 | torch.Tensor([(self.margin + self.epsilon) / self.dim]), requires_grad=False 36 | ) 37 | nn.init.uniform_( 38 | tensor=self.ent_embeddings.weight.data, 39 | a=-self.embedding_range.item(), 40 | b=self.embedding_range.item() 41 | ) 42 | nn.init.uniform_( 43 | tensor=self.rel_embeddings.weight.data, 44 | a=-self.embedding_range.item(), 45 | b=self.embedding_range.item() 46 | ) 47 | 48 | if margin is not None: 49 | self.margin = nn.Parameter(torch.Tensor([margin])) 50 | self.margin.requires_grad = False 51 | self.margin_flag = True 52 | else: 53 | self.margin_flag = False 54 | 55 | def _calc(self, h, t, r, mode): 56 | if self.norm_flag: 57 | h = F.normalize(h, 2, -1) 58 | r = F.normalize(r, 2, -1) 59 | t = F.normalize(t, 2, -1) 60 | if mode != 'normal': 61 | h = h.view(-1, r.shape[0], h.shape[-1]) 62 | t = t.view(-1, r.shape[0], t.shape[-1]) 63 | r = r.view(-1, r.shape[0], r.shape[-1]) 64 | if mode == 'head_batch': 65 | score = h + (r - t) 66 | else: 67 | score = (h + r) - t 68 | score = torch.norm(score, self.p_norm, -1).flatten() 69 | return score 70 | 71 | def get_batch_ent_embs(self, data): 72 | return self.ent_embeddings(data) 73 | 74 | def get_fake_score( 75 | self, 76 | batch_h, 77 | batch_r, 78 | batch_t, 79 | mode, 80 | fake_hv=None, 81 | fake_tv=None 82 | ): 83 | if fake_hv is None or fake_tv is None: 84 | raise NotImplementedError 85 | h = self.ent_embeddings(batch_h) 86 | t = self.ent_embeddings(batch_t) 87 | r = self.rel_embeddings(batch_r) 88 | h_img_emb = self.img_proj(self.img_embeddings(batch_h)) 89 | t_img_emb = self.img_proj(self.img_embeddings(batch_t)) 90 | # three kinds of fake score 91 | score_hv = ( 92 | self._calc(h, t, r, mode) 93 | + self._calc(fake_hv, t_img_emb, r, mode) 94 | + self._calc(fake_hv, t, r, mode) 95 | + self._calc(h, t_img_emb, r, mode) 96 | ) 97 | score_tv = ( 98 | self._calc(h, t, r, mode) 99 | + self._calc(h_img_emb, fake_tv, r, mode) 100 | + self._calc(h_img_emb, t, r, mode) 101 | + self._calc(h, fake_tv, r, mode) 102 | ) 103 | score_htv = ( 104 | self._calc(h, t, r, mode) 105 | + self._calc(fake_hv, fake_tv, r, mode) 106 | + self._calc(fake_hv, t, r, mode) 107 | + self._calc(h, fake_tv, r, mode) 108 | ) 109 | return [self.margin - score_hv, self.margin - score_tv, self.margin - score_htv], [h_img_emb, t_img_emb] 110 | 111 | 112 | def forward(self, data): 113 | batch_h = data['batch_h'] 114 | batch_t = data['batch_t'] 115 | batch_r = data['batch_r'] 116 | h_ent, h_img, t_ent, t_img = batch_h, batch_h, batch_t, batch_t 117 | mode = data['mode'] 118 | h = self.ent_embeddings(h_ent) 119 | t = self.ent_embeddings(t_ent) 120 | r = self.rel_embeddings(batch_r) 121 | h_img_emb = self.img_proj(self.img_embeddings(h_img)) 122 | t_img_emb = self.img_proj(self.img_embeddings(t_img)) 123 | # print(self._calc(h, t, r, mode)) 124 | score = ( 125 | self._calc(h, t, r, mode) 126 | + self._calc(h_img_emb, t_img_emb, r, mode) 127 | + self._calc(h_img_emb, t, r, mode) 128 | + self._calc(h, t_img_emb, r, mode) 129 | # + self._calc(h + h_img_emb, t + t_img_emb, r, mode) 130 | ) 131 | if self.margin_flag: 132 | return self.margin - score 133 | else: 134 | return score 135 | 136 | def regularization(self, data): 137 | batch_h = data['batch_h'] 138 | batch_t = data['batch_t'] 139 | batch_r = data['batch_r'] 140 | h = self.ent_embeddings(batch_h) 141 | t = self.ent_embeddings(batch_t) 142 | r = self.rel_embeddings(batch_r) 143 | regul = (torch.mean(h ** 2) + 144 | torch.mean(t ** 2) + 145 | torch.mean(r ** 2)) / 3 146 | return regul 147 | 148 | def predict(self, data): 149 | score = self.forward(data) 150 | if self.margin_flag: 151 | score = self.margin - score 152 | return score.cpu().data.numpy() 153 | else: 154 | return score.cpu().data.numpy() 155 | 156 | def set_test_mode(self, new_mode): 157 | self.test_mode = new_mode 158 | 159 | def get_rel_rank(self, data): 160 | head, tail, rel = data 161 | h_img_emb = self.img_proj(self.img_embeddings(head)) 162 | t_img_emb = self.img_proj(self.img_embeddings(tail)) 163 | relations = self.rel_embeddings.weight 164 | h = h_img_emb.reshape(-1, h_img_emb.shape[0]).expand((relations.shape[0], h_img_emb.shape[0])) 165 | t = t_img_emb.reshape(-1, t_img_emb.shape[0]).expand((relations.shape[0], t_img_emb.shape[0])) 166 | scores = self._calc(h, t, relations, mode='normal') 167 | ranks = torch.argsort(scores) 168 | rank = 0 169 | for (index, val) in enumerate(ranks): 170 | if val.item() == rel.item(): 171 | rank = index 172 | break 173 | return rank + 1 174 | -------------------------------------------------------------------------------- /mmkgc/module/model/MMKRL.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import time 5 | from .Model import Model 6 | 7 | 8 | class MMKRL(Model): 9 | 10 | def __init__(self, ent_tot, rel_tot, dim=100, p_norm=1, img_emb=None, 11 | img_dim=4096, norm_flag=True, margin=None, epsilon=None, 12 | text_emb=None): 13 | super(MMKRL, self).__init__(ent_tot, rel_tot) 14 | assert img_emb is not None 15 | assert text_emb is not None 16 | self.dim = dim 17 | self.margin = margin 18 | self.epsilon = epsilon 19 | self.norm_flag = norm_flag 20 | self.p_norm = p_norm 21 | self.img_dim = img_emb.shape[1] 22 | self.text_dim = text_emb.shape[1] 23 | self.ent_embeddings = nn.Embedding(self.ent_tot, self.dim) 24 | self.rel_embeddings = nn.Embedding(self.rel_tot, self.dim) 25 | # 新增的投影矩阵和图像embeddings 26 | # print(img_emb.shape, text_emb.shape) 27 | self.mm_proj = nn.Linear(self.img_dim + self.text_dim, self.dim) 28 | self.mm_embeddings = nn.Embedding.from_pretrained(torch.cat((img_emb, text_emb), dim=1)).requires_grad_(True) 29 | self.s_proj = nn.Linear(self.dim, self.dim, bias=False) 30 | self.h_bias = nn.Parameter(torch.randn(self.dim, ), requires_grad=True) 31 | self.r_bias = nn.Parameter(torch.randn(self.dim, ), requires_grad=True) 32 | self.t_bias = nn.Parameter(torch.randn(self.dim, ), requires_grad=True) 33 | self.sm_proj = nn.Linear(self.dim, self.dim, bias=False) 34 | self.ka_loss = nn.MSELoss() 35 | 36 | if margin is None or epsilon is None: 37 | nn.init.xavier_uniform_(self.ent_embeddings.weight.data) 38 | nn.init.xavier_uniform_(self.rel_embeddings.weight.data) 39 | else: 40 | self.embedding_range = nn.Parameter( 41 | torch.Tensor([(self.margin + self.epsilon) / self.dim]), requires_grad=False 42 | ) 43 | nn.init.uniform_( 44 | tensor=self.ent_embeddings.weight.data, 45 | a=-self.embedding_range.item(), 46 | b=self.embedding_range.item() 47 | ) 48 | nn.init.uniform_( 49 | tensor=self.rel_embeddings.weight.data, 50 | a=-self.embedding_range.item(), 51 | b=self.embedding_range.item() 52 | ) 53 | 54 | if margin is not None: 55 | self.margin = nn.Parameter(torch.Tensor([margin])) 56 | self.margin.requires_grad = False 57 | self.margin_flag = True 58 | else: 59 | self.margin_flag = False 60 | 61 | def _calc(self, h, t, r, mode): 62 | if self.norm_flag: 63 | h = F.normalize(h, 2, -1) 64 | r = F.normalize(r, 2, -1) 65 | t = F.normalize(t, 2, -1) 66 | if mode != 'normal': 67 | h = h.view(-1, r.shape[0], h.shape[-1]) 68 | t = t.view(-1, r.shape[0], t.shape[-1]) 69 | r = r.view(-1, r.shape[0], r.shape[-1]) 70 | if mode == 'head_batch': 71 | score = h + (r - t) 72 | else: 73 | score = (h + r) - t 74 | score = torch.norm(score, self.p_norm, -1).flatten() 75 | return score 76 | 77 | def get_batch_ent_embs(self, data): 78 | return self.ent_embeddings(data) 79 | 80 | def get_fake_score( 81 | self, 82 | batch_h, 83 | batch_r, 84 | batch_t, 85 | mode, 86 | fake_hv=None, 87 | fake_tv=None 88 | ): 89 | if fake_hv is None or fake_tv is None: 90 | raise NotImplementedError 91 | h = self.ent_embeddings(batch_h) 92 | t = self.ent_embeddings(batch_t) 93 | r = self.rel_embeddings(batch_r) 94 | h_proj = self.s_proj(h) + self.h_bias + fake_hv 95 | r_proj = self.s_proj(r) + self.r_bias 96 | t_proj = self.s_proj(t) + self.t_bias + fake_tv 97 | h_mm_emb = self.mm_proj(self.mm_embeddings(batch_h)) + fake_hv 98 | t_mm_emb = self.mm_proj(self.mm_embeddings(batch_t)) + fake_tv 99 | score = ( 100 | self._calc(h_proj, t_proj, r_proj, mode) 101 | + self._calc(h_mm_emb, t_mm_emb, r_proj, mode) 102 | + self._calc(h_mm_emb, t, r_proj, mode) 103 | + self._calc(h, t_mm_emb, r_proj, mode) 104 | ) / 4 105 | 106 | return score 107 | 108 | 109 | def forward(self, data, mse=False): 110 | batch_h = data['batch_h'] 111 | batch_t = data['batch_t'] 112 | batch_r = data['batch_r'] 113 | h_ent, h_img, t_ent, t_img = batch_h, batch_h, batch_t, batch_t 114 | mode = data['mode'] 115 | h = self.ent_embeddings(h_ent) 116 | t = self.ent_embeddings(t_ent) 117 | r = self.rel_embeddings(batch_r) 118 | h_proj = self.s_proj(h) + self.h_bias 119 | r_proj = self.s_proj(r) + self.r_bias 120 | t_proj = self.s_proj(t) + self.t_bias 121 | h_mm_emb = self.mm_proj(self.mm_embeddings(batch_h)) 122 | t_mm_emb = self.mm_proj(self.mm_embeddings(batch_t)) 123 | score = ( 124 | self._calc(h_proj, t_proj, r_proj, mode) 125 | + self._calc(h_mm_emb, t_mm_emb, r_proj, mode) 126 | + self._calc(h_mm_emb, t, r_proj, mode) 127 | + self._calc(h, t_mm_emb, r_proj, mode) 128 | ) / 4 129 | if not mse: 130 | return score 131 | else: 132 | loss_kas = self.ka_loss(h, h_proj) + self.ka_loss(t, t_proj) + self.ka_loss(r, r_proj) 133 | loss_kam = self.ka_loss(self.sm_proj(h), h_mm_emb) + self.ka_loss(self.sm_proj(t), t_mm_emb) 134 | loss_ka = loss_kas + loss_kam 135 | return score, loss_ka 136 | 137 | 138 | def regularization(self, data): 139 | batch_h = data['batch_h'] 140 | batch_t = data['batch_t'] 141 | batch_r = data['batch_r'] 142 | h = self.ent_embeddings(batch_h) 143 | t = self.ent_embeddings(batch_t) 144 | r = self.rel_embeddings(batch_r) 145 | regul = (torch.mean(h ** 2) + 146 | torch.mean(t ** 2) + 147 | torch.mean(r ** 2)) / 3 148 | return regul 149 | 150 | def predict(self, data): 151 | score = self.forward(data) 152 | if self.margin_flag: 153 | score = self.margin - score 154 | return score.cpu().data.numpy() 155 | else: 156 | return score.cpu().data.numpy() 157 | 158 | def set_test_mode(self, new_mode): 159 | self.test_mode = new_mode 160 | 161 | -------------------------------------------------------------------------------- /mmkgc/module/model/MMRotatE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd as autograd 3 | import torch.nn as nn 4 | import time 5 | from .Model import Model 6 | 7 | class MMRotatE(Model): 8 | 9 | def __init__(self, ent_tot, rel_tot, dim=100, margin=6.0, epsilon=2.0, img_emb=None, img_dim=4096, test_mode='lp', beta=0.5): 10 | super(MMRotatE, self).__init__(ent_tot, rel_tot) 11 | 12 | self.margin = margin 13 | self.epsilon = epsilon 14 | 15 | self.dim_e = dim * 2 16 | self.dim_r = dim 17 | 18 | self.ent_embeddings = nn.Embedding(self.ent_tot, self.dim_e) 19 | self.rel_embeddings = nn.Embedding(self.rel_tot, self.dim_r) 20 | self.img_dim = img_dim 21 | self.img_proj = nn.Linear(self.img_dim, self.dim_e) 22 | self.img_embeddings = nn.Embedding.from_pretrained(img_emb).requires_grad_(False) 23 | self.test_mode = test_mode 24 | self.beta = beta 25 | # self.log_file = open('{}.txt'.format(time.time()), 'w') 26 | 27 | self.ent_embedding_range = nn.Parameter( 28 | torch.Tensor([(self.margin + self.epsilon) / self.dim_e]), 29 | requires_grad=False 30 | ) 31 | 32 | 33 | nn.init.uniform_( 34 | tensor = self.ent_embeddings.weight.data, 35 | a=-self.ent_embedding_range.item(), 36 | b=self.ent_embedding_range.item() 37 | ) 38 | 39 | self.rel_embedding_range = nn.Parameter( 40 | torch.Tensor([(self.margin + self.epsilon) / self.dim_r]), 41 | requires_grad=False 42 | ) 43 | 44 | nn.init.uniform_( 45 | tensor = self.rel_embeddings.weight.data, 46 | a=-self.rel_embedding_range.item(), 47 | b=self.rel_embedding_range.item() 48 | ) 49 | 50 | self.margin = nn.Parameter(torch.Tensor([margin])) 51 | self.margin.requires_grad = False 52 | 53 | def _calc(self, h, t, r, mode): 54 | pi = self.pi_const 55 | 56 | re_head, im_head = torch.chunk(h, 2, dim=-1) 57 | re_tail, im_tail = torch.chunk(t, 2, dim=-1) 58 | 59 | phase_relation = r / (self.rel_embedding_range.item() / pi) 60 | 61 | re_relation = torch.cos(phase_relation) 62 | im_relation = torch.sin(phase_relation) 63 | 64 | re_head = re_head.view(-1, re_relation.shape[0], re_head.shape[-1]).permute(1, 0, 2) 65 | re_tail = re_tail.view(-1, re_relation.shape[0], re_tail.shape[-1]).permute(1, 0, 2) 66 | im_head = im_head.view(-1, re_relation.shape[0], im_head.shape[-1]).permute(1, 0, 2) 67 | im_tail = im_tail.view(-1, re_relation.shape[0], im_tail.shape[-1]).permute(1, 0, 2) 68 | im_relation = im_relation.view(-1, re_relation.shape[0], im_relation.shape[-1]).permute(1, 0, 2) 69 | re_relation = re_relation.view(-1, re_relation.shape[0], re_relation.shape[-1]).permute(1, 0, 2) 70 | 71 | if mode == "head_batch": 72 | re_score = re_relation * re_tail + im_relation * im_tail 73 | im_score = re_relation * im_tail - im_relation * re_tail 74 | re_score = re_score - re_head 75 | im_score = im_score - im_head 76 | else: 77 | re_score = re_head * re_relation - im_head * im_relation 78 | im_score = re_head * im_relation + im_head * re_relation 79 | re_score = re_score - re_tail 80 | im_score = im_score - im_tail 81 | 82 | score = torch.stack([re_score, im_score], dim = 0) 83 | score = score.norm(dim = 0).sum(dim = -1) 84 | return score.permute(1, 0).flatten() 85 | 86 | def forward(self, data): 87 | batch_h = data['batch_h'] 88 | batch_t = data['batch_t'] 89 | batch_r = data['batch_r'] 90 | mode = data['mode'] 91 | h_ent, h_img, t_ent, t_img = batch_h, batch_h, batch_t, batch_t 92 | h = self.ent_embeddings(h_ent) 93 | t = self.ent_embeddings(t_ent) 94 | # h_img, t_img = batch_h, batch_t 95 | r = self.rel_embeddings(batch_r) 96 | h_img_emb = self.img_proj(self.img_embeddings(h_img)) 97 | t_img_emb = self.img_proj(self.img_embeddings(t_img)) 98 | 99 | score = ( 100 | self._calc(h, t, r, mode) 101 | + self._calc(h_img_emb, t_img_emb, r, mode) 102 | + self._calc(h_img_emb, t, r, mode) 103 | + self._calc(h, t_img_emb, r, mode) 104 | ) 105 | 106 | # score = self._calc(h_img_emb, t, r, mode) + self._calc(h, t_img_emb, r, mode) 107 | score = self.margin - score 108 | return score 109 | 110 | def cross_modal_score_ent2img(self, data): 111 | batch_h = data['batch_h'] 112 | batch_t = data['batch_t'] 113 | batch_r = data['batch_r'] 114 | h_ent, h_img, t_ent, t_img = batch_h, batch_h, batch_t, batch_t 115 | mode = data['mode'] 116 | h = self.ent_embeddings(h_ent) 117 | r = self.rel_embeddings(batch_r) 118 | t_img_emb = self.img_proj(self.img_embeddings(t_img)) 119 | # 跨模态链接预测的过程中,只考虑h+r和尾部图像的匹配度 120 | score = self._calc(h, t_img_emb, r, mode) 121 | score = self.margin - score 122 | return score 123 | 124 | def predict(self, data): 125 | if self.test_mode == "cmlp": 126 | score = -self.cross_modal_score_ent2img(data) 127 | else: 128 | score = -self.forward(data) 129 | return score.cpu().data.numpy() 130 | 131 | def regularization(self, data): 132 | batch_h = data['batch_h'] 133 | batch_t = data['batch_t'] 134 | batch_r = data['batch_r'] 135 | h = self.ent_embeddings(batch_h) 136 | t = self.ent_embeddings(batch_t) 137 | r = self.rel_embeddings(batch_r) 138 | regul = (torch.mean(h ** 2) + 139 | torch.mean(t ** 2) + 140 | torch.mean(r ** 2)) / 3 141 | return regul 142 | -------------------------------------------------------------------------------- /mmkgc/module/model/Model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ..BaseModule import BaseModule 4 | 5 | 6 | class Model(BaseModule): 7 | 8 | def __init__(self, ent_tot, rel_tot): 9 | super(Model, self).__init__() 10 | self.ent_tot = ent_tot 11 | self.rel_tot = rel_tot 12 | 13 | def forward(self): 14 | raise NotImplementedError 15 | 16 | def predict(self): 17 | raise NotImplementedError 18 | -------------------------------------------------------------------------------- /mmkgc/module/model/RSME.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .Model import Model 5 | 6 | class RSME(Model): 7 | def __init__(self, ent_tot, rel_tot, dim=128, img_dim=768, img_emb=None): 8 | super(RSME, self).__init__(ent_tot, rel_tot) 9 | 10 | self.dim = dim 11 | self.img_dim = img_dim 12 | self.ent_re_embeddings = nn.Embedding(self.ent_tot, self.dim) 13 | self.ent_im_embeddings = nn.Embedding(self.ent_tot, self.dim) 14 | self.rel_re_embeddings = nn.Embedding(self.rel_tot, 2 * self.dim) 15 | self.rel_im_embeddings = nn.Embedding(self.rel_tot, 2 * self.dim) 16 | self.img_embeddings = nn.Embedding.from_pretrained(img_emb).requires_grad_(False) 17 | self.img_proj = nn.Linear(img_dim, 2 * dim) 18 | self.beta = 0.95 19 | 20 | nn.init.xavier_uniform_(self.ent_re_embeddings.weight.data) 21 | nn.init.xavier_uniform_(self.ent_im_embeddings.weight.data) 22 | nn.init.xavier_uniform_(self.rel_re_embeddings.weight.data) 23 | nn.init.xavier_uniform_(self.rel_im_embeddings.weight.data) 24 | 25 | def _calc(self, h_re, h_im, t_re, t_im, r_re, r_im): 26 | return torch.sum( 27 | h_re * t_re * r_re 28 | + h_im * t_im * r_re 29 | + h_re * t_im * r_im 30 | - h_im * t_re * r_im, 31 | -1 32 | ) 33 | 34 | def get_batch_ent_embs(self, data): 35 | e_re = self.ent_re_embeddings(data) 36 | e_im = self.ent_re_embeddings(data) 37 | return torch.cat((e_re, e_im), dim=-1) 38 | 39 | def get_fake_score( 40 | self, 41 | batch_h, 42 | batch_r, 43 | batch_t, 44 | mode, 45 | fake_hv=None, 46 | fake_tv=None 47 | ): 48 | if fake_tv is None: 49 | raise NotImplementedError 50 | h_re = self.ent_re_embeddings(batch_h) 51 | h_im = self.ent_im_embeddings(batch_h) 52 | t_re = self.ent_re_embeddings(batch_t) 53 | t_im = self.ent_im_embeddings(batch_t) 54 | r_re = self.rel_re_embeddings(batch_r) 55 | r_im = self.rel_im_embeddings(batch_r) 56 | h_img = self.img_proj(self.img_embeddings(batch_h)) 57 | t_img = fake_tv 58 | h_re = torch.cat((h_re, h_img[:, 0: self.dim]), dim=-1) 59 | h_im = torch.cat((h_im, h_img[:, self.dim:]), dim=-1) 60 | t_re = torch.cat((t_re, t_img[:, 0: self.dim]), dim=-1) 61 | t_im = torch.cat((t_im, t_img[:, self.dim:]), dim=-1) 62 | score1 = self._calc(h_re, h_im, t_re, t_im, r_re, r_im) 63 | score2 = F.cosine_similarity(h_img, t_img, dim=-1) 64 | score = self.beta * score1 + (1 - self.beta) * score2 65 | return [score] 66 | 67 | 68 | def forward(self, data): 69 | batch_h = data['batch_h'] 70 | batch_t = data['batch_t'] 71 | batch_r = data['batch_r'] 72 | h_re = self.ent_re_embeddings(batch_h) 73 | h_im = self.ent_im_embeddings(batch_h) 74 | t_re = self.ent_re_embeddings(batch_t) 75 | t_im = self.ent_im_embeddings(batch_t) 76 | r_re = self.rel_re_embeddings(batch_r) 77 | r_im = self.rel_im_embeddings(batch_r) 78 | 79 | h_img = self.img_proj(self.img_embeddings(batch_h)) 80 | t_img = self.img_proj(self.img_embeddings(batch_t)) 81 | h_re = torch.cat((h_re, h_img[:, 0: self.dim]), dim=-1) 82 | h_im = torch.cat((h_im, h_img[:, self.dim:]), dim=-1) 83 | t_re = torch.cat((t_re, t_img[:, 0: self.dim]), dim=-1) 84 | t_im = torch.cat((t_im, t_img[:, self.dim:]), dim=-1) 85 | score1 = self._calc(h_re, h_im, t_re, t_im, r_re, r_im) 86 | score2 = F.cosine_similarity(h_img, t_img, dim=-1) 87 | score = self.beta * score1 + (1 - self.beta) * score2 88 | return score 89 | 90 | def regularization(self, data): 91 | batch_h = data['batch_h'] 92 | batch_t = data['batch_t'] 93 | batch_r = data['batch_r'] 94 | h_re = self.ent_re_embeddings(batch_h) 95 | h_im = self.ent_im_embeddings(batch_h) 96 | t_re = self.ent_re_embeddings(batch_t) 97 | t_im = self.ent_im_embeddings(batch_t) 98 | r_re = self.rel_re_embeddings(batch_r) 99 | r_im = self.rel_im_embeddings(batch_r) 100 | h_img = self.img_proj(self.img_embeddings(batch_h)) 101 | t_img = self.img_proj(self.img_embeddings(batch_t)) 102 | regul = (torch.mean(h_re ** 2) + 103 | torch.mean(h_im ** 2) + 104 | torch.mean(t_re ** 2) + 105 | torch.mean(t_im ** 2) + 106 | torch.mean(r_re ** 2) + 107 | torch.mean(r_im ** 2) + 108 | torch.mean(h_img ** 2) + 109 | torch.mean(t_img ** 2) 110 | ) / 8 111 | return regul 112 | 113 | def predict(self, data): 114 | score = -self.forward(data) 115 | return score.cpu().data.numpy() -------------------------------------------------------------------------------- /mmkgc/module/model/RotatE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd as autograd 3 | import torch.nn as nn 4 | from .Model import Model 5 | 6 | class RotatE(Model): 7 | 8 | def __init__(self, ent_tot, rel_tot, dim = 100, margin = 6.0, epsilon = 2.0): 9 | super(RotatE, self).__init__(ent_tot, rel_tot) 10 | 11 | self.margin = margin 12 | self.epsilon = epsilon 13 | 14 | self.dim_e = dim * 2 15 | self.dim_r = dim 16 | 17 | self.ent_embeddings = nn.Embedding(self.ent_tot, self.dim_e) 18 | self.rel_embeddings = nn.Embedding(self.rel_tot, self.dim_r) 19 | 20 | self.ent_embedding_range = nn.Parameter( 21 | torch.Tensor([(self.margin + self.epsilon) / self.dim_e]), 22 | requires_grad=False 23 | ) 24 | 25 | nn.init.uniform_( 26 | tensor = self.ent_embeddings.weight.data, 27 | a=-self.ent_embedding_range.item(), 28 | b=self.ent_embedding_range.item() 29 | ) 30 | 31 | self.rel_embedding_range = nn.Parameter( 32 | torch.Tensor([(self.margin + self.epsilon) / self.dim_r]), 33 | requires_grad=False 34 | ) 35 | 36 | nn.init.uniform_( 37 | tensor = self.rel_embeddings.weight.data, 38 | a=-self.rel_embedding_range.item(), 39 | b=self.rel_embedding_range.item() 40 | ) 41 | 42 | self.margin = nn.Parameter(torch.Tensor([margin])) 43 | self.margin.requires_grad = False 44 | 45 | def _calc(self, h, t, r, mode): 46 | pi = self.pi_const 47 | 48 | re_head, im_head = torch.chunk(h, 2, dim=-1) 49 | re_tail, im_tail = torch.chunk(t, 2, dim=-1) 50 | 51 | phase_relation = r / (self.rel_embedding_range.item() / pi) 52 | 53 | re_relation = torch.cos(phase_relation) 54 | im_relation = torch.sin(phase_relation) 55 | 56 | re_head = re_head.view(-1, re_relation.shape[0], re_head.shape[-1]).permute(1, 0, 2) 57 | re_tail = re_tail.view(-1, re_relation.shape[0], re_tail.shape[-1]).permute(1, 0, 2) 58 | im_head = im_head.view(-1, re_relation.shape[0], im_head.shape[-1]).permute(1, 0, 2) 59 | im_tail = im_tail.view(-1, re_relation.shape[0], im_tail.shape[-1]).permute(1, 0, 2) 60 | im_relation = im_relation.view(-1, re_relation.shape[0], im_relation.shape[-1]).permute(1, 0, 2) 61 | re_relation = re_relation.view(-1, re_relation.shape[0], re_relation.shape[-1]).permute(1, 0, 2) 62 | 63 | if mode == "head_batch": 64 | re_score = re_relation * re_tail + im_relation * im_tail 65 | im_score = re_relation * im_tail - im_relation * re_tail 66 | re_score = re_score - re_head 67 | im_score = im_score - im_head 68 | else: 69 | re_score = re_head * re_relation - im_head * im_relation 70 | im_score = re_head * im_relation + im_head * re_relation 71 | re_score = re_score - re_tail 72 | im_score = im_score - im_tail 73 | 74 | score = torch.stack([re_score, im_score], dim = 0) 75 | score = score.norm(dim = 0).sum(dim = -1) 76 | return score.permute(1, 0).flatten() 77 | 78 | def forward(self, data): 79 | batch_h = data['batch_h'] 80 | batch_t = data['batch_t'] 81 | batch_r = data['batch_r'] 82 | mode = data['mode'] 83 | h = self.ent_embeddings(batch_h) 84 | t = self.ent_embeddings(batch_t) 85 | r = self.rel_embeddings(batch_r) 86 | score = self.margin - self._calc(h ,t, r, mode) 87 | return score 88 | 89 | def predict(self, data): 90 | score = -self.forward(data) 91 | return score.cpu().data.numpy() 92 | 93 | def regularization(self, data): 94 | batch_h = data['batch_h'] 95 | batch_t = data['batch_t'] 96 | batch_r = data['batch_r'] 97 | h = self.ent_embeddings(batch_h) 98 | t = self.ent_embeddings(batch_t) 99 | r = self.rel_embeddings(batch_r) 100 | regul = (torch.mean(h ** 2) + 101 | torch.mean(t ** 2) + 102 | torch.mean(r ** 2)) / 3 103 | return regul -------------------------------------------------------------------------------- /mmkgc/module/model/TBKGC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import time 5 | from .Model import Model 6 | 7 | 8 | class TBKGC(Model): 9 | 10 | def __init__(self, ent_tot, rel_tot, dim=100, p_norm=1, img_emb=None, 11 | img_dim=4096, norm_flag=True, margin=None, epsilon=None, 12 | text_emb=None): 13 | super(TBKGC, self).__init__(ent_tot, rel_tot) 14 | assert img_emb is not None 15 | assert text_emb is not None 16 | self.dim = dim 17 | self.margin = margin 18 | self.epsilon = epsilon 19 | self.norm_flag = norm_flag 20 | self.p_norm = p_norm 21 | self.img_dim = img_dim 22 | self.text_dim = text_emb.shape[1] 23 | self.ent_embeddings = nn.Embedding(self.ent_tot, self.dim) 24 | self.rel_embeddings = nn.Embedding(self.rel_tot, self.dim) 25 | # 新增的投影矩阵和图像embeddings 26 | self.img_proj = nn.Linear(self.img_dim, self.dim // 2) 27 | self.img_embeddings = nn.Embedding.from_pretrained(img_emb).requires_grad_(True) 28 | self.text_proj = nn.Linear(self.text_dim, self.dim // 2) 29 | self.text_embeddings = nn.Embedding.from_pretrained(text_emb).requires_grad_(True) 30 | 31 | if margin is None or epsilon is None: 32 | nn.init.xavier_uniform_(self.ent_embeddings.weight.data) 33 | nn.init.xavier_uniform_(self.rel_embeddings.weight.data) 34 | else: 35 | self.embedding_range = nn.Parameter( 36 | torch.Tensor([(self.margin + self.epsilon) / self.dim]), requires_grad=False 37 | ) 38 | nn.init.uniform_( 39 | tensor=self.ent_embeddings.weight.data, 40 | a=-self.embedding_range.item(), 41 | b=self.embedding_range.item() 42 | ) 43 | nn.init.uniform_( 44 | tensor=self.rel_embeddings.weight.data, 45 | a=-self.embedding_range.item(), 46 | b=self.embedding_range.item() 47 | ) 48 | 49 | if margin is not None: 50 | self.margin = nn.Parameter(torch.Tensor([margin])) 51 | self.margin.requires_grad = False 52 | self.margin_flag = True 53 | else: 54 | self.margin_flag = False 55 | 56 | def _calc(self, h, t, r, mode): 57 | if self.norm_flag: 58 | h = F.normalize(h, 2, -1) 59 | r = F.normalize(r, 2, -1) 60 | t = F.normalize(t, 2, -1) 61 | if mode != 'normal': 62 | h = h.view(-1, r.shape[0], h.shape[-1]) 63 | t = t.view(-1, r.shape[0], t.shape[-1]) 64 | r = r.view(-1, r.shape[0], r.shape[-1]) 65 | if mode == 'head_batch': 66 | score = h + (r - t) 67 | else: 68 | score = (h + r) - t 69 | score = torch.norm(score, self.p_norm, -1).flatten() 70 | return score 71 | 72 | def get_batch_ent_embs(self, data): 73 | return self.ent_embeddings(data) 74 | 75 | def get_fake_score( 76 | self, 77 | batch_h, 78 | batch_r, 79 | batch_t, 80 | mode, 81 | fake_hv=None, 82 | fake_tv=None 83 | ): 84 | if fake_hv is None or fake_tv is None: 85 | raise NotImplementedError 86 | h = self.ent_embeddings(batch_h) 87 | t = self.ent_embeddings(batch_t) 88 | r = self.rel_embeddings(batch_r) 89 | h_img_emb = self.img_proj(self.img_embeddings(batch_h)) 90 | t_img_emb = self.img_proj(self.img_embeddings(batch_t)) 91 | h_text_emb = self.text_proj(self.text_embeddings(batch_h)) 92 | t_text_emb = self.text_proj(self.text_embeddings(batch_t)) 93 | h_multimodal = torch.cat((h_img_emb, h_text_emb), dim=-1) 94 | t_multimodal = torch.cat((t_img_emb, t_text_emb), dim=-1) 95 | # three kinds of fake score 96 | score_hv = ( 97 | self._calc(h, t, r, mode) 98 | + self._calc(fake_hv, t_multimodal, r, mode) 99 | + self._calc(fake_hv, t, r, mode) 100 | + self._calc(h, t_multimodal, r, mode) 101 | + self._calc(h + fake_hv, t + t_multimodal, r, mode) 102 | ) 103 | score_tv = ( 104 | self._calc(h, t, r, mode) 105 | + self._calc(h_multimodal, fake_tv, r, mode) 106 | + self._calc(h_multimodal, t, r, mode) 107 | + self._calc(h, fake_tv, r, mode) 108 | + self._calc(h + h_multimodal, t + fake_tv, r, mode) 109 | ) 110 | score_htv = ( 111 | self._calc(h, t, r, mode) 112 | + self._calc(fake_hv, fake_tv, r, mode) 113 | + self._calc(fake_hv, t, r, mode) 114 | + self._calc(h, fake_tv, r, mode) 115 | + self._calc(h + fake_hv, t + fake_tv, r, mode) 116 | ) 117 | return [self.margin - score_hv, self.margin - score_tv, self.margin - score_htv], [h_multimodal, t_multimodal] 118 | 119 | 120 | def forward(self, data): 121 | batch_h = data['batch_h'] 122 | batch_t = data['batch_t'] 123 | batch_r = data['batch_r'] 124 | h_ent, h_img, t_ent, t_img = batch_h, batch_h, batch_t, batch_t 125 | mode = data['mode'] 126 | h = self.ent_embeddings(h_ent) 127 | t = self.ent_embeddings(t_ent) 128 | r = self.rel_embeddings(batch_r) 129 | h_img_emb = self.img_proj(self.img_embeddings(h_img)) 130 | t_img_emb = self.img_proj(self.img_embeddings(t_img)) 131 | h_text_emb = self.text_proj(self.text_embeddings(h_img)) 132 | t_text_emb = self.text_proj(self.text_embeddings(t_img)) 133 | h_multimodal = torch.cat((h_img_emb, h_text_emb), dim=-1) 134 | t_multimodal = torch.cat((t_img_emb, t_text_emb), dim=-1) 135 | score = ( 136 | self._calc(h, t, r, mode) 137 | + self._calc(h_multimodal, t_multimodal, r, mode) 138 | + self._calc(h_multimodal, t, r, mode) 139 | + self._calc(h, t_multimodal, r, mode) 140 | + self._calc(h + h_multimodal, t + t_multimodal, r, mode) 141 | ) 142 | if self.margin_flag: 143 | return self.margin - score 144 | else: 145 | return score 146 | 147 | 148 | def regularization(self, data): 149 | batch_h = data['batch_h'] 150 | batch_t = data['batch_t'] 151 | batch_r = data['batch_r'] 152 | h = self.ent_embeddings(batch_h) 153 | t = self.ent_embeddings(batch_t) 154 | r = self.rel_embeddings(batch_r) 155 | regul = (torch.mean(h ** 2) + 156 | torch.mean(t ** 2) + 157 | torch.mean(r ** 2)) / 3 158 | return regul 159 | 160 | def predict(self, data): 161 | score = self.forward(data) 162 | if self.margin_flag: 163 | score = self.margin - score 164 | return score.cpu().data.numpy() 165 | else: 166 | return score.cpu().data.numpy() 167 | 168 | def set_test_mode(self, new_mode): 169 | self.test_mode = new_mode 170 | 171 | -------------------------------------------------------------------------------- /mmkgc/module/model/TransAE.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .Model import Model 6 | 7 | class IMG_Encoder(nn.Module): 8 | def __init__(self, embedding_dim = 4096, dim = 200, margin = None, epsilon = None, dataset=None): 9 | assert dataset is not None 10 | super(IMG_Encoder, self).__init__() 11 | with open('./benchmarks/{}/entity2id.txt'.format(dataset)) as fp: 12 | entity2id = fp.readlines()[1:] 13 | entity2id = [i.split('\t')[0] for i in entity2id] 14 | self.entity2id = entity2id 15 | self.activation = nn.ReLU() 16 | self.entity_count = len(entity2id) 17 | self.dim = dim 18 | self.margin = margin 19 | self.embedding_dim = embedding_dim 20 | self.criterion = nn.MSELoss(reduction='mean') 21 | self.raw_embedding = nn.Embedding(self.entity_count, self.dim) 22 | visual_embs = torch.load("./embeddings/{}-visual.pth".format(dataset)) 23 | self.visual_embedding = nn.Embedding.from_pretrained(visual_embs) 24 | 25 | self.encoder = nn.Sequential( 26 | torch.nn.Linear(embedding_dim, 192), 27 | self.activation 28 | ) 29 | 30 | self.encoder2 = nn.Sequential( 31 | torch.nn.Linear(192, self.dim), 32 | self.activation 33 | ) 34 | 35 | self.decoder2 = nn.Sequential( 36 | torch.nn.Linear(self.dim, 192), 37 | self.activation 38 | ) 39 | 40 | self.decoder = nn.Sequential( 41 | torch.nn.Linear(192, embedding_dim), 42 | self.activation 43 | ) 44 | 45 | def _init_embedding(self): 46 | self.ent_embeddings = nn.Embedding(self.entity_count, self.embedding_dim) 47 | for param in self.ent_embeddings.parameters(): 48 | param.requires_grad = False 49 | nn.init.xavier_uniform_(self.ent_embeddings.weight.data) 50 | 51 | def forward(self, entity_id): 52 | v1 = self.visual_embedding(entity_id) 53 | v2 = self.encoder(v1) 54 | 55 | v2_ = self.encoder2(v2) 56 | v3_ = self.decoder2(v2_) 57 | 58 | v3 = self.decoder(v3_) 59 | loss = self.criterion(v1, v3) 60 | return v2_, loss 61 | 62 | class TransAE(nn.Module): 63 | def __init__(self, ent_tot, rel_tot, dim = 100, p_norm = 1, norm_flag = True, margin = None, epsilon = None, dataset=None, embedding_dim=None): 64 | super(TransAE, self).__init__() 65 | self.dataset = dataset 66 | self.ent_tot = ent_tot 67 | self.rel_tot = rel_tot 68 | self.dim = dim 69 | self.margin = margin 70 | self.epsilon = epsilon 71 | self.norm_flag = norm_flag 72 | self.p_norm = p_norm 73 | 74 | self.tail_embeddings = nn.Embedding(self.ent_tot, self.dim) 75 | self.rel_embeddings = nn.Embedding(self.rel_tot, self.dim) 76 | self.ent_embeddings = IMG_Encoder(dim = self.dim, margin = self.margin, epsilon = self.epsilon, dataset=dataset, embedding_dim=embedding_dim) 77 | 78 | if margin == None or epsilon == None: 79 | nn.init.xavier_uniform_(self.tail_embeddings.weight.data) 80 | nn.init.xavier_uniform_(self.rel_embeddings.weight.data) 81 | else: 82 | self.embedding_range = nn.Parameter( 83 | torch.Tensor([(self.margin + self.epsilon) / self.dim]), requires_grad=False 84 | ) 85 | nn.init.uniform_( 86 | tensor = self.ent_embeddings.weight.data, 87 | a = -self.embedding_range.item(), 88 | b = self.embedding_range.item() 89 | ) 90 | nn.init.uniform_( 91 | tensor = self.rel_embeddings.weight.data, 92 | a= -self.embedding_range.item(), 93 | b= self.embedding_range.item() 94 | ) 95 | if margin != None: 96 | self.margin = nn.Parameter(torch.Tensor([margin])) 97 | self.margin.requires_grad = False 98 | self.margin_flag = True 99 | else: 100 | self.margin_flag = False 101 | 102 | def _calc(self, h, t, r, mode): 103 | if self.norm_flag: 104 | h = F.normalize(h, 2, -1) 105 | r = F.normalize(r, 2, -1) 106 | t = F.normalize(t, 2, -1) 107 | if mode != 'normal': 108 | h = h.view(-1, r.shape[0], h.shape[-1]) 109 | t = t.view(-1, r.shape[0], t.shape[-1]) 110 | r = r.view(-1, r.shape[0], r.shape[-1]) 111 | if mode == 'head_batch': 112 | score = h + (r - t) 113 | else: 114 | score = (h + r) - t 115 | score = torch.norm(score, self.p_norm, -1).flatten() 116 | return score 117 | 118 | def forward(self, data): 119 | batch_h = data['batch_h'] 120 | batch_t = data['batch_t'] 121 | batch_r = data['batch_r'] 122 | mode = data['mode'] 123 | h = self.tail_embeddings(batch_h) 124 | t = self.tail_embeddings(batch_t) 125 | r = self.rel_embeddings(batch_r) 126 | score = self._calc(h, t, r, mode) 127 | if self.margin_flag: 128 | return self.margin - score, 0 129 | else: 130 | return score, 0 131 | 132 | def regularization(self, data): 133 | batch_h = data['batch_h'] 134 | batch_t = data['batch_t'] 135 | batch_r = data['batch_r'] 136 | h = self.ent_embeddings(batch_h) 137 | t = self.ent_embeddings(batch_t) 138 | r = self.rel_embeddings(batch_r) 139 | regul = (torch.mean(h ** 2) + 140 | torch.mean(t ** 2) + 141 | torch.mean(r ** 2)) / 3 142 | return regul 143 | 144 | def predict(self, data): 145 | score = self.forward(data)[0] 146 | if self.margin_flag: 147 | score = self.margin - score 148 | return score.cpu().data.numpy() 149 | else: 150 | return score.cpu().data.numpy() 151 | 152 | def load_checkpoint(self, path): 153 | self.load_state_dict(torch.load(os.path.join(path))) 154 | self.eval() 155 | 156 | def save_checkpoint(self, path): 157 | torch.save(self.state_dict(), path) -------------------------------------------------------------------------------- /mmkgc/module/model/TransE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .Model import Model 5 | 6 | 7 | class TransE(Model): 8 | 9 | def __init__(self, ent_tot, rel_tot, dim=100, p_norm=1, norm_flag=True, margin=None, epsilon=None): 10 | super(TransE, self).__init__(ent_tot, rel_tot) 11 | 12 | self.dim = dim 13 | self.margin = margin 14 | self.epsilon = epsilon 15 | self.norm_flag = norm_flag 16 | self.p_norm = p_norm 17 | 18 | self.ent_embeddings = nn.Embedding(self.ent_tot, self.dim) 19 | self.rel_embeddings = nn.Embedding(self.rel_tot, self.dim) 20 | 21 | if margin == None or epsilon is None: 22 | nn.init.xavier_uniform_(self.ent_embeddings.weight.data) 23 | nn.init.xavier_uniform_(self.rel_embeddings.weight.data) 24 | else: 25 | self.embedding_range = nn.Parameter( 26 | torch.Tensor([(self.margin + self.epsilon) / self.dim]), requires_grad=False 27 | ) 28 | nn.init.uniform_( 29 | tensor=self.ent_embeddings.weight.data, 30 | a=-self.embedding_range.item(), 31 | b=self.embedding_range.item() 32 | ) 33 | nn.init.uniform_( 34 | tensor=self.rel_embeddings.weight.data, 35 | a=-self.embedding_range.item(), 36 | b=self.embedding_range.item() 37 | ) 38 | 39 | if margin != None: 40 | self.margin = nn.Parameter(torch.Tensor([margin])) 41 | self.margin.requires_grad = False 42 | self.margin_flag = True 43 | else: 44 | self.margin_flag = False 45 | 46 | def _calc(self, h, t, r, mode): 47 | if self.norm_flag: 48 | h = F.normalize(h, 2, -1) 49 | r = F.normalize(r, 2, -1) 50 | t = F.normalize(t, 2, -1) 51 | if mode != 'normal': 52 | h = h.view(-1, r.shape[0], h.shape[-1]) 53 | t = t.view(-1, r.shape[0], t.shape[-1]) 54 | r = r.view(-1, r.shape[0], r.shape[-1]) 55 | if mode == 'head_batch': 56 | score = h + (r - t) 57 | else: 58 | score = (h + r) - t 59 | score = torch.norm(score, self.p_norm, -1).flatten() 60 | return score 61 | 62 | def forward(self, data): 63 | batch_h = data['batch_h'] 64 | batch_t = data['batch_t'] 65 | batch_r = data['batch_r'] 66 | mode = data['mode'] 67 | h = self.ent_embeddings(batch_h) 68 | t = self.ent_embeddings(batch_t) 69 | r = self.rel_embeddings(batch_r) 70 | score = self._calc(h, t, r, mode) 71 | if self.margin_flag: 72 | return self.margin - score 73 | else: 74 | return score 75 | 76 | def regularization(self, data): 77 | batch_h = data['batch_h'] 78 | batch_t = data['batch_t'] 79 | batch_r = data['batch_r'] 80 | h = self.ent_embeddings(batch_h) 81 | t = self.ent_embeddings(batch_t) 82 | r = self.rel_embeddings(batch_r) 83 | regul = (torch.mean(h ** 2) + 84 | torch.mean(t ** 2) + 85 | torch.mean(r ** 2)) / 3 86 | return regul 87 | 88 | def predict(self, data): 89 | score = self.forward(data) 90 | if self.margin_flag: 91 | score = self.margin - score 92 | return score.cpu().data.numpy() 93 | else: 94 | return score.cpu().data.numpy() 95 | -------------------------------------------------------------------------------- /mmkgc/module/model/VBRotatE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd as autograd 3 | import torch.nn as nn 4 | from .Model import Model 5 | 6 | class VBRotatE(Model): 7 | 8 | def __init__(self, ent_tot, rel_tot, dim=100, margin=6.0, epsilon=2.0, img_emb=None, img_dim=4096, test_mode='lp'): 9 | super(VBRotatE, self).__init__(ent_tot, rel_tot) 10 | 11 | self.margin = margin 12 | self.epsilon = epsilon 13 | 14 | self.dim_e = dim * 2 15 | self.dim_r = dim 16 | 17 | self.ent_embeddings = nn.Embedding(self.ent_tot, self.dim_e) 18 | self.rel_embeddings = nn.Embedding(self.rel_tot, self.dim_r) 19 | self.img_dim = img_dim 20 | self.img_proj = nn.Linear(self.img_dim, self.dim_e) 21 | self.img_embeddings = nn.Embedding.from_pretrained(img_emb) 22 | self.test_mode = test_mode 23 | 24 | self.ent_embedding_range = nn.Parameter( 25 | torch.Tensor([(self.margin + self.epsilon) / self.dim_e]), 26 | requires_grad=False 27 | ) 28 | 29 | 30 | nn.init.uniform_( 31 | tensor = self.ent_embeddings.weight.data, 32 | a=-self.ent_embedding_range.item(), 33 | b=self.ent_embedding_range.item() 34 | ) 35 | 36 | self.rel_embedding_range = nn.Parameter( 37 | torch.Tensor([(self.margin + self.epsilon) / self.dim_r]), 38 | requires_grad=False 39 | ) 40 | 41 | nn.init.uniform_( 42 | tensor = self.rel_embeddings.weight.data, 43 | a=-self.rel_embedding_range.item(), 44 | b=self.rel_embedding_range.item() 45 | ) 46 | 47 | self.margin = nn.Parameter(torch.Tensor([margin])) 48 | self.margin.requires_grad = False 49 | 50 | def _calc(self, h, t, r, mode): 51 | pi = self.pi_const 52 | 53 | re_head, im_head = torch.chunk(h, 2, dim=-1) 54 | re_tail, im_tail = torch.chunk(t, 2, dim=-1) 55 | 56 | phase_relation = r / (self.rel_embedding_range.item() / pi) 57 | 58 | re_relation = torch.cos(phase_relation) 59 | im_relation = torch.sin(phase_relation) 60 | 61 | re_head = re_head.view(-1, re_relation.shape[0], re_head.shape[-1]).permute(1, 0, 2) 62 | re_tail = re_tail.view(-1, re_relation.shape[0], re_tail.shape[-1]).permute(1, 0, 2) 63 | im_head = im_head.view(-1, re_relation.shape[0], im_head.shape[-1]).permute(1, 0, 2) 64 | im_tail = im_tail.view(-1, re_relation.shape[0], im_tail.shape[-1]).permute(1, 0, 2) 65 | im_relation = im_relation.view(-1, re_relation.shape[0], im_relation.shape[-1]).permute(1, 0, 2) 66 | re_relation = re_relation.view(-1, re_relation.shape[0], re_relation.shape[-1]).permute(1, 0, 2) 67 | 68 | if mode == "head_batch": 69 | re_score = re_relation * re_tail + im_relation * im_tail 70 | im_score = re_relation * im_tail - im_relation * re_tail 71 | re_score = re_score - re_head 72 | im_score = im_score - im_head 73 | else: 74 | re_score = re_head * re_relation - im_head * im_relation 75 | im_score = re_head * im_relation + im_head * re_relation 76 | re_score = re_score - re_tail 77 | im_score = im_score - im_tail 78 | 79 | score = torch.stack([re_score, im_score], dim = 0) 80 | score = score.norm(dim = 0).sum(dim = -1) 81 | return score.permute(1, 0).flatten() 82 | 83 | def forward(self, data, batch_size, neg_mode='normal', neg_num=1): 84 | h_ent, h_img, t_ent, t_img = None, None, None, None 85 | batch_h = data['batch_h'] 86 | batch_t = data['batch_t'] 87 | batch_r = data['batch_r'] 88 | h_ent, h_img, t_ent, t_img = batch_h, batch_h, batch_t, batch_t 89 | mode = data['mode'] 90 | h = self.ent_embeddings(h_ent) 91 | t = self.ent_embeddings(t_ent) 92 | r = self.rel_embeddings(batch_r) 93 | h_img_emb = self.img_proj(self.img_embeddings(h_img)) 94 | t_img_emb = self.img_proj(self.img_embeddings(t_img)) 95 | # print(h.shape, t.shape, r.shape, h_img_emb.shape, t_img_emb.shape) 96 | score = ( 97 | self._calc(h, t, r, mode) 98 | + self._calc(h_img_emb, t_img_emb, r, mode) 99 | + self._calc(h_img_emb, t, r, mode) 100 | + self._calc(h, t_img_emb, r, mode) 101 | ) 102 | score = self.margin - score 103 | return score 104 | 105 | def cross_modal_score_ent2img(self, data): 106 | batch_h = data['batch_h'] 107 | batch_t = data['batch_t'] 108 | batch_r = data['batch_r'] 109 | h_ent, h_img, t_ent, t_img = batch_h, batch_h, batch_t, batch_t 110 | mode = data['mode'] 111 | h = self.ent_embeddings(h_ent) 112 | r = self.rel_embeddings(batch_r) 113 | t_img_emb = self.img_proj(self.img_embeddings(t_img)) 114 | # 跨模态链接预测的过程中,只考虑h+r和尾部图像的匹配度 115 | score = self._calc(h, t_img_emb, r, mode) 116 | score = self.margin - score 117 | return score 118 | 119 | def predict(self, data): 120 | if self.test_mode == "cmlp": 121 | score = -self.cross_modal_score_ent2img(data) 122 | else: 123 | score = -self.forward(data, batch_size=1, neg_mode='normal') 124 | return score.cpu().data.numpy() 125 | 126 | def regularization(self, data): 127 | batch_h = data['batch_h'] 128 | batch_t = data['batch_t'] 129 | batch_r = data['batch_r'] 130 | h = self.ent_embeddings(batch_h) 131 | t = self.ent_embeddings(batch_t) 132 | r = self.rel_embeddings(batch_r) 133 | regul = (torch.mean(h ** 2) + 134 | torch.mean(t ** 2) + 135 | torch.mean(r ** 2)) / 3 136 | return regul 137 | -------------------------------------------------------------------------------- /mmkgc/module/model/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .Model import Model 6 | from .TransE import TransE 7 | from .RotatE import RotatE 8 | from .IKRL import IKRL 9 | from .RSME import RSME 10 | from .EnsembleMMKGE import EnsembleMMKGE 11 | from .EnsembleComplEx import EnsembleComplEx 12 | from .TBKGC import TBKGC 13 | from .AdvMixRotatE import AdvMixRotatE 14 | from .TransAE import TransAE 15 | from .MMKRL import MMKRL 16 | 17 | __all__ = [ 18 | 'Model', 19 | 'TransE', 20 | 'RotatE', 21 | 'IKRL', 22 | 'RSME', 23 | 'TBKGC', 24 | 'EnsembleMMKGE', 25 | 'EnsembleComplEx', 26 | 'AdvMixRotatE', 27 | 'TransAE', 28 | 'MMKRL' 29 | ] 30 | -------------------------------------------------------------------------------- /mmkgc/module/strategy/MMKRLNegativeSampling.py: -------------------------------------------------------------------------------- 1 | from .Strategy import Strategy 2 | 3 | 4 | class MMKRLNegativeSampling(Strategy): 5 | 6 | def __init__(self, model=None, loss=None, batch_size=256, regul_rate=0.0, l3_regul_rate=0.0): 7 | super(MMKRLNegativeSampling, self).__init__() 8 | self.model = model 9 | self.loss = loss 10 | self.batch_size = batch_size 11 | self.regul_rate = regul_rate 12 | self.l3_regul_rate = l3_regul_rate 13 | 14 | 15 | def _get_positive_score(self, score): 16 | positive_score = score[:self.batch_size] 17 | positive_score = positive_score.view(-1, self.batch_size).permute(1, 0) 18 | return positive_score 19 | 20 | def _get_negative_score(self, score): 21 | negative_score = score[self.batch_size:] 22 | negative_score = negative_score.view(-1, self.batch_size).permute(1, 0) 23 | return negative_score 24 | 25 | def forward(self, data, fast_return=False): 26 | score, ka_loss = self.model(data, mse=True) 27 | p_score = self._get_positive_score(score) 28 | if fast_return: 29 | return p_score 30 | n_score = self._get_negative_score(score) 31 | loss_res = self.loss(p_score, n_score) + ka_loss 32 | if self.regul_rate != 0: 33 | loss_res += self.regul_rate * self.model.regularization(data) 34 | if self.l3_regul_rate != 0: 35 | loss_res += self.l3_regul_rate * self.model.l3_regularization() 36 | return loss_res, p_score 37 | -------------------------------------------------------------------------------- /mmkgc/module/strategy/NegativeSampling.py: -------------------------------------------------------------------------------- 1 | from .Strategy import Strategy 2 | 3 | 4 | class NegativeSampling(Strategy): 5 | 6 | def __init__(self, model=None, loss=None, batch_size=256, regul_rate=0.0, l3_regul_rate=0.0): 7 | super(NegativeSampling, self).__init__() 8 | self.model = model 9 | self.loss = loss 10 | self.batch_size = batch_size 11 | self.regul_rate = regul_rate 12 | self.l3_regul_rate = l3_regul_rate 13 | 14 | 15 | def _get_positive_score(self, score): 16 | positive_score = score[:self.batch_size] 17 | positive_score = positive_score.view(-1, self.batch_size).permute(1, 0) 18 | return positive_score 19 | 20 | def _get_negative_score(self, score): 21 | negative_score = score[self.batch_size:] 22 | negative_score = negative_score.view(-1, self.batch_size).permute(1, 0) 23 | return negative_score 24 | 25 | def forward(self, data, fast_return=False): 26 | score = self.model(data) 27 | p_score = self._get_positive_score(score) 28 | if fast_return: 29 | return p_score 30 | n_score = self._get_negative_score(score) 31 | loss_res = self.loss(p_score, n_score) 32 | if self.regul_rate != 0: 33 | loss_res += self.regul_rate * self.model.regularization(data) 34 | if self.l3_regul_rate != 0: 35 | loss_res += self.l3_regul_rate * self.model.l3_regularization() 36 | return loss_res, p_score 37 | -------------------------------------------------------------------------------- /mmkgc/module/strategy/Strategy.py: -------------------------------------------------------------------------------- 1 | from ..BaseModule import BaseModule 2 | 3 | class Strategy(BaseModule): 4 | 5 | def __init__(self): 6 | super(Strategy, self).__init__() -------------------------------------------------------------------------------- /mmkgc/module/strategy/TransAENegativeSampling.py: -------------------------------------------------------------------------------- 1 | from .Strategy import Strategy 2 | 3 | 4 | class TransAENegativeSampling(Strategy): 5 | 6 | def __init__(self, model=None, loss=None, batch_size=256, regul_rate=0.0, l3_regul_rate=0.0): 7 | super(TransAENegativeSampling, self).__init__() 8 | self.model = model 9 | self.loss = loss 10 | self.batch_size = batch_size 11 | self.regul_rate = regul_rate 12 | self.l3_regul_rate = l3_regul_rate 13 | 14 | 15 | def _get_positive_score(self, score): 16 | positive_score = score[:self.batch_size] 17 | positive_score = positive_score.view(-1, self.batch_size).permute(1, 0) 18 | return positive_score 19 | 20 | def _get_negative_score(self, score): 21 | negative_score = score[self.batch_size:] 22 | negative_score = negative_score.view(-1, self.batch_size).permute(1, 0) 23 | return negative_score 24 | 25 | def forward(self, data, fast_return=False): 26 | score, hloss = self.model(data) 27 | p_score = self._get_positive_score(score) 28 | if fast_return: 29 | return p_score 30 | n_score = self._get_negative_score(score) 31 | loss_res = self.loss(p_score, n_score) + hloss 32 | if self.regul_rate != 0: 33 | loss_res += self.regul_rate * self.model.regularization(data) 34 | if self.l3_regul_rate != 0: 35 | loss_res += self.l3_regul_rate * self.model.l3_regularization() 36 | return loss_res, p_score 37 | -------------------------------------------------------------------------------- /mmkgc/module/strategy/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .Strategy import Strategy 6 | from .NegativeSampling import NegativeSampling 7 | from .TransAENegativeSampling import TransAENegativeSampling 8 | from .MMKRLNegativeSampling import MMKRLNegativeSampling 9 | 10 | __all__ = [ 11 | 'Strategy', 12 | 'NegativeSampling', 13 | 'TransAENegativeSampling', 14 | 'MMKRLNegativeSampling' 15 | ] -------------------------------------------------------------------------------- /mmkgc/release/Base.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjukg/AdaMF-MAT/50930e9b28aed57133bedfc281391185ac9b68a4/mmkgc/release/Base.so -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.23.3 2 | scikit_learn==1.1.2 3 | torch==1.9.1+cu111 4 | tqdm==4.64.1 5 | -------------------------------------------------------------------------------- /run_adamf_mat.py: -------------------------------------------------------------------------------- 1 | from email.generator import Generator 2 | import torch 3 | import mmkgc 4 | from mmkgc.config import Tester, AdvMixTrainer 5 | from mmkgc.module.model import AdvMixRotatE 6 | from mmkgc.module.loss import SigmoidLoss 7 | from mmkgc.module.strategy import NegativeSampling 8 | from mmkgc.data import TrainDataLoader, TestDataLoader 9 | from mmkgc.adv.modules import MultiGenerator 10 | 11 | from args import get_args 12 | 13 | if __name__ == "__main__": 14 | args = get_args() 15 | print(args) 16 | # set the seed 17 | torch.manual_seed(args.seed) 18 | torch.cuda.manual_seed_all(args.seed) 19 | # dataloader for training 20 | train_dataloader = TrainDataLoader( 21 | in_path="./benchmarks/" + args.dataset + '/', 22 | batch_size=args.batch_size, 23 | threads=8, 24 | sampling_mode="normal", 25 | bern_flag=1, 26 | filter_flag=1, 27 | neg_ent=args.neg_num, 28 | neg_rel=0 29 | ) 30 | # dataloader for test 31 | test_dataloader = TestDataLoader( 32 | "./benchmarks/" + args.dataset + '/', "link") 33 | img_emb = torch.load('./embeddings/' + args.dataset + '-visual.pth') 34 | text_emb = torch.load('./embeddings/' + args.dataset + '-textual.pth') 35 | # define the model 36 | kge_score = AdvMixRotatE( 37 | ent_tot=train_dataloader.get_ent_tot(), 38 | rel_tot=train_dataloader.get_rel_tot(), 39 | dim=args.dim, 40 | margin=args.margin, 41 | epsilon=2.0, 42 | img_emb=img_emb, 43 | text_emb=text_emb 44 | ) 45 | print(kge_score) 46 | # define the loss function 47 | model = NegativeSampling( 48 | model=kge_score, 49 | loss=SigmoidLoss(adv_temperature=args.adv_temp), 50 | batch_size=train_dataloader.get_batch_size(), 51 | ) 52 | 53 | adv_generator = MultiGenerator( 54 | noise_dim=64, 55 | structure_dim=2*args.dim, 56 | img_dim=2*args.dim 57 | ) 58 | # train the model 59 | trainer = AdvMixTrainer( 60 | model=model, 61 | data_loader=train_dataloader, 62 | train_times=args.epoch, 63 | alpha=args.learning_rate, 64 | use_gpu=True, 65 | opt_method='Adam', 66 | generator=adv_generator, 67 | lrg=args.lrg, 68 | mu=args.mu 69 | ) 70 | 71 | trainer.run() 72 | kge_score.save_checkpoint(args.save) 73 | 74 | # test the model 75 | kge_score.load_checkpoint(args.save) 76 | tester = Tester(model=kge_score, data_loader=test_dataloader, use_gpu=True) 77 | tester.run_link_prediction(type_constrain=False) 78 | -------------------------------------------------------------------------------- /scripts/run_db15k.sh: -------------------------------------------------------------------------------- 1 | DATA=DB15K 2 | EMB_DIM=250 3 | NUM_BATCH=1024 4 | MARGIN=12 5 | LR=2e-5 6 | NEG_NUM=128 7 | EPOCH=1000 8 | 9 | CUDA_VISIBLE_DEVICES=0 nohup python run_adamf_mat.py -dataset=$DATA \ 10 | -batch_size=$NUM_BATCH \ 11 | -margin=$MARGIN \ 12 | -epoch=$EPOCH \ 13 | -dim=$EMB_DIM \ 14 | -mu=0 \ 15 | -save=./checkpoint/$DATA-$NUM_BATCH-$EMB_DIM-$NEG_NUM-$MARGIN-$LR-$EPOCH \ 16 | -neg_num=$NEG_NUM \ 17 | -learning_rate=$LR > $DATA-$EMB_DIM-$NUM_BATCH-$NEG_NUM-$MARGIN-$EPOCH.txt & 18 | -------------------------------------------------------------------------------- /scripts/run_mkgw.sh: -------------------------------------------------------------------------------- 1 | DATA=MKG-W 2 | EMB_DIM=200 3 | NUM_BATCH=1024 4 | MARGIN=12 5 | LR=2e-5 6 | NEG_NUM=128 7 | EPOCH=1000 8 | 9 | CUDA_VISIBLE_DEVICES=0 nohup python run_adamf_mat.py -dataset=$DATA \ 10 | -batch_size=$NUM_BATCH \ 11 | -margin=$MARGIN \ 12 | -epoch=$EPOCH \ 13 | -dim=$EMB_DIM \ 14 | -mu=0 \ 15 | -save=./checkpoint/$DATA-$NUM_BATCH-$EMB_DIM-$NEG_NUM-$MARGIN-$LR-$EPOCH \ 16 | -neg_num=$NEG_NUM \ 17 | -learning_rate=$LR > $DATA-$EMB_DIM-$NUM_BATCH-$NEG_NUM-$MARGIN-$EPOCH.txt & 18 | -------------------------------------------------------------------------------- /scripts/run_mkgy.sh: -------------------------------------------------------------------------------- 1 | DATA=MKG-Y 2 | EMB_DIM=200 3 | NUM_BATCH=1024 4 | MARGIN=4 5 | LR=2e-5 6 | NEG_NUM=128 7 | EPOCH=1000 8 | 9 | CUDA_VISIBLE_DEVICES=0 nohup python run_adamf_mat.py -dataset=$DATA \ 10 | -batch_size=$NUM_BATCH \ 11 | -margin=$MARGIN \ 12 | -epoch=$EPOCH \ 13 | -dim=$EMB_DIM \ 14 | -mu=0 \ 15 | -save=./checkpoint/$DATA-$NUM_BATCH-$EMB_DIM-$NEG_NUM-$MARGIN-$LR-$EPOCH \ 16 | -neg_num=$NEG_NUM \ 17 | -learning_rate=$LR > $DATA-$EMB_DIM-$NUM_BATCH-$NEG_NUM-$MARGIN-$EPOCH.txt & 18 | --------------------------------------------------------------------------------