├── .gitignore ├── IMG ├── logo.png ├── logo1.png ├── logo2.png └── model.jpg ├── LICENSE ├── README.md ├── base_model.py ├── data.py ├── load_data.py ├── models.py ├── opt.py ├── run.sh ├── run.slurm ├── run_dbp.sh ├── run_dbp.slurm ├── run_oea.sh ├── run_oea.slurm ├── train.py ├── utils.py └── vis.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__ 3 | results 4 | *.log 5 | vis -------------------------------------------------------------------------------- /IMG/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyyf2002/ASGEA/57eb3216f273db03e2aab5db1e38de65917b9f37/IMG/logo.png -------------------------------------------------------------------------------- /IMG/logo1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyyf2002/ASGEA/57eb3216f273db03e2aab5db1e38de65917b9f37/IMG/logo1.png -------------------------------------------------------------------------------- /IMG/logo2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyyf2002/ASGEA/57eb3216f273db03e2aab5db1e38de65917b9f37/IMG/logo2.png -------------------------------------------------------------------------------- /IMG/model.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyyf2002/ASGEA/57eb3216f273db03e2aab5db1e38de65917b9f37/IMG/model.jpg -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Luo Yangyifei 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | Logo 3 |
4 | 5 | # 🏕️ [ASGEA: Exploiting Logic Rules from Align-Subgraphs for Entity Alignment](https://arxiv.org/abs/2402.11000) 6 | 7 | [![license](https://img.shields.io/github/license/mashape/apistatus.svg?maxAge=2592000)](https://github.com/zjukg/MEAformer/blob/main/licence) 8 | [![arxiv badge](https://img.shields.io/badge/arxiv-2402.11000-red)](https://arxiv.org/abs/2402.11000) 9 | [![Pytorch](https://img.shields.io/badge/PyTorch-%23EE4C2C.svg?e&logo=PyTorch&logoColor=white)](https://pytorch.org/) 10 | 11 | 12 | >This paper proposes the Align-Subgraph Entity Alignment (ASGEA) framework to exploit logic rules from Align-Subgraphs. ASGEA uses anchor links as bridges to construct Align-Subgraphs and spreads along the paths across KGs, which distinguishes it from the embedding-based methods. 13 | 14 |
15 | 16 |
17 | 18 | 19 | ## 🔬 Dependencies 20 | ``` 21 | pytorch 1.12.0 22 | torch_geometric 2.2.0 23 | torch_scatter 2.0.9 24 | transformers 4.26.1 25 | ``` 26 | 27 | ## 🚀 Train 28 | 29 | - **Quick start**: Using script file for ASGEA-MM. 30 | 31 | ```bash 32 | # FBDB15K & FBYG15K 33 | >> bash run.sh FB 34 | # DBP15K 35 | >> bash run_dbp.sh DBP 36 | # Multi OpenEA 37 | >> bash run_oea.sh OEA 38 | ``` 39 | 40 | - **❗tips**: If you are using slurm, you can change the `.sh` file from 41 | 42 | ```bash 43 | datas="FBDB15K FBYG15K" 44 | rates="0.2 0.5 0.8" 45 | expn=$1 46 | if [ ! -d "results/${expn}" ]; then 47 | mkdir results/${expn} 48 | fi 49 | if [ ! -d "results/${expn}/backup" ]; then 50 | mkdir results/${expn}/backup 51 | fi 52 | cp *.py results/${expn}/backup/ 53 | for data in $datas ; do 54 | for rate in $rates ; do 55 | python train.py --data_split norm --n_batch 4 --n_layer 5 --lr 0.001 --data_choice ${data} --data_rate ${rate} --exp_name ${expn} --mm 1 --img_dim 4096 56 | # echo "sbatch -o ${data}_${rate}.log run.slurm 4 5 0.001 ${data} ${rate}" 57 | # sbatch -o ${expn}_${data}_${rate}.log run.slurm 4 5 0.001 ${data} ${rate} ${expn} 58 | done 59 | done 60 | ``` 61 | 62 | to 63 | 64 | ```bash 65 | datas="FBDB15K FBYG15K" 66 | rates="0.2 0.5 0.8" 67 | expn=$1 68 | if [ ! -d "results/${expn}" ]; then 69 | mkdir results/${expn} 70 | fi 71 | if [ ! -d "results/${expn}/backup" ]; then 72 | mkdir results/${expn}/backup 73 | fi 74 | cp *.py results/${expn}/backup/ 75 | for data in $datas ; do 76 | for rate in $rates ; do 77 | echo "sbatch -o ${data}_${rate}.log run.slurm 4 5 0.001 ${data} ${rate}" 78 | sbatch -o ${expn}_${data}_${rate}.log run.slurm 4 5 0.001 ${data} ${rate} ${expn} 79 | done 80 | done 81 | ``` 82 | 83 | - **for ASGEA-Stru**: Just set `mm=0`. 84 | 85 | 86 | ## 📚 Dataset 87 | ❗NOTE: Download from [ufile](https://ufile.io/kzkkfayd) (1.69G) and unzip it to make those files satisfy the following file hierarchy: 88 | 89 | ``` 90 | ROOT 91 | ├── data 92 | │ └── mmkg 93 | └── ASGEA 94 | ``` 95 | 96 | #### Code Path 97 | 98 |
99 | 👈 🔎 Click 100 | 101 | ``` 102 | ASGEA 103 | ├── base_model.py 104 | ├── data.py 105 | ├── load_data.py 106 | ├── models.py 107 | ├── opt.py 108 | ├── README.md 109 | ├── run.sh 110 | ├── run.slurm 111 | ├── run_dbp.sh 112 | ├── run_dbp.slurm 113 | ├── run_oea.sh 114 | ├── run_oea.slurm 115 | ├── train.py 116 | ├── utils.py 117 | └── vis.py 118 | ``` 119 | 120 |
121 | 122 | #### Data Path 123 |
124 | 👈 🔎 Click 125 | 126 | ``` 127 | mmkg 128 | ├─ DBP15K 129 | │ ├─ fr_en 130 | │ │ ├─ att_features100.npy 131 | │ │ ├─ att_features500.npy 132 | │ │ ├─ att_rel_features100.npy 133 | │ │ ├─ att_rel_features500.npy 134 | │ │ ├─ att_val_features100.npy 135 | │ │ ├─ att_val_features500.npy 136 | │ │ ├─ en_att_triples 137 | │ │ ├─ ent_ids_1 138 | │ │ ├─ ent_ids_2 139 | │ │ ├─ fr_att_triples 140 | │ │ ├─ ill_ent_ids 141 | │ │ ├─ training_attrs_1 142 | │ │ ├─ training_attrs_2 143 | │ │ ├─ triples_1 144 | │ │ └─ triples_2 145 | │ ├─ ja_en 146 | │ │ ├─ att_features100.npy 147 | │ │ ├─ att_features500.npy 148 | │ │ ├─ att_rel_features100.npy 149 | │ │ ├─ att_rel_features500.npy 150 | │ │ ├─ att_val_features100.npy 151 | │ │ ├─ att_val_features500.npy 152 | │ │ ├─ en_att_triples 153 | │ │ ├─ ent_ids_1 154 | │ │ ├─ ent_ids_2 155 | │ │ ├─ ill_ent_ids 156 | │ │ ├─ ja_att_triples 157 | │ │ ├─ training_attrs_1 158 | │ │ ├─ training_attrs_2 159 | │ │ ├─ triples_1 160 | │ │ └─ triples_2 161 | │ ├─ translated_ent_name 162 | │ │ ├─ dbp_fr_en.json 163 | │ │ ├─ dbp_ja_en.json 164 | │ │ └─ dbp_zh_en.json 165 | │ └─ zh_en 166 | │ ├─ att_features100.npy 167 | │ ├─ att_features500.npy 168 | │ ├─ att_rel_features100.npy 169 | │ ├─ att_rel_features500.npy 170 | │ ├─ att_val_features100.npy 171 | │ ├─ att_val_features500.npy 172 | │ ├─ en_att_triples 173 | │ ├─ ent_ids_1 174 | │ ├─ ent_ids_2 175 | │ ├─ ill_ent_ids 176 | │ ├─ rule_test.txt 177 | │ ├─ rule_train.txt 178 | │ ├─ training_attrs_1 179 | │ ├─ training_attrs_2 180 | │ ├─ triples_1 181 | │ ├─ triples_2 182 | │ └─ zh_att_triples 183 | ├─ FBDB15K 184 | │ └─ norm 185 | │ ├─ DB15K_NumericalTriples.txt 186 | │ ├─ FB15K_NumericalTriples.txt 187 | │ ├─ att_features.npy 188 | │ ├─ att_rel_features.npy 189 | │ ├─ att_val_features.npy 190 | │ ├─ ent_ids_1 191 | │ ├─ ent_ids_2 192 | │ ├─ fbid2name.txt 193 | │ ├─ id2relation.txt 194 | │ ├─ ill_ent_ids 195 | │ ├─ training_attrs_1 196 | │ ├─ training_attrs_2 197 | │ ├─ triples_1 198 | │ └─ triples_2 199 | ├─ FBYG15K 200 | │ └─ norm 201 | │ ├─ FB15K_NumericalTriples.txt 202 | │ ├─ YAGO15K_NumericalTriples.txt 203 | │ ├─ att_features.npy 204 | │ ├─ att_rel_features.npy 205 | │ ├─ att_val_features.npy 206 | │ ├─ ent_ids_1 207 | │ ├─ ent_ids_2 208 | │ ├─ fbid2name.txt 209 | │ ├─ id2relation.txt 210 | │ ├─ ill_ent_ids 211 | │ ├─ training_attrs_1 212 | │ ├─ training_attrs_2 213 | │ ├─ triples_1 214 | │ └─ triples_2 215 | ├─ MEAformer 216 | ├─ OpenEA 217 | │ ├─ OEA_D_W_15K_V1 218 | │ │ ├─ att_features.npy 219 | │ │ ├─ att_features500.npy 220 | │ │ ├─ att_rel_features.npy 221 | │ │ ├─ att_rel_features500.npy 222 | │ │ ├─ att_val_features.npy 223 | │ │ ├─ att_val_features500.npy 224 | │ │ ├─ attr_triples_1 225 | │ │ ├─ attr_triples_2 226 | │ │ ├─ ent_ids_1 227 | │ │ ├─ ent_ids_2 228 | │ │ ├─ ill_ent_ids 229 | │ │ ├─ rel_ids 230 | │ │ ├─ training_attrs_1 231 | │ │ ├─ training_attrs_2 232 | │ │ ├─ triples_1 233 | │ │ └─ triples_2 234 | │ ├─ OEA_D_W_15K_V2 235 | │ │ ├─ att_features.npy 236 | │ │ ├─ att_features500.npy 237 | │ │ ├─ att_rel_features.npy 238 | │ │ ├─ att_rel_features500.npy 239 | │ │ ├─ att_val_features.npy 240 | │ │ ├─ att_val_features500.npy 241 | │ │ ├─ attr_triples_1 242 | │ │ ├─ attr_triples_2 243 | │ │ ├─ ent_ids_1 244 | │ │ ├─ ent_ids_2 245 | │ │ ├─ ill_ent_ids 246 | │ │ ├─ rel_ids 247 | │ │ ├─ training_attrs_1 248 | │ │ ├─ training_attrs_2 249 | │ │ ├─ triples_1 250 | │ │ └─ triples_2 251 | │ ├─ OEA_D_Y_15K_V1 252 | │ │ ├─ 721_5fold 253 | │ │ │ ├─ 1 254 | │ │ │ │ ├─ test_links 255 | │ │ │ │ ├─ train_links 256 | │ │ │ │ └─ valid_links 257 | │ │ │ ├─ 2 258 | │ │ │ │ ├─ test_links 259 | │ │ │ │ ├─ train_links 260 | │ │ │ │ └─ valid_links 261 | │ │ │ ├─ 3 262 | │ │ │ │ ├─ test_links 263 | │ │ │ │ ├─ train_links 264 | │ │ │ │ └─ valid_links 265 | │ │ │ ├─ 4 266 | │ │ │ │ ├─ test_links 267 | │ │ │ │ ├─ train_links 268 | │ │ │ │ └─ valid_links 269 | │ │ │ └─ 5 270 | │ │ │ ├─ test_links 271 | │ │ │ ├─ train_links 272 | │ │ │ └─ valid_links 273 | │ │ ├─ attr_triples_1 274 | │ │ ├─ attr_triples_2 275 | │ │ ├─ ent_ids_1 276 | │ │ ├─ ent_ids_2 277 | │ │ ├─ ent_links 278 | │ │ ├─ ill_ent_ids 279 | │ │ ├─ rel_ids 280 | │ │ ├─ rel_triples_1 281 | │ │ ├─ rel_triples_2 282 | │ │ ├─ triples_1 283 | │ │ └─ triples_2 284 | │ ├─ OEA_D_Y_15K_V2 285 | │ │ ├─ 721_5fold 286 | │ │ │ ├─ 1 287 | │ │ │ │ ├─ test_links 288 | │ │ │ │ ├─ train_links 289 | │ │ │ │ └─ valid_links 290 | │ │ │ ├─ 2 291 | │ │ │ │ ├─ test_links 292 | │ │ │ │ ├─ train_links 293 | │ │ │ │ └─ valid_links 294 | │ │ │ ├─ 3 295 | │ │ │ │ ├─ test_links 296 | │ │ │ │ ├─ train_links 297 | │ │ │ │ └─ valid_links 298 | │ │ │ ├─ 4 299 | │ │ │ │ ├─ test_links 300 | │ │ │ │ ├─ train_links 301 | │ │ │ │ └─ valid_links 302 | │ │ │ └─ 5 303 | │ │ │ ├─ test_links 304 | │ │ │ ├─ train_links 305 | │ │ │ └─ valid_links 306 | │ │ ├─ attr_triples_1 307 | │ │ ├─ attr_triples_2 308 | │ │ ├─ ent_ids_1 309 | │ │ ├─ ent_ids_2 310 | │ │ ├─ ent_links 311 | │ │ ├─ ill_ent_ids 312 | │ │ ├─ rel_ids 313 | │ │ ├─ rel_triples_1 314 | │ │ ├─ rel_triples_2 315 | │ │ ├─ triples_1 316 | │ │ └─ triples_2 317 | │ ├─ OEA_EN_DE_15K_V1 318 | │ │ ├─ att_features.npy 319 | │ │ ├─ att_features500.npy 320 | │ │ ├─ att_rel_features.npy 321 | │ │ ├─ att_rel_features500.npy 322 | │ │ ├─ att_val_features.npy 323 | │ │ ├─ att_val_features500.npy 324 | │ │ ├─ attr_triples_1 325 | │ │ ├─ attr_triples_2 326 | │ │ ├─ ent_ids_1 327 | │ │ ├─ ent_ids_2 328 | │ │ ├─ ill_ent_ids 329 | │ │ ├─ rel_ids 330 | │ │ ├─ training_attrs_1 331 | │ │ ├─ training_attrs_2 332 | │ │ ├─ triples_1 333 | │ │ └─ triples_2 334 | │ ├─ OEA_EN_DE_15K_V2 335 | │ │ ├─ 721_5fold 336 | │ │ │ ├─ 1 337 | │ │ │ │ ├─ test_links 338 | │ │ │ │ ├─ train_links 339 | │ │ │ │ └─ valid_links 340 | │ │ │ ├─ 2 341 | │ │ │ │ ├─ test_links 342 | │ │ │ │ ├─ train_links 343 | │ │ │ │ └─ valid_links 344 | │ │ │ ├─ 3 345 | │ │ │ │ ├─ test_links 346 | │ │ │ │ ├─ train_links 347 | │ │ │ │ └─ valid_links 348 | │ │ │ ├─ 4 349 | │ │ │ │ ├─ test_links 350 | │ │ │ │ ├─ train_links 351 | │ │ │ │ └─ valid_links 352 | │ │ │ └─ 5 353 | │ │ │ ├─ test_links 354 | │ │ │ ├─ train_links 355 | │ │ │ └─ valid_links 356 | │ │ ├─ attr_triples_1 357 | │ │ ├─ attr_triples_2 358 | │ │ ├─ ent_ids_1 359 | │ │ ├─ ent_ids_2 360 | │ │ ├─ ent_links 361 | │ │ ├─ ill_ent_ids 362 | │ │ ├─ rel_ids 363 | │ │ ├─ rel_triples_1 364 | │ │ ├─ rel_triples_2 365 | │ │ ├─ triples_1 366 | │ │ └─ triples_2 367 | │ ├─ OEA_EN_FR_15K_V1 368 | │ │ ├─ att_features.npy 369 | │ │ ├─ att_rel_features.npy 370 | │ │ ├─ att_val_features.npy 371 | │ │ ├─ attr_triples_1 372 | │ │ ├─ attr_triples_2 373 | │ │ ├─ ent_ids_1 374 | │ │ ├─ ent_ids_2 375 | │ │ ├─ ill_ent_ids 376 | │ │ ├─ rel_ids 377 | │ │ ├─ training_attrs_1 378 | │ │ ├─ training_attrs_2 379 | │ │ ├─ triples_1 380 | │ │ └─ triples_2 381 | │ ├─ OEA_EN_FR_15K_V2 382 | │ │ ├─ 721_5fold 383 | │ │ │ ├─ 1 384 | │ │ │ │ ├─ test_links 385 | │ │ │ │ ├─ train_links 386 | │ │ │ │ └─ valid_links 387 | │ │ │ ├─ 2 388 | │ │ │ │ ├─ test_links 389 | │ │ │ │ ├─ train_links 390 | │ │ │ │ └─ valid_links 391 | │ │ │ ├─ 3 392 | │ │ │ │ ├─ test_links 393 | │ │ │ │ ├─ train_links 394 | │ │ │ │ └─ valid_links 395 | │ │ │ ├─ 4 396 | │ │ │ │ ├─ test_links 397 | │ │ │ │ ├─ train_links 398 | │ │ │ │ └─ valid_links 399 | │ │ │ └─ 5 400 | │ │ │ ├─ test_links 401 | │ │ │ ├─ train_links 402 | │ │ │ └─ valid_links 403 | │ │ ├─ attr_triples_1 404 | │ │ ├─ attr_triples_2 405 | │ │ ├─ ent_ids_1 406 | │ │ ├─ ent_ids_2 407 | │ │ ├─ ent_links 408 | │ │ ├─ ill_ent_ids 409 | │ │ ├─ rel_ids 410 | │ │ ├─ rel_triples_1 411 | │ │ ├─ rel_triples_2 412 | │ │ ├─ triples_1 413 | │ │ └─ triples_2 414 | │ ├─ pkl 415 | │ │ ├─ OEA_D_W_15K_V1_id_img_feature_dict.pkl 416 | │ │ ├─ OEA_D_W_15K_V2_id_img_feature_dict.pkl 417 | │ │ ├─ OEA_EN_DE_15K_V1_id_img_feature_dict.pkl 418 | │ │ └─ OEA_EN_FR_15K_V1_id_img_feature_dict.pkl 419 | │ └─ data.py 420 | ├─ dump 421 | ├─ embedding 422 | │ ├─ dbp_fr_en_char.pkl 423 | │ ├─ dbp_fr_en_name.pkl 424 | │ ├─ dbp_ja_en_char.pkl 425 | │ ├─ dbp_ja_en_name.pkl 426 | │ ├─ dbp_zh_en_char.pkl 427 | │ ├─ dbp_zh_en_name.pkl 428 | │ └─ glove.6B.300d.txt 429 | └─ pkls 430 | ├─ FBDB15K_id_img_feature_dict.pkl 431 | ├─ FBYG15K_id_img_feature_dict.pkl 432 | ├─ dbpedia_wikidata_15k_dense_GA_id_img_feature_dict.pkl 433 | ├─ dbpedia_wikidata_15k_norm_GA_id_img_feature_dict.pkl 434 | ├─ fr_en_GA_id_img_feature_dict.pkl 435 | ├─ ja_en_GA_id_img_feature_dict.pkl 436 | └─ zh_en_GA_id_img_feature_dict.pkl 437 | ``` 438 | 439 |
440 | 441 | ## 🤝 Cite: 442 | 443 | Please condiser citing this paper if you use the ```code``` or ```data``` from our work. 444 | Thanks a lot :) 445 | ``` 446 | @article{DBLP:journals/corr/abs-2402-11000, 447 | author = {Yangyifei Luo and 448 | Zhuo Chen and 449 | Lingbing Guo and 450 | Qian Li and 451 | Wenxuan Zeng and 452 | Zhixin Cai and 453 | Jianxin Li}, 454 | title = {{ASGEA:} Exploiting Logic Rules from Align-Subgraphs for Entity Alignment}, 455 | journal = {CoRR}, 456 | volume = {abs/2402.11000}, 457 | year = {2024} 458 | } 459 | ``` 460 | -------------------------------------------------------------------------------- /base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import time 4 | from tqdm import tqdm 5 | from torch.optim import Adam 6 | from torch.optim.lr_scheduler import ExponentialLR 7 | from models import MASGNN 8 | from utils import cal_ranks, cal_performance 9 | 10 | class BaseModel(object): 11 | def __init__(self, args, loader): 12 | self.model = MASGNN(args, loader) 13 | self.model.cuda() 14 | 15 | self.loader = loader 16 | self.n_ent = loader.n_ent 17 | self.n_batch = args.n_batch 18 | self.n_rel = loader.n_rel 19 | self.left_ents = loader.left_ents 20 | self.right_ents = loader.right_ents 21 | self.shuffle = args.shuffle 22 | 23 | self.n_train = loader.n_train 24 | # self.n_valid = loader.n_valid 25 | self.n_test = loader.n_test 26 | self.n_layer = args.n_layer 27 | 28 | self.optimizer = Adam(self.model.parameters(), lr=args.lr, weight_decay=args.lamb) 29 | self.scheduler = ExponentialLR(self.optimizer, args.decay_rate) 30 | self.t_time = 0 31 | 32 | def train_batch(self,): 33 | epoch_loss = 0 34 | i = 0 35 | 36 | batch_size = self.n_batch 37 | n_batch = self.n_train // batch_size + (self.n_train % batch_size > 0) 38 | if self.shuffle: 39 | self.loader.shuffle_train() 40 | 41 | t_time = time.time() 42 | self.model.train() 43 | for i in tqdm(range(n_batch)): 44 | start = i*batch_size 45 | end = min(self.n_train, (i+1)*batch_size) 46 | batch_idx = np.arange(start, end) 47 | triple = self.loader.get_batch(batch_idx) 48 | 49 | self.model.zero_grad() 50 | scores = self.model(triple[:,0]) 51 | 52 | pos_scores = scores[[torch.arange(len(scores)).cuda(),torch.LongTensor(triple[:,2]).cuda()]] 53 | max_n = torch.max(scores, 1, keepdim=True)[0] 54 | loss = torch.sum(- pos_scores + max_n + torch.log(torch.sum(torch.exp(scores - max_n),1))) 55 | # gamma = 0.1 56 | # lambd = 1 57 | # tau = 1 58 | # max_n = torch.max(scores, 1, keepdim=True)[0] 59 | # scores = max_n - scores 60 | # pos_scores = scores[[torch.arange(len(scores)).cuda(), torch.LongTensor(triple[:, 2]).cuda()]] 61 | # # extend pos_scores to scores 62 | # pos_scores = pos_scores.unsqueeze(-1) 63 | # l = gamma + pos_scores - scores 64 | # ln = (l - l.mean(dim=-1, keepdim=True).detach()) / l.std(dim=-1, keepdim=True).detach() 65 | # # ln = (l - mu) / torch.sqrt(sig + 1e-6) 66 | # loss = torch.sum(torch.log(1 + torch.sum(torch.exp(lambd * ln + tau), 1))) 67 | 68 | loss.backward() 69 | self.optimizer.step() 70 | 71 | # avoid NaN 72 | for p in self.model.parameters(): 73 | X = p.data.clone() 74 | flag = X != X 75 | X[flag] = np.random.random() 76 | p.data.copy_(X) 77 | epoch_loss += loss.item() 78 | self.scheduler.step() 79 | self.t_time += time.time() - t_time 80 | 81 | t_mrr,t_h1, t_h3, t_h5, t_h10, out_str = self.evaluate() 82 | return t_mrr,t_h1, t_h3, t_h5, t_h10, out_str 83 | 84 | def evaluate(self, ): 85 | batch_size = self.n_batch 86 | i_time = time.time() 87 | n_data = self.n_test 88 | n_batch = n_data // batch_size + (n_data % batch_size > 0) 89 | ranking = [] 90 | self.model.eval() 91 | for i in range(n_batch): 92 | start = i*batch_size 93 | end = min(n_data, (i+1)*batch_size) 94 | batch_idx = np.arange(start, end) 95 | triple = self.loader.get_batch(batch_idx, data='test') 96 | subs, rels, objs = triple[:,0],triple[:,1],triple[:,2] 97 | is_lefts = rels == self.n_rel*2+1 98 | scores = self.model(subs,'test').data.cpu().numpy() 99 | 100 | ranks = cal_ranks(scores, objs, is_lefts, len(self.left_ents)) 101 | ranking += ranks 102 | ranking = np.array(ranking) 103 | t_mrr, t_h1, t_h3, t_h5, t_h10 = cal_performance(ranking) 104 | i_time = time.time() - i_time 105 | 106 | out_str = '[TEST] MRR:%.4f H@1:%.4f H@3:%.4f H@5:%.4f H@10:%.4f \t[TIME] inference:%.4f\n' % (t_mrr, t_h1, t_h3, t_h5, t_h10, i_time) 107 | return t_mrr,t_h1, t_h3, t_h5, t_h10, out_str 108 | 109 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import json 4 | import numpy as np 5 | import pdb 6 | import torch.distributed as dist 7 | import os 8 | import os.path as osp 9 | from collections import Counter 10 | import pickle 11 | import torch.nn.functional as F 12 | from transformers import BertTokenizer 13 | import torch.distributed 14 | from tqdm import tqdm 15 | import re 16 | 17 | from utils import get_topk_indices, get_adjr 18 | 19 | 20 | class EADataset(torch.utils.data.Dataset): 21 | def __init__(self, data): 22 | self.data = data 23 | 24 | def __len__(self): 25 | return len(self.data) 26 | 27 | def __getitem__(self, index): 28 | return self.data[index] 29 | 30 | 31 | class Collator_base(object): 32 | def __init__(self, args): 33 | self.args = args 34 | 35 | def __call__(self, batch): 36 | # pdb.set_trace() 37 | 38 | return np.array(batch) 39 | 40 | 41 | # def load_data(logger, args): 42 | # assert args.data_choice in ["DWY", "DBP15K", "FBYG15K", "FBDB15K"] 43 | # if args.data_choice in ["DWY", "DBP15K", "FBYG15K", "FBDB15K"]: 44 | # KGs, non_train, train_ill, test_ill, eval_ill, test_ill_ = load_eva_data(logger, args) 45 | # 46 | # elif args.data_choice in ["FBYG15K_attr", "FBDB15K_attr"]: 47 | # pass 48 | # 49 | # return KGs, non_train, train_ill, test_ill, eval_ill, test_ill_ 50 | # 51 | 52 | 53 | def load_eva_data(args): 54 | if "OEA" in args.data_choice: 55 | file_dir = osp.join(args.data_path, "OpenEA", args.data_choice) 56 | else: 57 | file_dir = osp.join(args.data_path, args.data_choice, args.data_split) 58 | lang_list = [1, 2] 59 | ent2id_dict, ills, triples, r_hs, r_ts, ids = read_raw_data(file_dir, lang_list) 60 | e1 = os.path.join(file_dir, 'ent_ids_1') 61 | e2 = os.path.join(file_dir, 'ent_ids_2') 62 | left_ents,left_id2name = get_ids(e1,file_dir) 63 | right_ents,right_id2name = get_ids(e2,file_dir) 64 | id2name = {**left_id2name, **right_id2name} 65 | if not args.data_choice == "DBP15K" and not args.data_choice == "OpenEA": 66 | id2rel = get_id2rel(os.path.join(file_dir, 'id2relation.txt')) 67 | elif args.data_choice == "OpenEA": 68 | id2rel = get_id2rel(os.path.join(file_dir, 'rel_ids')) 69 | else: 70 | id2rel = None 71 | ENT_NUM = len(ent2id_dict) 72 | REL_NUM = len(r_hs) 73 | np.random.shuffle(ills) 74 | if args.mm: 75 | if args.data_choice == "OpenEA": 76 | img_vec_path = osp.join(args.data_path, f"OpenEA/pkl/{args.data_split}_id_img_feature_dict.pkl") 77 | elif "FB" in file_dir: 78 | img_vec_path = osp.join(args.data_path, f"pkls/{args.data_choice}_id_img_feature_dict.pkl") 79 | else: 80 | # fr_en 81 | split = file_dir.split("/")[-1] 82 | img_vec_path = osp.join(args.data_path, "pkls", args.data_split + "_GA_id_img_feature_dict.pkl") 83 | 84 | assert osp.exists(img_vec_path) 85 | img_features = load_img(ENT_NUM, img_vec_path) 86 | print(f"image feature shape:{img_features.shape}") 87 | 88 | if args.word_embedding == "glove": 89 | word2vec_path = os.path.join(args.data_path, "embedding", "glove.6B.300d.txt") 90 | elif args.word_embedding == 'bert': 91 | pass 92 | else: 93 | raise Exception("error word embedding") 94 | else: 95 | img_features = None 96 | 97 | name_features = None 98 | char_features = None 99 | # if args.data_choice == "DBP15K" and (args.w_name or args.w_char): 100 | 101 | # assert osp.exists(word2vec_path) 102 | # ent_vec, char_features = load_word_char_features(ENT_NUM, word2vec_path, args) 103 | # name_features = F.normalize(torch.Tensor(ent_vec)) 104 | # char_features = F.normalize(torch.Tensor(char_features)) 105 | # print(f"name feature shape:{name_features.shape}") 106 | # print(f"char feature shape:{char_features.shape}") 107 | img_ill = None 108 | if args.mm: 109 | input_features = F.normalize(torch.Tensor(img_features)) 110 | img_ill = visual_pivot_induction(args, left_ents, right_ents, input_features, ills) 111 | 112 | train_ill = np.array(ills[:int(len(ills) // 1 * args.data_rate)], dtype=np.int32) 113 | 114 | test_ill_ = ills[int(len(ills) // 1 * args.data_rate):] 115 | test_ill = np.array(test_ill_, dtype=np.int32) 116 | 117 | test_left = torch.LongTensor(test_ill[:, 0].squeeze()) 118 | test_right = torch.LongTensor(test_ill[:, 1].squeeze()) 119 | 120 | left_non_train = list(set(left_ents) - set(train_ill[:, 0].tolist())) 121 | 122 | right_non_train = list(set(right_ents) - set(train_ill[:, 1].tolist())) 123 | 124 | print(f"#left entity : {len(left_ents)}, #right entity: {len(right_ents)}") 125 | print(f"#left entity not in train set: {len(left_non_train)}, #right entity not in train set: {len(right_non_train)}") 126 | 127 | rel_features = load_relation(ENT_NUM, triples, 1000) 128 | print(f"relation feature shape:{rel_features.shape}") 129 | if 'OpenEA' in args.data_choice: 130 | a1 = os.path.join(file_dir, f'attr_triples_1') 131 | a2 = os.path.join(file_dir, f'attr_triples_2') 132 | att_features, num_att_left, num_att_right = load_attr_withNums(['oea', 'oea'], [a1, a2], ent2id_dict, file_dir, 133 | topk=args.topk) 134 | elif 'FB' in args.data_choice: 135 | a1 = os.path.join(file_dir, 'FB15K_NumericalTriples.txt') 136 | a2 = os.path.join(file_dir, 'DB15K_NumericalTriples.txt') if 'DB' in args.data_choice else os.path.join(file_dir, 'YAGO15K_NumericalTriples.txt') 137 | att_features, num_att_left, num_att_right = load_attr_withNums(['FB15K','DB15K'] if 'DB' in args.data_choice else ['FB15K','YAGO15K'],[a1, a2], ent2id_dict, file_dir, topk=0) 138 | else: 139 | att1,att2 = args.data_split.split('_') 140 | a1 = os.path.join(file_dir, f'{att1}_att_triples') 141 | a2 = os.path.join(file_dir, f'{att2}_att_triples') 142 | att_features, num_att_left, num_att_right = load_attr_withNums([att1,att2],[a1, a2], ent2id_dict, file_dir, topk=args.topk) 143 | print(f"attribute feature shape:{len(att_features)}") 144 | print("-----dataset summary-----") 145 | print(f"dataset:\t\t {file_dir}") 146 | print(f"triple num:\t {len(triples)}") 147 | print(f"entity num:\t {ENT_NUM}") 148 | print(f"relation num:\t {REL_NUM}") 149 | print(f"train ill num:\t {train_ill.shape[0]} \t test ill num:\t {test_ill.shape[0]}") 150 | print("-------------------------") 151 | 152 | eval_ill = None 153 | input_idx = torch.LongTensor(np.arange(ENT_NUM)) 154 | 155 | # pdb.set_trace() 156 | # train_ill = EADataset(train_ill) 157 | # test_ill = EADataset(test_ill) 158 | 159 | return { 160 | 'ent_num': ENT_NUM, 161 | 'rel_num': REL_NUM, 162 | 'images_list': img_features, 163 | 'rel_features': rel_features, 164 | 'att_features': att_features, 165 | 'num_att_left': num_att_left, 166 | 'num_att_right': num_att_right, 167 | 'name_features': name_features, 168 | 'char_features': char_features, 169 | 'input_idx': input_idx, 170 | 'triples': triples, 171 | 'id2name':id2name, 172 | 'id2rel':id2rel, 173 | 'img_ill':img_ill 174 | }, {"left": left_non_train, "right": right_non_train},left_ents,right_ents, train_ill, test_ill, eval_ill, test_ill_ 175 | 176 | 177 | def load_word2vec(path, dim=300): 178 | """ 179 | glove or fasttext embedding 180 | """ 181 | # print('\n', path) 182 | word2vec = dict() 183 | err_num = 0 184 | err_list = [] 185 | 186 | with open(path, 'r', encoding='utf-8') as file: 187 | for line in tqdm(file.readlines(), desc="load word embedding"): 188 | line = line.strip('\n').split(' ') 189 | if len(line) != dim + 1: 190 | continue 191 | try: 192 | v = np.array(list(map(float, line[1:])), dtype=np.float64) 193 | word2vec[line[0].lower()] = v 194 | except: 195 | err_num += 1 196 | err_list.append(line[0]) 197 | continue 198 | file.close() 199 | print("err list ", err_list) 200 | print("err num ", err_num) 201 | return word2vec 202 | 203 | 204 | def load_char_bigram(path): 205 | """ 206 | character bigrams of translated entity names 207 | """ 208 | # load the translated entity names 209 | ent_names = json.load(open(path, "r")) 210 | # generate the bigram dictionary 211 | char2id = {} 212 | count = 0 213 | for _, name in ent_names: 214 | for word in name: 215 | word = word.lower() 216 | for idx in range(len(word) - 1): 217 | if word[idx:idx + 2] not in char2id: 218 | char2id[word[idx:idx + 2]] = count 219 | count += 1 220 | return ent_names, char2id 221 | 222 | 223 | def load_word_char_features(node_size, word2vec_path, args): 224 | """ 225 | node_size : ent num 226 | """ 227 | name_path = os.path.join(args.data_path, "DBP15K", "translated_ent_name", "dbp_" + args.data_split + ".json") 228 | assert osp.exists(name_path) 229 | save_path_name = os.path.join(args.data_path, "embedding", f"dbp_{args.data_split}_name.pkl") 230 | save_path_char = os.path.join(args.data_path, "embedding", f"dbp_{args.data_split}_char.pkl") 231 | if osp.exists(save_path_name) and osp.exists(save_path_char): 232 | print(f"load entity name emb from {save_path_name} ... ") 233 | ent_vec = pickle.load(open(save_path_name, "rb")) 234 | print(f"load entity char emb from {save_path_char} ... ") 235 | char_vec = pickle.load(open(save_path_char, "rb")) 236 | return ent_vec, char_vec 237 | 238 | word_vecs = load_word2vec(word2vec_path) 239 | ent_names, char2id = load_char_bigram(name_path) 240 | 241 | # generate the word-level features and char-level features 242 | 243 | ent_vec = np.zeros((node_size, 300)) 244 | char_vec = np.zeros((node_size, len(char2id))) 245 | for i, name in ent_names: 246 | k = 0 247 | for word in name: 248 | word = word.lower() 249 | if word in word_vecs: 250 | ent_vec[i] += word_vecs[word] 251 | k += 1 252 | for idx in range(len(word) - 1): 253 | char_vec[i, char2id[word[idx:idx + 2]]] += 1 254 | if k: 255 | ent_vec[i] /= k 256 | else: 257 | ent_vec[i] = np.random.random(300) - 0.5 258 | 259 | if np.sum(char_vec[i]) == 0: 260 | char_vec[i] = np.random.random(len(char2id)) - 0.5 261 | ent_vec[i] = ent_vec[i] / np.linalg.norm(ent_vec[i]) 262 | char_vec[i] = char_vec[i] / np.linalg.norm(char_vec[i]) 263 | 264 | with open(save_path_name, 'wb') as f: 265 | pickle.dump(ent_vec, f) 266 | with open(save_path_char, 'wb') as f: 267 | pickle.dump(char_vec, f) 268 | print("save entity emb done. ") 269 | return ent_vec, char_vec 270 | 271 | 272 | def visual_pivot_induction(args, left_ents, right_ents, img_features, ills): 273 | 274 | l_img_f = img_features[left_ents] # left images 275 | r_img_f = img_features[right_ents] # right images 276 | 277 | img_sim = l_img_f.mm(r_img_f.t()) 278 | topk = args.img_ill_k 279 | two_d_indices = get_topk_indices(img_sim, topk * 100) 280 | del l_img_f, r_img_f, img_sim 281 | 282 | visual_links = [] 283 | used_inds = [] 284 | count = 0 285 | for ind in two_d_indices: 286 | if left_ents[ind[0]] in used_inds: 287 | continue 288 | if right_ents[ind[1]] in used_inds: 289 | continue 290 | used_inds.append(left_ents[ind[0]]) 291 | used_inds.append(right_ents[ind[1]]) 292 | visual_links.append((left_ents[ind[0]], right_ents[ind[1]])) 293 | count += 1 294 | if count == topk: 295 | break 296 | 297 | count = 0.0 298 | for link in visual_links: 299 | if link in ills: 300 | count = count + 1 301 | print(f"{(count / len(visual_links) * 100):.2f}% in true links") 302 | print(f"visual links length: {(len(visual_links))}") 303 | train_ill = np.array(visual_links, dtype=np.int32) 304 | return train_ill 305 | 306 | 307 | def read_raw_data(file_dir, lang=[1, 2]): 308 | """ 309 | Read DBP15k/DWY15k dataset. 310 | Parameters 311 | ---------- 312 | file_dir: root of the dataset. 313 | Returns 314 | ------- 315 | ent2id_dict : A dict mapping from entity name to ids 316 | ills: inter-lingual links (specified by ids) 317 | triples: a list of tuples (ent_id_1, relation_id, ent_id_2) 318 | r_hs: a dictionary containing mappings of relations to a list of entities that are head entities of the relation 319 | r_ts: a dictionary containing mappings of relations to a list of entities that are tail entities of the relation 320 | ids: all ids as a list 321 | """ 322 | print('loading raw data...') 323 | 324 | def read_file(file_paths): 325 | tups = [] 326 | for file_path in file_paths: 327 | with open(file_path, "r", encoding="utf-8") as fr: 328 | for line in fr: 329 | params = line.strip("\n").split("\t") 330 | tups.append(tuple([int(x) for x in params])) 331 | return tups 332 | 333 | def read_dict(file_paths): 334 | ent2id_dict = {} 335 | ids = [] 336 | for file_path in file_paths: 337 | id = set() 338 | with open(file_path, "r", encoding="utf-8") as fr: 339 | for line in fr: 340 | params = line.strip("\n").split("\t") 341 | ent2id_dict[params[1]] = int(params[0]) 342 | id.add(int(params[0])) 343 | ids.append(id) 344 | return ent2id_dict, ids 345 | ent2id_dict, ids = read_dict([file_dir + "/ent_ids_" + str(i) for i in lang]) 346 | ills = read_file([file_dir + "/ill_ent_ids"]) 347 | triples = read_file([file_dir + "/triples_" + str(i) for i in lang]) 348 | r_hs, r_ts = {}, {} 349 | for (h, r, t) in triples: 350 | if r not in r_hs: 351 | r_hs[r] = set() 352 | if r not in r_ts: 353 | r_ts[r] = set() 354 | r_hs[r].add(h) 355 | r_ts[r].add(t) 356 | assert len(r_hs) == len(r_ts) 357 | return ent2id_dict, ills, triples, r_hs, r_ts, ids 358 | 359 | 360 | def loadfile(fn, num=1): 361 | print('loading a file...' + fn) 362 | ret = [] 363 | with open(fn, encoding='utf-8') as f: 364 | for line in f: 365 | th = line[:-1].split('\t') 366 | x = [] 367 | for i in range(num): 368 | x.append(int(th[i])) 369 | ret.append(tuple(x)) 370 | return ret 371 | 372 | 373 | def get_ids(fn,file_dir): 374 | ids = [] 375 | id2name = {} 376 | fbid2name = {} 377 | if 'FB' in fn: 378 | with open(os.path.join(file_dir, 'fbid2name.txt'), encoding='utf-8') as f: 379 | for line in f: 380 | th = line[:-1].split('\t') 381 | fbid2name[th[0]] = th[1] 382 | with open(fn, encoding='utf-8') as f: 383 | for line in f: 384 | th = line[:-1].split('\t') 385 | ids.append(int(th[0])) 386 | name = th[1] 387 | if ''==s[-1]: 436 | s = s[1:-1] 437 | t = s.split('/')[-1].replace('_',' ') 438 | t_ = ' '.join(split_camel_case(t)) 439 | if t_ == '': 440 | return t 441 | return t_ 442 | 443 | 444 | def dbp_value(s): 445 | # print(s) 446 | if '^^' in s: 447 | s = s.split("^^")[0] 448 | if ('<' == s[0] and '>' == s[-1]) or ('\"' == s[0] and '\"' == s[-1]): 449 | s = s[1:-1] 450 | elif '@' in s and s.index('@')>0: 451 | s = '@'.join(s.split('@')[:-1]) 452 | if ('<' == s[0] and '>' == s[-1]) or ('\"' == s[0] and '\"' == s[-1]): 453 | s = s[1:-1] 454 | # print(s) 455 | if s[-1]=='\"': 456 | s = s[:-1] 457 | else: 458 | if ('<' == s[0] and '>' == s[-1]) or ('\"' == s[0] and '\"' == s[-1]): 459 | s = s[1:-1] 460 | return s 461 | if 'e' in s: 462 | return s 463 | 464 | if '-' not in s[1:]: 465 | return s 466 | try: 467 | s_ = s.split('-') 468 | y = int(s_[0].replace('#','0')) 469 | m = int(s_[1]) if s_[1]!='##'else 1 470 | d = int(s_[2]) if s_[2]!='##' and s_[2]!='' else 1 471 | return y + (m-1)/12 +(d-1)/30/12 472 | except: 473 | return s 474 | 475 | 476 | 477 | def load_attr_withNums(datas,fns, ent2id_dict, file_dir, topk=0): 478 | ans = [load_attr_withNum(data,fn,ent2id_dict) for data,fn in zip(datas,fns)] 479 | if topk!=0: 480 | 481 | rels = [] 482 | rels2index = {} 483 | rels2times = {} 484 | cur = 0 485 | att2rel = [] 486 | for i, att in enumerate(ans[0]+ans[1]): 487 | if att[1] not in rels2index: 488 | rels2index[att[1]] = cur 489 | rels.append(att[1]) 490 | cur += 1 491 | rels2times[att[1]] = 0 492 | rels2times[att[1]] += 1 493 | att2rel.append(rels2index[att[1]]) 494 | att2rel = np.array(att2rel) 495 | 496 | rels_left = [] 497 | rels2index_left = {} 498 | cur = 0 499 | att2rel_left = [] 500 | for i, att in enumerate(ans[0]): 501 | if att[1] not in rels2index_left: 502 | rels2index_left[att[1]] = cur 503 | rels_left.append(att[1]) 504 | cur += 1 505 | att2rel_left.append(rels2index_left[att[1]]) 506 | att2rel_left = np.array(att2rel_left) 507 | 508 | 509 | rels_right = [] 510 | rels2index_right = {} 511 | cur = 0 512 | att2rel_right = [] 513 | for i, att in enumerate(ans[1]): 514 | if att[1] not in rels2index_right: 515 | rels2index_right[att[1]] = cur 516 | rels_right.append(att[1]) 517 | cur += 1 518 | att2rel_right.append(rels2index_right[att[1]]) 519 | att2rel_right = np.array(att2rel_right) 520 | 521 | rels_right = set(rels_right) 522 | rels_left = set(rels_left) 523 | rels_inter = rels_left.intersection(rels_right) 524 | if len(rels_inter)==0: 525 | rels_inter = rels 526 | # select topk 527 | rels_inter = sorted(rels_inter, key=lambda x: rels2times[x], reverse=True)[:topk] 528 | 529 | ans_ = [] 530 | for i in ans[0]: 531 | if i[1] in rels_inter: 532 | ans_.append(i) 533 | num_left = len(ans_) 534 | for i in ans[1]: 535 | if i[1] in rels_inter: 536 | ans_.append(i) 537 | num_right = len(ans_)-num_left 538 | return ans_,num_left,num_right 539 | 540 | 541 | 542 | # num_att_left = len(rels2index) 543 | # att_rel_features = np.load(os.path.join(file_dir, 'att_rel_features.npy'), allow_pickle=True) 544 | # rels = torch.FloatTensor(att_rel_features).cuda() 545 | # sim_rels_left = torch.mm(rels[:num_att_left], rels[num_att_left:].T) 546 | # sim_rels_right = torch.mm(rels[num_att_left:], rels[:num_att_left].T) 547 | # # get the max sim at row 548 | # sim_rels_left = torch.max(sim_rels_left, dim=1)[0] 549 | # sim_rels_right = torch.max(sim_rels_right, dim=1)[0] 550 | # # get the topk rels 551 | # topk_rels_left = torch.topk(sim_rels_left, topk, dim=0)[1] 552 | # topk_rels_right = torch.topk(sim_rels_right, topk, dim=0)[1] 553 | # 554 | # topk_rels_left = topk_rels_left.cpu().numpy() 555 | # topk_rels_right = topk_rels_right.cpu().numpy() 556 | # # topk_rels = np.concatenate([topk_rels_left,topk_rels_right+num_att_left]) 557 | # 558 | # 559 | # # contain topkrels 560 | # common_elements = np.in1d(att2rel, topk_rels_left) 561 | # common_elements_indices = list(np.where(common_elements)[0]) 562 | # ans_ = [] 563 | # for i in common_elements_indices: 564 | # ans_.append(ans[0][i]) 565 | # num_left = len(ans_) 566 | # 567 | # rels = [] 568 | # rels2index = {} 569 | # cur = 0 570 | # att2rel = [] 571 | # for i,att in enumerate(ans[1]): 572 | # if att[1] not in rels2index: 573 | # rels2index[att[1]] = cur 574 | # rels.append(att[1]) 575 | # cur += 1 576 | # att2rel.append(rels2index[att[1]]) 577 | # att2rel = np.array(att2rel) 578 | # # contain topkrels 579 | # common_elements = np.in1d(att2rel, topk_rels_right) 580 | # common_elements_indices = list(np.where(common_elements)[0]) 581 | # for i in common_elements_indices: 582 | # ans_.append(ans[1][i]) 583 | # num_right = len(ans_) - num_left 584 | # return ans_,num_left,num_right 585 | 586 | 587 | 588 | 589 | 590 | return ans[0]+ans[1], len(ans[0]), len(ans[1]) 591 | def load_attr_withNum(data, fn, ent2id): 592 | 593 | with open(fn, 'r',encoding='utf-8') as f: 594 | Numericals = f.readlines() 595 | if data == 'FB15K' or data == 'DB15K' or data=='YAGO15K': 596 | Numericals_ = list(set(Numericals)) 597 | Numericals_.sort(key = Numericals.index) 598 | Numericals = Numericals_ 599 | 600 | if data=='FB15K': 601 | Numericals = [i[:-1].split('\t') for i in Numericals] 602 | Numericals = [(ent2id[i[0]], i[1][1:-1].replace('http://rdf.freebase.com/ns/', '').split('.')[-1].replace('_',' '), i[2]) for i in 603 | Numericals] 604 | elif data=='DB15K': 605 | Numericals = [i[:-1].split(' ') if '\t' not in i else i[:-1].split('\t') for i in Numericals] 606 | Numericals = [(ent2id[i[0]], db_str(i[1]), db_time(i[2])) for i in Numericals] 607 | 608 | elif data=='YAGO15K': 609 | Numericals = [i[:-1].split(' ') if '\t' not in i else i[:-1].split('\t') for i in Numericals] 610 | Numericals = [(ent2id[i[0]], db_str(i[1]), db_time(i[2])) for i in Numericals] 611 | elif data=='oea': 612 | Numericals = [i[:-1].split('\t') for i in Numericals] 613 | Numericals = [(ent2id[i[0]], dbp_str(i[1]), dbp_value(i[2])) for i in Numericals] 614 | else: 615 | Numericals = [i[:-1].split(' ') if '\t' not in i else i[:-1].split('\t') for i in Numericals] 616 | Numericals = [(ent2id[i[0][1:-1]], dbp_str(i[1]), dbp_value(' '.join(i[2:]))) for i in Numericals] 617 | 618 | return Numericals 619 | 620 | 621 | # The most frequent attributes are selected to save space 622 | def load_attr(fns, e, ent2id, topA=1000): 623 | cnt = {} 624 | for fn in fns: 625 | with open(fn, 'r', encoding='utf-8') as f: 626 | for line in f: 627 | th = line[:-1].split('\t') 628 | if th[0] not in ent2id: 629 | continue 630 | for i in range(1, len(th)): 631 | if th[i] not in cnt: 632 | cnt[th[i]] = 1 633 | else: 634 | cnt[th[i]] += 1 635 | fre = [(k, cnt[k]) for k in sorted(cnt, key=cnt.get, reverse=True)] 636 | attr2id = {} 637 | # pdb.set_trace() 638 | topA = min(1000, len(fre)) 639 | for i in range(topA): 640 | attr2id[fre[i][0]] = i 641 | attr = np.zeros((e, topA), dtype=np.float32) 642 | for fn in fns: 643 | with open(fn, 'r', encoding='utf-8') as f: 644 | for line in f: 645 | th = line[:-1].split('\t') 646 | if th[0] in ent2id: 647 | for i in range(1, len(th)): 648 | if th[i] in attr2id: 649 | attr[ent2id[th[0]]][attr2id[th[i]]] = 1.0 650 | return attr 651 | 652 | 653 | def load_relation(e, KG, topR=1000): 654 | # (39654, 1000) 655 | rel_mat = np.zeros((e, topR), dtype=np.float32) 656 | rels = np.array(KG)[:, 1] 657 | top_rels = Counter(rels).most_common(topR) 658 | rel_index_dict = {r: i for i, (r, cnt) in enumerate(top_rels)} 659 | for tri in KG: 660 | h = tri[0] 661 | r = tri[1] 662 | o = tri[2] 663 | if r in rel_index_dict: 664 | rel_mat[h][rel_index_dict[r]] += 1. 665 | rel_mat[o][rel_index_dict[r]] += 1. 666 | return np.array(rel_mat) 667 | 668 | 669 | def load_json_embd(path): 670 | embd_dict = {} 671 | with open(path) as f: 672 | for line in f: 673 | example = json.loads(line.strip()) 674 | vec = np.array([float(e) for e in example['feature'].split()]) 675 | embd_dict[int(example['guid'])] = vec 676 | return embd_dict 677 | 678 | 679 | def load_img(e_num, path): 680 | img_dict = pickle.load(open(path, "rb")) 681 | # init unknown img vector with mean and std deviation of the known's 682 | imgs_np = np.array(list(img_dict.values())) 683 | mean = np.mean(imgs_np, axis=0) 684 | std = np.std(imgs_np, axis=0) 685 | # img_embd = np.array([np.zeros_like(img_dict[0]) for i in range(e_num)]) # no image 686 | # img_embd = np.array([img_dict[i] if i in img_dict else np.zeros_like(img_dict[0]) for i in range(e_num)]) 687 | 688 | img_embd = np.array([img_dict[i] if i in img_dict else np.random.normal(mean, std, mean.shape[0]) for i in range(e_num)]) 689 | print(f"{(100 * len(img_dict) / e_num):.2f}% entities have images") 690 | return img_embd 691 | -------------------------------------------------------------------------------- /load_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import torch 5 | from scipy.sparse import csr_matrix 6 | import numpy as np 7 | from collections import defaultdict 8 | from data import load_eva_data 9 | import pickle 10 | from tqdm import tqdm 11 | # import lmdb 12 | class DataLoader: 13 | def __init__(self, args): 14 | 15 | KGs, non_train, left_ents, right_ents, train_ill, test_ill, eval_ill, test_ill_ = load_eva_data(args) 16 | ent_num = KGs['ent_num'] 17 | rel_num = KGs['rel_num'] 18 | self.img_ill = KGs['img_ill'] 19 | self.use_img_ill = args.use_img_ill 20 | self.images_list = KGs['images_list'] 21 | self.rel_features = KGs['rel_features'] 22 | self.att_features = KGs['att_features'] 23 | self.num_att_left = KGs['num_att_left'] 24 | self.num_att_right = KGs['num_att_right'] 25 | self.id2name = KGs['id2name'] 26 | self.id2rel = KGs['id2rel'] 27 | self.left_ents = [i for i in range(len(left_ents))] 28 | self.right_ents = [len(left_ents) + i for i in range(len(right_ents))] 29 | old_ids = np.array(left_ents+right_ents) 30 | # new_ids = torch.arange(len(self.left_ents+self.right_ents)) 31 | # old2new = torch.zeros(len(self.left_ents+self.right_ents)).long() 32 | # old2new[old_ids] = new_ids 33 | # self.old2new = old2new 34 | self.old_ids = old_ids 35 | if args.mm: 36 | self.images_list = self.images_list[self.old_ids] 37 | self.old2new_dict = {oldid:newid for newid,oldid in enumerate(left_ents+right_ents)} 38 | triples = KGs['triples'] 39 | triples = [(self.old2new_dict[tri[0]],tri[1],self.old2new_dict[tri[2]]) for tri in triples] 40 | train_ill = np.array([(self.old2new_dict[tri[0]],self.old2new_dict[tri[1]]) for tri in train_ill]) 41 | test_ill = np.array([(self.old2new_dict[tri[0]],self.old2new_dict[tri[1]]) for tri in test_ill]) 42 | if args.mm: 43 | self.img_ill = np.array([(self.old2new_dict[tri[0]],self.old2new_dict[tri[1]]) for tri in self.img_ill]) 44 | 45 | 46 | # self.att_features_text = np.array(KGs['att_features']) 47 | self.att2rel ,self.rels = self.process_rels(self.att_features) 48 | self.att_ids = [self.old2new_dict[i[0]] for i in self.att_features] 49 | 50 | self.ids_att = {} 51 | for att_index,ids in enumerate(self.att_ids): 52 | if ids not in self.ids_att: 53 | self.ids_att[ids] = [] 54 | self.ids_att[ids].append(att_index) 55 | # self.test_cache_url = os.path.join(args.data_path, args.data_choice, args.data_split, f'test_{args.data_rate}') 56 | # self.test_cache = {} 57 | 58 | if args.mm: 59 | if args.topk == 0: 60 | if os.path.exists(os.path.join(args.data_path, args.data_choice, args.data_split, 'att_features.npy')): 61 | self.att_features = np.load(os.path.join(args.data_path, args.data_choice, args.data_split, 'att_features.npy'), allow_pickle=True) 62 | self.att_rel_features = np.load(os.path.join(args.data_path, args.data_choice, args.data_split, 'att_rel_features.npy'), allow_pickle=True) 63 | self.att_val_features = np.load(os.path.join(args.data_path, args.data_choice, args.data_split, 'att_val_features.npy'), allow_pickle=True) 64 | else: 65 | self.att_features, self.att_rel_features,self.att_val_features = self.bert_feature() 66 | np.save(os.path.join(args.data_path, args.data_choice, args.data_split, 'att_features.npy'), self.att_features) 67 | np.save(os.path.join(args.data_path, args.data_choice, args.data_split, 'att_rel_features.npy'), self.att_rel_features) 68 | np.save(os.path.join(args.data_path, args.data_choice, args.data_split, 'att_val_features.npy'), self.att_val_features) 69 | else: 70 | if os.path.exists(os.path.join(args.data_path, args.data_choice, args.data_split, f'att_features{args.topk}.npy')): 71 | self.att_features = np.load(os.path.join(args.data_path, args.data_choice, args.data_split, f'att_features{args.topk}.npy'), allow_pickle=True) 72 | self.att_rel_features = np.load(os.path.join(args.data_path, args.data_choice, args.data_split, f'att_rel_features{args.topk}.npy'), allow_pickle=True) 73 | self.att_val_features = np.load(os.path.join(args.data_path, args.data_choice, args.data_split, f'att_val_features{args.topk}.npy'), allow_pickle=True) 74 | else: 75 | self.att_features, self.att_rel_features,self.att_val_features = self.bert_feature() 76 | np.save(os.path.join(args.data_path, args.data_choice, args.data_split, f'att_features{args.topk}.npy'), self.att_features) 77 | np.save(os.path.join(args.data_path, args.data_choice, args.data_split, f'att_rel_features{args.topk}.npy'), self.att_rel_features) 78 | np.save(os.path.join(args.data_path, args.data_choice, args.data_split, f'att_val_features{args.topk}.npy'), self.att_val_features) 79 | # for i1,i2 in train_ill: 80 | # f1 = self.att_features[np.array(self.att_ids)==i1] 81 | # f2 = self.att_features[np.array(self.att_ids)==i2] 82 | # print('-'*30) 83 | # print('1',self.att_features_text[np.array(self.att_ids)==i1]) 84 | # print('2',self.att_features_text[np.array(self.att_ids)==i2]) 85 | 86 | # for f1i in f1: 87 | # for f2i in f2: 88 | # print(f1i.dot(f2i)) 89 | # f1 = self.att_rel_features[self.att2rel[np.array(self.att_ids)==i1]] 90 | # f2 = self.att_rel_features[self.att2rel[np.array(self.att_ids)==i2]] 91 | # print() 92 | # for f1i in f1: 93 | # for f2i in f2: 94 | # print(f1i.dot(f2i)) 95 | 96 | # f1 = self.att_val_features[np.array(self.att_ids)==i1] 97 | # f2 = self.att_val_features[np.array(self.att_ids)==i2] 98 | # print() 99 | # for f1i in f1: 100 | # for f2i in f2: 101 | # print(f1i.dot(f2i)) 102 | 103 | 104 | 105 | 106 | 107 | self.n_ent = ent_num 108 | self.n_rel = rel_num 109 | 110 | self.filters = defaultdict(lambda: set()) 111 | 112 | self.fact_triple = triples 113 | 114 | self.train_triple = self.ill2triples(train_ill) 115 | self.valid_triple = eval_ill # None 116 | self.test_triple = self.ill2triples(test_ill) 117 | 118 | # add inverse 119 | self.fact_data = self.double_triple(self.fact_triple) 120 | # self.train_data = np.array(self.double_triple(self.train_triple)) 121 | # self.valid_data = self.double_triple(self.valid_triple) 122 | self.test_data = self.double_triple(self.test_triple, ill=True) 123 | self.test_data = np.array(self.test_data) 124 | self.train_data = self.double_triple(self.train_triple, ill=True) 125 | self.train_data = np.array(self.train_data) 126 | if self.use_img_ill: 127 | self.img_ill_triple = self.img_ill2triples(self.img_ill) 128 | self.img_ill_triple = self.double_triple(self.img_ill_triple, ill=True) 129 | self.img_ill_triple = np.array(self.img_ill_triple) 130 | self.img_ill_data = torch.LongTensor(self.img_ill_triple).cuda() 131 | 132 | # self.KG,self.M_sub = self.load_graph(self.fact_data) # do it in shuffle_train 133 | self.tKG = self.load_graph(self.fact_data + self.double_triple(self.train_triple, ill=True)) 134 | self.tKG = torch.LongTensor(self.tKG).cuda() 135 | 136 | # in torch 137 | idd = np.concatenate([np.expand_dims(np.arange(self.n_ent), 1), 2 * self.n_rel * np.ones((self.n_ent, 1)), 138 | np.expand_dims(np.arange(self.n_ent), 1)], 1) 139 | self.fact_data = np.concatenate([np.array(self.fact_data), idd], 0) 140 | self.fact_data = torch.LongTensor(self.fact_data).cuda() 141 | # self.node2index = {} 142 | # for i, triple in enumerate(self.train_triple): 143 | # h, r, t = triple 144 | # assert h not in self.node2index 145 | # assert t not in self.node2index 146 | # self.node2index[h] = i 147 | # self.node2index[t] = i 148 | # self.train_triple = torch.LongTensor(self.train_triple).cuda() 149 | 150 | 151 | self.n_test = len(self.test_data) 152 | self.n_train = len(self.train_data) 153 | self.shuffle_train() 154 | 155 | # if os.path.exists(self.test_cache_url): 156 | # self.test_env = lmdb.open(self.test_cache_url) 157 | # else: 158 | # self.test_env = lmdb.open(self.test_cache_url, map_size=200*1024 * 1024 * 1024, max_dbs=1) 159 | # self.preprocess_test() 160 | def process_rels(self, atts): 161 | rels = [] 162 | rels2index = {} 163 | cur = 0 164 | att2rel = [] 165 | for i,att in enumerate(atts): 166 | if att[1] not in rels2index: 167 | rels2index[att[1]] = cur 168 | rels.append(att[1]) 169 | cur += 1 170 | att2rel.append(rels2index[att[1]]) 171 | return np.array(att2rel),rels 172 | 173 | 174 | 175 | def bert_feature(self, ): 176 | from sentence_transformers import SentenceTransformer 177 | from transformers import BertTokenizer, BertModel 178 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 179 | # model = BertModel.from_pretrained("bert-base-uncased").cuda() 180 | # model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2').cuda() 181 | model = SentenceTransformer('sentence-transformers/LaBSE').cuda() 182 | 183 | outputs = [] 184 | texts = [a + ' ' + str(v) for i,a,v in self.att_features] 185 | batch_size = 2048 186 | sent_batch = [texts[i:i + batch_size] for i in range(0, len(texts), batch_size)] 187 | for sent in sent_batch: 188 | 189 | # encoded_input = tokenizer(sent, return_tensors='pt', padding=True, truncation=True, max_length=512) 190 | # #cuda 191 | # encoded_input.data['input_ids'] = encoded_input.data['input_ids'].cuda() 192 | # encoded_input.data['attention_mask'] = encoded_input.data['attention_mask'].cuda() 193 | # encoded_input.data['token_type_ids'] = encoded_input.data['token_type_ids'].cuda() 194 | with torch.no_grad(): 195 | # output = model(**encoded_input) 196 | output = model.encode(sent) 197 | outputs.append(output) 198 | outputs = np.concatenate(outputs) 199 | 200 | # batch_size = 512 201 | sent_batch = [self.rels[i:i + batch_size] for i in range(0, len(self.rels), batch_size)] 202 | rel_outputs = [] 203 | for sent in sent_batch: 204 | # encoded_input = tokenizer(sent, return_tensors='pt', padding=True, truncation=True, max_length=512) 205 | # #cuda 206 | # encoded_input.data['input_ids'] = encoded_input.data['input_ids'].cuda() 207 | # encoded_input.data['attention_mask'] = encoded_input.data['attention_mask'].cuda() 208 | # encoded_input.data['token_type_ids'] = encoded_input.data['token_type_ids'].cuda() 209 | with torch.no_grad(): 210 | # output = model(**encoded_input) 211 | output = model.encode(sent) 212 | rel_outputs.append(output) 213 | rel_outputs = np.concatenate(rel_outputs) 214 | 215 | vals = [str(i[2]) for i in self.att_features] 216 | # batch_size = 512 217 | sent_batch = [vals[i:i + batch_size] for i in range(0, len(vals), batch_size)] 218 | val_outputs = [] 219 | for sent in sent_batch: 220 | # encoded_input = tokenizer(sent, return_tensors='pt', padding=True, truncation=True, max_length=512) 221 | # #cuda 222 | # encoded_input.data['input_ids'] = encoded_input.data['input_ids'].cuda() 223 | # encoded_input.data['attention_mask'] = encoded_input.data['attention_mask'].cuda() 224 | # encoded_input.data['token_type_ids'] = encoded_input.data['token_type_ids'].cuda() 225 | with torch.no_grad(): 226 | # output = model(**encoded_input) 227 | output = model.encode(sent) 228 | val_outputs.append(output) 229 | val_outputs = np.concatenate(val_outputs) 230 | del model 231 | return outputs, rel_outputs, val_outputs 232 | 233 | 234 | 235 | def ill2triples(self, ill): 236 | return [(i[0], self.n_rel * 2 + 1, i[1]) for i in ill] 237 | 238 | def img_ill2triples(self, ill): 239 | return [(i[0], self.n_rel * 2 + 3, i[1]) for i in ill] 240 | 241 | # def read_triples(self, filename): 242 | # triples = [] 243 | # with open(os.path.join(self.task_dir, filename)) as f: 244 | # for line in f: 245 | # h, r, t = line.strip().split() 246 | # h, r, t = self.entity2id[h], self.relation2id[r], self.entity2id[t] 247 | # triples.append([h, r, t]) 248 | # self.filters[(h, r)].add(t) 249 | # self.filters[(t, r + self.n_rel)].add(h) 250 | # return triples 251 | 252 | def double_triple(self, triples, ill=False): 253 | new_triples = [] 254 | for triple in triples: 255 | h, r, t = triple 256 | new_triples.append([t, r + self.n_rel if not ill else r+1, h]) 257 | return triples + new_triples 258 | 259 | def load_graph(self, triples): 260 | idd = np.concatenate([np.expand_dims(np.arange(self.n_ent), 1), 2 * self.n_rel * np.ones((self.n_ent, 1)), 261 | np.expand_dims(np.arange(self.n_ent), 1)], 1) 262 | 263 | KG = np.concatenate([np.array(triples), idd], 0) 264 | # n_fact = len(KG) 265 | # M_sub = csr_matrix((np.ones((n_fact,)), (np.arange(n_fact), KG[:, 0])), 266 | # shape=(n_fact, self.n_ent)) 267 | return KG 268 | 269 | 270 | def get_subgraphs(self, head_nodes, layer=3,mode='train',sim=None): 271 | all_edges = [] 272 | for index,head_node in enumerate(head_nodes): 273 | all_edge = self.get_subgraph(head_node, index, layer, mode,sim=sim) 274 | all_edges.append(all_edge) 275 | all_nodes = [] 276 | layer_edges = [] 277 | old_nodes_new_idxs = [] 278 | old_nodes = [] 279 | for i in range(layer): 280 | edges = [] 281 | for j in range(len(all_edges)): 282 | edges.append(all_edges[j][i]) 283 | edges = torch.cat(edges, dim=0) 284 | edges = edges.long() 285 | 286 | head_nodes, head_index = torch.unique(edges[:, [0, 1]], dim=0, sorted=True, return_inverse=True) 287 | tail_nodes, tail_index = torch.unique(edges[:, [0, 3]], dim=0, sorted=True, return_inverse=True) 288 | sampled_edges = torch.cat([edges, head_index.unsqueeze(1), tail_index.unsqueeze(1)], 1) 289 | 290 | 291 | mask = sampled_edges[:, 2] == (self.n_rel * 2) 292 | old_node, old_idx = head_index[mask].sort() 293 | old_nodes_new_idx = tail_index[mask][old_idx] 294 | all_nodes.append(tail_nodes) 295 | layer_edges.append(sampled_edges) 296 | old_nodes_new_idxs.append(old_nodes_new_idx) 297 | old_nodes.append(old_node) 298 | 299 | 300 | return all_nodes, layer_edges, old_nodes_new_idxs, old_nodes 301 | # 302 | def get_subgraph(self, head_node, index, layer, mode, max_size=500, sim=None): 303 | if mode == 'train': 304 | # # set false to self.node2index[node] 305 | # mask = torch.ones(len(self.train_triple), dtype=torch.bool).cuda() 306 | # mask[self.node2index[head_node.item()]] = False 307 | # support = self.train_triple[mask] 308 | # reverse_support = support[:, [2, 1, 0]] 309 | # reverse_support[:, 1] += 1 310 | # support = torch.cat((support, reverse_support), dim=0) 311 | # KG = torch.cat((support,self.fact_data),dim=0) 312 | KG=self.KG 313 | else: 314 | KG = self.tKG 315 | if sim is not None: 316 | KG = torch.cat((KG, sim), dim=0) 317 | if self.use_img_ill: 318 | KG = torch.cat((KG, self.img_ill_data), dim=0) 319 | row, col = KG[:, 0], KG[:, 2] 320 | node_mask = row.new_empty(self.n_ent, dtype=torch.bool) 321 | # edge_mask = row.new_empty(row.size(0), dtype=torch.bool) 322 | subsets = [torch.LongTensor([head_node]).cuda()] 323 | raw_layer_edges = [] 324 | for i in range(layer): 325 | node_mask.fill_(False) 326 | node_mask[subsets[-1]] = True 327 | edge_mask = torch.index_select(node_mask, 0, row) 328 | subsets.append(torch.unique(col[edge_mask])) 329 | raw_layer_edges.append(edge_mask) 330 | # nodes, edges, old_nodes_new_idx = self.get_neighbors(nodes.data.cpu().numpy()) 331 | # delete target not in the other KG 332 | tail_node = self.left_ents if head_node.item() >= len(self.left_ents) else self.right_ents 333 | tail_node = torch.LongTensor(tail_node).cuda() 334 | node_mask_ = row.new_empty(self.n_ent, dtype=torch.bool) 335 | node_mask_.fill_(False) 336 | node_mask_[tail_node] = True 337 | tail_set = subsets[-1] 338 | node_mask.fill_(False) 339 | node_mask[tail_set] = True 340 | node_mask = node_mask & node_mask_ 341 | layer_edges = [] 342 | for i in reversed(range(layer)): 343 | edge_mask = torch.index_select(node_mask, 0, col) 344 | edge_mask = edge_mask & raw_layer_edges[i] 345 | node_mask_.fill_(False) 346 | node_mask_[row[edge_mask]] = True 347 | node_mask = node_mask | node_mask_ 348 | layer_edges.append(KG[edge_mask]) 349 | layer_edges = layer_edges[::-1] 350 | batched_edges = [] 351 | for i in range(layer): 352 | layer_edges[i] = torch.unique(layer_edges[i], dim=0) 353 | batched_edges.append(torch.cat([torch.ones(len(layer_edges[i])).unsqueeze(1).cuda() * index, layer_edges[i]], 1)) 354 | return batched_edges 355 | 356 | def get_vis_subgraph(self, head_node, tail_node, layer, max_size=500, sim=None): 357 | 358 | KG = self.tKG 359 | if sim is not None: 360 | KG = torch.cat((KG, sim), dim=0) 361 | row, col = KG[:, 0], KG[:, 2] 362 | node_mask = row.new_empty(self.n_ent, dtype=torch.bool) 363 | # edge_mask = row.new_empty(row.size(0), dtype=torch.bool) 364 | subsets = [torch.LongTensor([head_node]).cuda()] 365 | raw_layer_edges = [] 366 | for i in range(layer): 367 | node_mask.fill_(False) 368 | node_mask[subsets[-1]] = True 369 | edge_mask = torch.index_select(node_mask, 0, row) 370 | subsets.append(torch.unique(col[edge_mask])) 371 | raw_layer_edges.append(edge_mask) 372 | # nodes, edges, old_nodes_new_idx = self.get_neighbors(nodes.data.cpu().numpy()) 373 | # delete target not in the other KG 374 | # tail_node = self.left_ents if head_node.item() >= len(self.left_ents) else self.right_ents 375 | tail_node = torch.LongTensor([tail_node]).cuda() 376 | node_mask_ = row.new_empty(self.n_ent, dtype=torch.bool) 377 | node_mask_.fill_(False) 378 | node_mask_[tail_node] = True 379 | tail_set = subsets[-1] 380 | node_mask.fill_(False) 381 | node_mask[tail_set] = True 382 | node_mask = node_mask & node_mask_ 383 | layer_edges = [] 384 | for i in reversed(range(layer)): 385 | edge_mask = torch.index_select(node_mask, 0, col) 386 | edge_mask = edge_mask & raw_layer_edges[i] 387 | node_mask_.fill_(False) 388 | node_mask_[row[edge_mask]] = True 389 | node_mask = node_mask | node_mask_ 390 | layer_edges.append(KG[edge_mask]) 391 | layer_edges = layer_edges[::-1] 392 | batched_edges = [] 393 | for i in range(layer): 394 | layer_edges[i] = torch.unique(layer_edges[i], dim=0) 395 | batched_edges.append(layer_edges[i]) 396 | return batched_edges 397 | 398 | # def get_neighbors(self, nodes, mode='train', n_hop=0): 399 | # if mode == 'train': 400 | # KG = self.KG 401 | # M_sub = self.M_sub 402 | # else: 403 | # KG = self.tKG 404 | # M_sub = self.tM_sub 405 | # # if self.test_cache 406 | # 407 | # # nodes: n_node x 2 with (batch_idx, node_idx) 408 | # node_1hot = csr_matrix((np.ones(len(nodes)), (nodes[:, 1], nodes[:, 0])), shape=(self.n_ent, nodes.shape[0])) # (n_ent, batch_size) 409 | # edge_1hot = M_sub.dot(node_1hot) 410 | # edges = np.nonzero(edge_1hot) 411 | # sampled_edges = np.concatenate([np.expand_dims(edges[1], 1), KG[edges[0]]], 412 | # axis=1) # (batch_idx, head, rela, tail) 413 | # sampled_edges = torch.LongTensor(sampled_edges).cuda() 414 | # 415 | # # index to nodes 416 | # head_nodes, head_index = torch.unique(sampled_edges[:, [0, 1]], dim=0, sorted=True, return_inverse=True) 417 | # tail_nodes, tail_index = torch.unique(sampled_edges[:, [0, 3]], dim=0, sorted=True, return_inverse=True) 418 | # 419 | # sampled_edges = torch.cat([sampled_edges, head_index.unsqueeze(1), tail_index.unsqueeze(1)], 1) 420 | # 421 | # mask = sampled_edges[:, 2] == (self.n_rel * 2) 422 | # _, old_idx = head_index[mask].sort() 423 | # old_nodes_new_idx = tail_index[mask][old_idx] 424 | # 425 | # return tail_nodes, sampled_edges, old_nodes_new_idx 426 | 427 | # def get_neighbor(self, node, mode='train', n_hop=0): 428 | # if mode == 'train': 429 | # # set false to self.node2index[node] 430 | # mask = torch.ones(len(self.train_triple), dtype=torch.bool) 431 | # mask[self.node2index[node]] = False 432 | # KG = torch.cat(self.train_triple[mask],self.fact_data) 433 | # 434 | # else: 435 | # KG = self.tKG 436 | # # if self.test_cache 437 | # 438 | # # nodes: n_node x 2 with (batch_idx, node_idx) 439 | # # node_1hot = csr_matrix((np.ones(len(nodes)), (nodes[:, 1], nodes[:, 0])), shape=(self.n_ent, nodes.shape[0])) # (n_ent, batch_size) 440 | # # edge_1hot = M_sub.dot(node_1hot) 441 | # edges = KG[:, 0]==node 442 | # edges = np.nonzero(edges) 443 | # sampled_edges = KG[edges[0]] # (head, rela, tail) 444 | # sampled_edges = torch.LongTensor(sampled_edges).cuda() 445 | # 446 | # # index to nodes 447 | # head_nodes, head_index = torch.unique(sampled_edges[:, 1], dim=0, sorted=True, return_inverse=True) 448 | # tail_nodes, tail_index = torch.unique(sampled_edges[:, 3], dim=0, sorted=True, return_inverse=True) 449 | # 450 | # sampled_edges = torch.cat([sampled_edges, head_index.unsqueeze(1), tail_index.unsqueeze(1)], 1) 451 | # 452 | # # mask = sampled_edges[:, 2] == (self.n_rel * 2) 453 | # # _, old_idx = head_index[mask].sort() 454 | # # old_nodes_new_idx = tail_index[mask][old_idx] 455 | # 456 | # return tail_nodes, sampled_edges 457 | 458 | def get_batch(self, batch_idx, steps=2, data='train'): 459 | if data == 'train': 460 | return self.train_data[batch_idx] 461 | if data == 'valid': 462 | return None 463 | if data == 'test': 464 | return self.test_data[batch_idx] 465 | 466 | # subs = [] 467 | # rels = [] 468 | # objs = [] 469 | # 470 | # subs = query[batch_idx, 0] 471 | # rels = query[batch_idx, 1] 472 | # objs = np.zeros((len(batch_idx), self.n_ent)) 473 | # for i in range(len(batch_idx)): 474 | # objs[i][answer[batch_idx[i]]] = 1 475 | # return subs, rels, objs 476 | 477 | def shuffle_train(self, ): 478 | # fact_triple = np.array(self.fact_triple) 479 | # train_triple = np.array(self.train_triple) 480 | # all_triple = np.concatenate([fact_triple, train_triple], axis=0) 481 | # n_all = len(all_triple) 482 | # rand_idx = np.random.permutation(n_all) 483 | # all_triple = all_triple[rand_idx] 484 | 485 | # random shuffle train_triples 486 | random.shuffle(self.train_triple) 487 | # support/query split 3/1 488 | support_triple = self.train_triple[:len(self.train_triple) * 3 // 4] 489 | query_triple = self.train_triple[len(self.train_triple) * 3 // 4:] 490 | # add inverse triples 491 | support_triple = self.double_triple(support_triple, ill=True) 492 | query_triple = self.double_triple(query_triple, ill=True) 493 | support = torch.LongTensor(support_triple).cuda() 494 | self.KG = torch.cat((support,self.fact_data),dim=0) 495 | # now the fact triples are fact_triple + support_triple 496 | # self.KG, self.M_sub = self.load_graph(self.fact_data + support_triple) 497 | self.n_train = len(query_triple) 498 | self.train_data = np.array(query_triple) 499 | 500 | # # increase the ratio of fact_data, e.g., 3/4->4/5, can increase the performance 501 | # self.fact_data = self.double_triple(all_triple[:n_all * 3 // 4].tolist()) 502 | # self.train_data = np.array(self.double_triple(all_triple[n_all * 3 // 4:].tolist())) 503 | # self.n_train = len(self.train_data) 504 | # self.KG,self.M_sub = self.load_graph(self.fact_data) 505 | 506 | print('n_train:', self.n_train, 'n_test:', self.n_test) 507 | 508 | def preprocess_test(self, ): 509 | batch_size = 4 510 | n_data = self.n_test 511 | n_batch = n_data // batch_size + (n_data % batch_size > 0) 512 | for i in tqdm(range(n_batch)): 513 | start = i * batch_size 514 | end = min(n_data, (i + 1) * batch_size) 515 | batch_idx = np.arange(start, end) 516 | triple = self.get_batch(batch_idx, data='test') 517 | subs, rels, objs = triple[:, 0], triple[:, 1], triple[:, 2] 518 | print(subs, rels, objs) 519 | n = len(subs) 520 | q_sub = torch.LongTensor(subs).cuda() 521 | nodes = torch.cat([torch.arange(n).unsqueeze(1).cuda(), q_sub.unsqueeze(1)], 1) 522 | for h in range(5): 523 | nodes, edges, old_nodes_new_idx = self.get_neighbors(nodes.data.cpu().numpy(), mode='test', 524 | n_hop=h) 525 | # to np 526 | # self.test_cache[(i, h)] = (nodes.cpu().numpy(), edges.cpu().numpy(), old_nodes_new_idx.cpu().numpy()) 527 | # use lmdb write 528 | with self.test_env.begin(write=True) as txn: 529 | txn.put(f'{i}_{h}'.encode(), pickle.dumps((nodes.cpu().numpy(), edges.cpu().numpy(), old_nodes_new_idx.cpu().numpy()))) 530 | # pickle.dump(self.test_cache, open(self.test_cache_url, 'wb')) 531 | 532 | def get_test_cache(self, batch_idx, h): 533 | #use lmdb read 534 | with self.test_env.begin(write=False) as txn: 535 | nodes, edges, old_nodes_new_idx = pickle.loads(txn.get(f'{batch_idx}_{h}'.encode())) 536 | return nodes, edges, old_nodes_new_idx 537 | # return self.test_cache[(batch_idx, h)] 538 | 539 | 540 | # def save_cache(self): 541 | # with open(self.cache_path, 'wb') as f: 542 | # pickle.dump(self.edge_cache, f) 543 | # 544 | # def load_cache(self): 545 | # with open(self.cache_path, 'rb') as f: 546 | # self.edge_cache = pickle.load(f) 547 | # print("load cache from {}".format(self.cache_path)) 548 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_scatter import scatter 4 | import torch.nn.functional as F 5 | from torch_geometric.utils import softmax 6 | from torch_geometric.nn.models import MLP 7 | class Text_enc(nn.Module): 8 | def __init__(self, params): 9 | super().__init__() 10 | self.hidden_dim = params.text_dim 11 | self.u = nn.Linear(params.text_dim, 1) 12 | self.W = nn.Linear(2*params.text_dim , params.text_dim) 13 | 14 | def forward(self, ent_num, Textid, Text, Text_rel): 15 | # print(edge_index.device) 16 | 17 | a_v = torch.cat((Text_rel,Text),-1) 18 | o = self.u(Text_rel) 19 | alpha = softmax(o, Textid, None, ent_num) 20 | text = scatter(alpha * a_v, index=Textid, dim=0, dim_size=ent_num, reduce='sum') 21 | 22 | return text 23 | 24 | 25 | # class FeatureMapping(nn.Module): 26 | # def __init__(self, params): 27 | # super().__init__() 28 | # self.params = params 29 | # self.in_dims = {'Stru': params.stru_dim, 'Text': params.text_dim, 'IMG': params.hidden_dim, 30 | # 'Temporal': params.time_dim, 'Numerical': params.time_dim} 31 | # self.out_dim = params.hidden_dim 32 | # modals = ['Stru', 'Text', 'IMG', 'Temporal', 'Numerical'] 33 | # self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 34 | # if self.device == 'cuda': 35 | 36 | # self.W_list = { 37 | # modal: MLP(in_channels=self.in_dims[modal], out_channels=self.out_dim, 38 | # hidden_channels=params.MLP_hidden_dim, num_layers=params.MLP_num_layers, 39 | # dropout=params.MLP_dropout, norm=None).cuda() for modal in modals 40 | # } 41 | # else: 42 | # self.W_list = { 43 | # modal: MLP(in_channels=self.in_dims[modal], out_channels=self.out_dim, 44 | # hidden_channels=params.MLP_hidden_dim, num_layers=params.MLP_num_layers, 45 | # dropout=params.MLP_dropout, norm=None) for modal in modals 46 | # } 47 | # self.W_list = nn.ModuleDict(self.W_list) 48 | 49 | # def forward(self, features): 50 | # new_features = {} 51 | # modals = ['Text'] 52 | 53 | # for modal, feature in features.items(): 54 | # if modal not in modals: 55 | # continue 56 | # # print(modal,feature.device) 57 | # new_features[modal] = self.W_list[modal](feature) 58 | # mean_feature = torch.mean(torch.stack(list(new_features.values())), dim=0) 59 | # return new_features, mean_feature 60 | 61 | 62 | class MMFeature(nn.Module): 63 | def __init__(self, n_ent, params): 64 | super().__init__() 65 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 66 | self.params = params 67 | self.n_ent = n_ent 68 | # self.feature_mapping = FeatureMapping(params) 69 | self.text_model = Text_enc(params) 70 | self.in_dims = {'Stru': params.stru_dim, 'Text': params.text_dim, 'IMG': params.img_dim} 71 | self.out_dim = params.hidden_dim 72 | modals = ['Text', 'IMG'] 73 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 74 | self.W_list = { 75 | modal: MLP(in_channels=self.in_dims[modal], out_channels=self.out_dim, 76 | hidden_channels=params.MLP_hidden_dim, num_layers=params.MLP_num_layers, 77 | dropout=params.MLP_dropout, norm=None).to(self.device) for modal in modals 78 | } 79 | self.W_list = nn.ModuleDict(self.W_list) 80 | 81 | def forward(self, img_features = None,att_features= None,att_rel_features= None, att_ids=None): 82 | # features = {'IMG': self.W_list['IMG'](img_features), 83 | # 'Text': self.W_list['Text'](self.text_model(self.n_ent, att_ids, att_features, att_rel_features))} 84 | features = {'IMG': img_features, 85 | 'Text': self.text_model(self.n_ent, att_ids, att_features, att_rel_features)} 86 | # mean_feature = torch.mean(torch.stack(list(features.values())), dim=0) 87 | mean_feature = None 88 | return features, mean_feature 89 | 90 | 91 | class GNNLayer(torch.nn.Module): 92 | def __init__(self, in_dim, out_dim, attn_dim, n_rel, act=lambda x:x): 93 | super(GNNLayer, self).__init__() 94 | self.n_rel = n_rel 95 | self.in_dim = in_dim 96 | self.out_dim = out_dim 97 | self.attn_dim = attn_dim 98 | self.act = act 99 | 100 | # +3 for self-loop, alignment and alignment-inverse 101 | self.rela_embed = nn.Embedding(2 * n_rel + 5, in_dim) 102 | 103 | self.Ws_attn = nn.Linear(in_dim, attn_dim, bias=False) 104 | self.Wr_attn = nn.Linear(in_dim, attn_dim, bias=False) 105 | self.Wkg_attn = nn.Linear(2*in_dim, attn_dim) 106 | self.w_alpha = nn.Linear(attn_dim, 1) 107 | 108 | self.W_h = nn.Linear(in_dim, out_dim, bias=False) 109 | 110 | def forward(self, hidden, edges, n_node, kgemb, left_num): 111 | # edges: [batch_idx, head, rela, tail, old_idx, new_idx] 112 | sub = edges[:, 4] 113 | rel = edges[:, 2] 114 | obj = edges[:, 5] 115 | 116 | hs = hidden[sub] 117 | hr = self.rela_embed(rel) 118 | 119 | head = edges[:, 1] 120 | tail = edges[:, 3] 121 | 122 | kg_h = kgemb((head>=left_num).long()) 123 | kg_t = kgemb((tail>=left_num).long()) 124 | kg = torch.cat([kg_h, kg_t], dim=1) 125 | 126 | message = hs + hr 127 | alpha = torch.sigmoid(self.w_alpha(nn.ReLU()(self.Ws_attn(hs) + self.Wr_attn(hr) + self.Wkg_attn(kg)))) 128 | message = alpha * message 129 | 130 | message_agg = scatter(message, index=obj, dim=0, dim_size=n_node, reduce='sum') 131 | 132 | hidden_new = self.act(self.W_h(message_agg)) 133 | 134 | return hidden_new 135 | 136 | 137 | class MASGNN(torch.nn.Module): 138 | def __init__(self, params, loader): 139 | super(MASGNN, self).__init__() 140 | self.n_layer = params.n_layer 141 | self.hidden_dim = params.hidden_dim 142 | self.attn_dim = params.attn_dim 143 | self.mm = params.mm 144 | self.n_rel = loader.n_rel 145 | self.n_ent = loader.n_ent 146 | self.loader = loader 147 | self.left_num = len(self.loader.left_ents) 148 | acts = {'relu': nn.ReLU(), 'tanh': torch.tanh, 'idd': lambda x: x} 149 | act = acts[params.act] 150 | 151 | self.gnn_layers = [] 152 | for i in range(self.n_layer): 153 | self.gnn_layers.append(GNNLayer(self.hidden_dim, self.hidden_dim, self.attn_dim, self.n_rel, act=act)) 154 | self.gnn_layers = nn.ModuleList(self.gnn_layers) 155 | 156 | self.dropout = nn.Dropout(params.dropout) 157 | self.W_final = nn.Linear(self.hidden_dim if self.mm else self.hidden_dim, 1, bias=False) # get score todo: try to use mlp 158 | self.gate = nn.GRU(self.hidden_dim, self.hidden_dim) 159 | self.kgemb = nn.Embedding(2, self.hidden_dim) 160 | if self.mm: 161 | self.img_features = F.normalize(torch.FloatTensor(self.loader.images_list)).cuda() 162 | self.att_features = torch.FloatTensor(self.loader.att_features).cuda() 163 | self.num_att_left = self.loader.num_att_left 164 | self.num_att_right = self.loader.num_att_right 165 | self.att_val_features = torch.FloatTensor(self.loader.att_val_features).cuda() 166 | self.att_rel_features = torch.nn.Embedding(self.loader.att_rel_features.shape[0], self.loader.att_rel_features.shape[1]) 167 | self.att_rel_features.weight.data = torch.FloatTensor(self.loader.att_rel_features).cuda() 168 | self.att_ids = torch.LongTensor(self.loader.att_ids).cuda() 169 | self.ids_att = self.loader.ids_att 170 | self.ids_att = {k:torch.LongTensor(v).cuda() for k,v in self.loader.ids_att.items()} 171 | self.att2rel = torch.LongTensor(self.loader.att2rel).cuda() 172 | self.mmfeature = MMFeature(self.n_ent, params) 173 | self.textMLP = MLP(in_channels=params.hidden_dim, out_channels=1, 174 | hidden_channels=params.MLP_hidden_dim, num_layers=params.MLP_num_layers, 175 | dropout=[params.MLP_dropout]*params.MLP_num_layers, norm=None) 176 | self.textW = nn.Linear(2*params.text_dim, params.hidden_dim) 177 | self.ImgMLP = MLP(in_channels=params.img_dim, out_channels=1, 178 | hidden_channels=params.MLP_hidden_dim, num_layers=params.MLP_num_layers, 179 | dropout=[params.MLP_dropout]*params.MLP_num_layers, norm=None) 180 | 181 | 182 | def forward(self, subs, mode='train',batch_idx=None): 183 | # if self.mm: 184 | # features, mean_feature = self.mmfeature(img_features=self.img_features, att_features=self.att_val_features, 185 | # att_rel_features=self.att_rel_features(self.att2rel), att_ids=self.att_ids) 186 | # simlarity of att_rel_features use cosine shape (n_rel, n_rel) 187 | # use self.att2rel to get simlarity from rel_sim , self.att2rel shape is n_att , attention shape is (n_att, n_att) 188 | # attention = rel_sim[torch.meshgrid(self.att2rel[:self.num_att_left], self.att2rel[self.num_att_left:])] 189 | # attention_l2r = scatter(attention, index=self.att_ids[self.num_att_left:]-self.left_num, dim=1, dim_size=self.n_ent-self.left_num, reduce='sum') 190 | # attention_r2l = scatter(attention, index=self.att_ids[:self.num_att_left], dim=0, dim_size=self.left_num, reduce='sum') 191 | # alpha_l2r = softmax(attention_l2r, self.att_ids[:self.num_att_left], None, self.left_num,0) 192 | # alpha_r2l = softmax(attention_r2l, self.att_ids[self.num_att_left:]-self.left_num, None, self.n_ent-self.left_num,-1) 193 | # get att_features (n1,n2,dim) 194 | 195 | 196 | 197 | 198 | 199 | # features['IMG'] = features['IMG'] / torch.norm(features['IMG'], dim=-1, keepdim=True) 200 | # features['Text'] = features['Text'] / torch.norm(features['Text'], dim=-1, keepdim=True) 201 | 202 | # img_features = self.ImgMLP(self.img_features) 203 | # img_features = F.normalize(img_features) 204 | # sim_i = torch.mm(img_features[:self.left_num], img_features[self.left_num:].T) 205 | # sim_t = torch.mm(features['Text'][:self.left_num], features['Text'][self.left_num:].T) 206 | # sim_m = sim_i+sim_t 207 | # select sim > 0.9 index 208 | # sim = torch.nonzero(sim_m > 0.8).squeeze(1) 209 | # # add rels = (2 * n_rel + 3) and inverse rels = (2 * n_rel + 4) 210 | # sim_ = torch.cat([sim[:,[0]],torch.ones(sim.shape[0],1).long().cuda() * (2 * self.n_rel + 3), sim[:,[1]] + self.left_num], -1) 211 | # rev_sim = torch.cat([sim[:,[1]] + self.left_num,torch.ones(sim.shape[0],1).long().cuda() * (2 * self.n_rel + 4),sim[:,[0]]], -1) 212 | # sim = torch.cat([sim_, rev_sim], 0) 213 | 214 | 215 | q_sub = torch.LongTensor(subs).cuda() 216 | n = q_sub.shape[0] 217 | nodes = torch.cat([torch.arange(n).unsqueeze(1).cuda(), q_sub.unsqueeze(1)], 1) 218 | nodess, edgess, old_nodes_new_idxs,old_nodes = self.loader.get_subgraphs(q_sub, layer=self.n_layer,mode=mode,sim=None) 219 | 220 | 221 | 222 | 223 | 224 | # hidden = mean_feature[nodes[:, 1]] 225 | # h0 = mean_feature[nodes[:, 1]].unsqueeze(0) 226 | # else: 227 | h0 = torch.zeros((1, n, self.hidden_dim)).cuda() 228 | hidden = torch.zeros(n, self.hidden_dim).cuda() 229 | 230 | 231 | 232 | 233 | scores_all = [] 234 | for i in range(self.n_layer): 235 | nodes = nodess[i] 236 | edges = edgess[i] 237 | old_nodes_new_idx = old_nodes_new_idxs[i] 238 | old_node = old_nodes[i] 239 | # if mode == 'train': 240 | # nodes, edges, old_nodes_new_idx = self.loader.get_neighbors(nodes.data.cpu().numpy(), mode=mode,n_hop=i) 241 | # else: 242 | # nodes, edges, old_nodes_new_idx = self.loader.get_test_cache(batch_idx,i) 243 | # # np to tensor 244 | # nodes = torch.LongTensor(nodes).cuda() 245 | # edges = torch.LongTensor(edges).cuda() 246 | # old_nodes_new_idx = torch.LongTensor(old_nodes_new_idx).cuda() 247 | # print(nodes) 248 | # print(edges) 249 | # print(old_nodes_new_idx) 250 | # print(hidden) 251 | # print(h0) 252 | hidden = self.gnn_layers[i](hidden, edges, nodes.size(0), self.kgemb, self.left_num) 253 | # print(hidden) 254 | 255 | # if self.mm: 256 | # h0 = mean_feature[nodes[:, 1]].unsqueeze(0).cuda().index_copy_(1, old_nodes_new_idx, h0[:,old_node]) 257 | # else: 258 | h0 = torch.zeros(1, nodes.size(0), hidden.size(1)).cuda().index_copy_(1, old_nodes_new_idx, h0[:, old_node]) 259 | hidden = self.dropout(hidden) 260 | hidden, h0 = self.gate(hidden.unsqueeze(0), h0) 261 | hidden = hidden.squeeze(0) 262 | # hidden -> (len(nodes), hidden_dim) 263 | # if self.mm: 264 | # mm_hidden = torch.cat((hidden, features['IMG'][nodes[:, 1]] - features['IMG'][q_sub[nodes[:, 0]]], 265 | # features['Text'][nodes[:, 1]] - features['Text'][q_sub[nodes[:, 0]]]), dim=-1) 266 | # scores = self.W_final(mm_hidden).squeeze(-1) 267 | # else: 268 | scores = self.W_final(hidden).squeeze(-1) 269 | 270 | scores_all = torch.zeros((len(subs), self.loader.n_ent)).cuda() # non_visited entities have 0 scores 271 | scores_all[[nodes[:, 0], nodes[:, 1]]] = scores 272 | 273 | 274 | if self.mm: 275 | source,target = torch.meshgrid(q_sub, torch.arange(self.n_ent).cuda()) 276 | hidden = self.img_features[source] * self.img_features[target] 277 | b,_ = torch.meshgrid(torch.arange(n).cuda(), torch.arange(self.n_ent).cuda()) 278 | img_scores = self.ImgMLP(hidden).squeeze(-1) 279 | scores_all[[b, target]] += img_scores 280 | 281 | rel_sim = torch.mm(self.att_rel_features.weight, self.att_rel_features.weight.T) 282 | 283 | for i,sub in enumerate(subs): 284 | if sub not in self.ids_att: 285 | continue 286 | if sub 0: 221 | epoch += 1 222 | mrr,t_h1, t_h3, t_h5, t_h10, out_str = model.train_batch() 223 | if args.nni: 224 | nni.report_intermediate_result({'default':mrr,'h1':t_h1,'h3':t_h3,'h5':t_h5,'h10':t_h10}) 225 | with open(args.perf_file, 'a+') as f: 226 | f.write(out_str) 227 | if mrr > best_mrr: 228 | best_mrr = mrr 229 | best_h1 = t_h1 230 | best_h3 = t_h3 231 | best_h5 = t_h5 232 | best_h10 = t_h10 233 | best_str = out_str 234 | print(str(epoch) + '\t' + best_str) 235 | with open(args.perf_file,'a+') as f: 236 | f.write("best at "+ str(epoch) + '\t' + best_str) 237 | wait_patient = 10 238 | else: 239 | wait_patient -= 1 240 | 241 | if args.nni: 242 | nni.report_final_result({'default':best_mrr,'h1':best_h1,'h3':best_h3,'h5':best_h5,'h10':best_h10}) 243 | print(best_str) 244 | 245 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import errno 4 | import torch 5 | import sys 6 | import logging 7 | import json 8 | from pathlib import Path 9 | import torch.optim as optim 10 | from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup 11 | import torch.distributed as dist 12 | import csv 13 | import os.path as osp 14 | import time 15 | import re 16 | import pdb 17 | from torch import nn 18 | from numpy import mean 19 | import multiprocessing 20 | import math 21 | import random 22 | import numpy as np 23 | import scipy 24 | import scipy.sparse as sp 25 | from scipy.stats import rankdata 26 | 27 | 28 | def set_optim(opt, model_list, freeze_part=[], accumulation_step=None): 29 | named_parameters = [] 30 | param_name = [] 31 | for model in model_list: 32 | model_para_train, freeze_layer = [], [] 33 | model_para = list(model.named_parameters()) 34 | 35 | for n, p in model_para: 36 | if not any(nd in n for nd in freeze_part): 37 | model_para_train.append((n, p)) 38 | param_name.append(n) 39 | else: 40 | p.requires_grad = False 41 | freeze_layer.append((n, p)) 42 | # pdb.set_trace() 43 | named_parameters.extend(model_para_train) 44 | 45 | parameters = [ 46 | {'params': [p for n, p in named_parameters], "lr": opt.lr, 'weight_decay': opt.weight_decay} 47 | ] 48 | 49 | if opt.optim == 'adamw': 50 | # optimizer = optim.AdamW(model.parameters(), lr=opt.lr, eps=opt.adam_epsilon) 51 | optimizer = optim.AdamW(parameters, lr=opt.lr, eps=opt.adam_epsilon) 52 | # optimizer = AdamW(parameters, lr=opt.lr, eps=opt.adam_epsilon) 53 | elif opt.optim == 'adam': 54 | optimizer = optim.Adam(parameters, lr=opt.lr) 55 | 56 | if accumulation_step is None: 57 | accumulation_step = opt.accumulation_steps 58 | if opt.scheduler == 'fixed': 59 | scheduler = FixedScheduler(optimizer) 60 | elif opt.scheduler == 'linear': 61 | scheduler_steps = opt.total_steps 62 | # scheduler = WarmupLinearScheduler(optimizer, warmup_steps=opt.warmup_steps, scheduler_steps=scheduler_steps, min_ratio=0.) 63 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(opt.warmup_steps / accumulation_step), num_training_steps=int(opt.total_steps / accumulation_step)) 64 | elif opt.scheduler == 'cos': 65 | scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=int(opt.warmup_steps / accumulation_step), num_training_steps=int(opt.total_steps / accumulation_step)) 66 | 67 | return optimizer, scheduler 68 | 69 | 70 | class FixedScheduler(torch.optim.lr_scheduler.LambdaLR): 71 | def __init__(self, optimizer, last_epoch=-1): 72 | super(FixedScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 73 | 74 | def lr_lambda(self, step): 75 | return 1.0 76 | 77 | 78 | class WarmupLinearScheduler(torch.optim.lr_scheduler.LambdaLR): 79 | def __init__(self, optimizer, warmup_steps, scheduler_steps, min_ratio, last_epoch=-1): 80 | self.warmup_steps = warmup_steps 81 | self.scheduler_steps = scheduler_steps 82 | self.min_ratio = min_ratio 83 | # self.fixed_lr = fixed_lr 84 | super(WarmupLinearScheduler, self).__init__( 85 | optimizer, self.lr_lambda, last_epoch=last_epoch 86 | ) 87 | 88 | def lr_lambda(self, step): 89 | if step < self.warmup_steps: 90 | return (1 - self.min_ratio) * step / float(max(1, self.warmup_steps)) + self.min_ratio 91 | 92 | # if self.fixed_lr: 93 | # return 1.0 94 | 95 | return max(0.0, 96 | 1.0 + (self.min_ratio - 1) * (step - self.warmup_steps) / float(max(1.0, self.scheduler_steps - self.warmup_steps)), 97 | ) 98 | 99 | 100 | class Loss_log(): 101 | def __init__(self): 102 | self.loss = [999999.] 103 | self.acc = [0.] 104 | self.flag = 0 105 | self.token_right_num = [] 106 | self.token_all_num = [] 107 | self.use_top_k_acc = 0 108 | 109 | def acc_init(self, topn=[1]): 110 | self.loss = [] 111 | self.token_right_num = [] 112 | self.token_all_num = [] 113 | self.topn = topn 114 | self.use_top_k_acc = 1 115 | self.top_k_word_right = {} 116 | for n in topn: 117 | self.top_k_word_right[n] = [] 118 | 119 | def get_token_acc(self): 120 | if len(self.token_all_num) == 0: 121 | return 0. 122 | elif self.use_top_k_acc == 1: 123 | res = [] 124 | for n in self.topn: 125 | res.append(round((sum(self.top_k_word_right[n]) / sum(self.token_all_num)) * 100, 3)) 126 | return res 127 | else: 128 | return [sum(self.token_right_num) / sum(self.token_all_num)] 129 | 130 | def update_token(self, token_num, token_right): 131 | self.token_all_num.append(token_num) 132 | if isinstance(token_right, list): 133 | for i, n in enumerate(self.topn): 134 | self.top_k_word_right[n].append(token_right[i]) 135 | self.token_right_num.append(token_right) 136 | 137 | def update(self, case): 138 | self.loss.append(case) 139 | 140 | def update_acc(self, case): 141 | self.acc.append(case) 142 | 143 | def get_acc(self): 144 | return self.acc[-1] 145 | 146 | def get_min_loss(self): 147 | return min(self.loss) 148 | 149 | def get_loss(self): 150 | if len(self.loss) == 0: 151 | return 500. 152 | return mean(self.loss) 153 | 154 | def early_stop(self): 155 | # min_loss = min(self.loss) 156 | if self.loss[-1] > min(self.loss): 157 | self.flag += 1 158 | else: 159 | self.flag = 0 160 | 161 | if self.flag > 1000: 162 | return True 163 | else: 164 | return False 165 | 166 | def torch_accuracy(output, target, topk=(1,)): 167 | ''' 168 | param output, target: should be torch Variable 169 | ''' 170 | # assert isinstance(output, torch.cuda.Tensor), 'expecting Torch Tensor' 171 | # assert isinstance(target, torch.Tensor), 'expecting Torch Tensor' 172 | # print(type(output)) 173 | 174 | topn = max(topk) 175 | batch_size = output.size(0) 176 | 177 | _, pred = output.topk(topn, 1, True, True) 178 | pred = pred.t() 179 | 180 | is_correct = pred.eq(target.view(1, -1).expand_as(pred)) 181 | 182 | ans = [] 183 | ans_num = [] 184 | for i in topk: 185 | # is_correct_i = is_correct[:i].view(-1).float().sum(0, keepdim=True) 186 | is_correct_i = is_correct[:i].contiguous().view(-1).float().sum(0, keepdim=True) 187 | ans_num.append(int(is_correct_i.item())) 188 | ans.append(is_correct_i.mul_(100.0 / batch_size)) 189 | 190 | return ans, ans_num 191 | 192 | 193 | def pairwise_distances(x, y=None): 194 | ''' 195 | Input: x is a Nxd matrix 196 | y is an optional Mxd matirx 197 | Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:] 198 | if y is not given then use 'y=x'. 199 | i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2 200 | ''' 201 | x_norm = (x**2).sum(1).view(-1, 1) 202 | if y is not None: 203 | y_norm = (y**2).sum(1).view(1, -1) 204 | else: 205 | y = x 206 | y_norm = x_norm.view(1, -1) 207 | 208 | distance = x_norm + y_norm - 2.0 * torch.mm(x, torch.transpose(y, 0, 1)) 209 | return torch.clamp(distance, 0.0, np.inf) 210 | 211 | 212 | def normalize_adj(mx): 213 | """Row-normalize sparse matrix""" 214 | rowsum = np.array(mx.sum(1)) 215 | r_inv_sqrt = np.power(rowsum, -0.5).flatten() 216 | r_inv_sqrt[np.isinf(r_inv_sqrt)] = 0. 217 | r_mat_inv_sqrt = sp.diags(r_inv_sqrt) 218 | return mx.dot(r_mat_inv_sqrt).transpose().dot(r_mat_inv_sqrt) 219 | 220 | 221 | def normalize_features(mx): 222 | """Row-normalize sparse matrix""" 223 | rowsum = np.array(mx.sum(1)) 224 | r_inv = np.power(rowsum, -1).flatten() 225 | r_inv[np.isinf(r_inv)] = 0. 226 | r_mat_inv = sp.diags(r_inv) 227 | mx = r_mat_inv.dot(mx) 228 | return mx 229 | 230 | 231 | def sparse_mx_to_torch_sparse_tensor(sparse_mx): 232 | """Convert a scipy sparse matrix to a torch sparse tensor.""" 233 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 234 | indices = torch.from_numpy( 235 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) 236 | values = torch.FloatTensor(sparse_mx.data) 237 | shape = torch.Size(sparse_mx.shape) 238 | return torch.sparse.FloatTensor(indices, values, shape) 239 | 240 | 241 | def div_list(ls, n): 242 | ls_len = len(ls) 243 | if n <= 0 or 0 == ls_len: 244 | return [] 245 | if n > ls_len: 246 | return [] 247 | elif n == ls_len: 248 | return [[i] for i in ls] 249 | else: 250 | j = ls_len // n 251 | k = ls_len % n 252 | ls_return = [] 253 | for i in range(0, (n - 1) * j, j): 254 | ls_return.append(ls[i:i + j]) 255 | ls_return.append(ls[(n - 1) * j:]) 256 | return ls_return 257 | 258 | 259 | def multi_cal_neg(pos_triples, task, triples, r_hs_dict, r_ts_dict, ids, neg_scope): 260 | neg_triples = list() 261 | for idx, tas in enumerate(task): 262 | (h, r, t) = pos_triples[tas] 263 | h2, r2, t2 = h, r, t 264 | temp_scope, num = neg_scope, 0 265 | while True: 266 | choice = random.randint(0, 999) 267 | if choice < 500: 268 | if temp_scope: 269 | h2 = random.sample(r_hs_dict[r], 1)[0] 270 | else: 271 | for id in ids: 272 | if h2 in id: 273 | h2 = random.sample(id, 1)[0] 274 | break 275 | else: 276 | if temp_scope: 277 | t2 = random.sample(r_ts_dict[r], 1)[0] 278 | else: 279 | for id in ids: 280 | if t2 in id: 281 | t2 = random.sample(id, 1)[0] 282 | break 283 | if (h2, r2, t2) not in triples: 284 | break 285 | else: 286 | num += 1 287 | if num > 10: 288 | temp_scope = False 289 | neg_triples.append((h2, r2, t2)) 290 | return neg_triples 291 | 292 | 293 | def multi_typed_sampling(pos_triples, triples, r_hs_dict, r_ts_dict, ids, neg_scope): 294 | t_ = time.time() 295 | triples = set(triples) 296 | tasks = div_list(np.array(range(len(pos_triples)), dtype=np.int32), 10) 297 | pool = multiprocessing.Pool(processes=len(tasks)) 298 | reses = list() 299 | for task in tasks: 300 | reses.append(pool.apply_async(multi_cal_neg, (pos_triples, task, triples, r_hs_dict, r_ts_dict, ids, neg_scope))) 301 | pool.close() 302 | pool.join() 303 | neg_triples = [] 304 | for res in reses: 305 | neg_triples.extend(res.get()) 306 | return neg_triples 307 | 308 | 309 | def nearest_neighbor_sampling(emb, left, right, K): 310 | t = time.time() 311 | neg_left = [] 312 | distance = pairwise_distances(emb[right], emb[right]) 313 | for idx in range(right.shape[0]): 314 | _, indices = torch.sort(distance[idx, :], descending=False) 315 | neg_left.append(right[indices[1:K + 1]]) 316 | neg_left = torch.cat(tuple(neg_left), dim=0) 317 | neg_right = [] 318 | distance = pairwise_distances(emb[left], emb[left]) 319 | for idx in range(left.shape[0]): 320 | _, indices = torch.sort(distance[idx, :], descending=False) 321 | neg_right.append(left[indices[1:K + 1]]) 322 | neg_right = torch.cat(tuple(neg_right), dim=0) 323 | return neg_left, neg_right 324 | 325 | 326 | def get_adjr(ent_size, triples, norm=False): 327 | print('getting a sparse tensor r_adj...') 328 | M = {} 329 | for tri in triples: 330 | if tri[0] == tri[2]: 331 | continue 332 | if (tri[0], tri[2]) not in M: 333 | M[(tri[0], tri[2])] = 0 334 | M[(tri[0], tri[2])] += 1 335 | ind, val = [], [] 336 | for (fir, sec) in M: 337 | ind.append((fir, sec)) 338 | ind.append((sec, fir)) 339 | val.append(M[(fir, sec)]) 340 | val.append(M[(fir, sec)]) 341 | 342 | for i in range(ent_size): 343 | ind.append((i, i)) 344 | val.append(1) 345 | 346 | if norm: 347 | ind = np.array(ind, dtype=np.int32) 348 | val = np.array(val, dtype=np.float32) 349 | adj = sp.coo_matrix((val, (ind[:, 0], ind[:, 1])), shape=(ent_size, ent_size), dtype=np.float32) 350 | # 1. normalize_adj 351 | # 2. Convert a scipy sparse matrix to a torch sparse tensor 352 | # pdb.set_trace() 353 | return sparse_mx_to_torch_sparse_tensor(normalize_adj(adj)) 354 | else: 355 | M = torch.sparse_coo_tensor(torch.LongTensor(ind).t(), torch.FloatTensor(val), torch.Size([ent_size, ent_size])) 356 | return M 357 | 358 | 359 | def cal_ranks(scores, labels, is_lefts, left_num): 360 | ranks = [] 361 | for idx, score in enumerate(scores): 362 | if not is_lefts[idx]: 363 | real_score = - score[:left_num] 364 | rank = real_score.argsort() 365 | rank = np.where(rank == labels[idx])[0][0] 366 | else: 367 | real_score = - score[left_num:] 368 | rank = real_score.argsort() 369 | rank = np.where(rank == labels[idx]-left_num)[0][0] 370 | ranks.append(rank+1) 371 | return list(ranks) 372 | 373 | 374 | def cal_performance(ranks): 375 | mrr = (1. / ranks).sum() / len(ranks) 376 | h_1 = sum(ranks<=1) * 1.0 / len(ranks) 377 | h_3 = sum(ranks<=3) * 1.0 / len(ranks) 378 | h_5 = sum(ranks<=5) * 1.0 / len(ranks) 379 | h_10 = sum(ranks<=10) * 1.0 / len(ranks) 380 | return mrr, h_1,h_3,h_5, h_10 381 | 382 | def multi_cal_rank(task, sim, top_k, l_or_r): 383 | mean = 0 384 | mrr = 0 385 | num = [0 for k in top_k] 386 | for i in range(len(task)): 387 | ref = task[i] 388 | if l_or_r == 0: 389 | rank = (sim[i, :]).argsort() 390 | else: 391 | rank = (sim[:, i]).argsort() 392 | assert ref in rank 393 | rank_index = np.where(rank == ref)[0][0] 394 | mean += (rank_index + 1) 395 | mrr += 1.0 / (rank_index + 1) 396 | for j in range(len(top_k)): 397 | if rank_index < top_k[j]: 398 | num[j] += 1 399 | return mean, num, mrr 400 | 401 | 402 | def multi_get_hits(Lvec, Rvec, top_k=(1, 5, 10, 50, 100), args=None): 403 | result = [] 404 | sim = pairwise_distances(torch.FloatTensor(Lvec), torch.FloatTensor(Rvec)).numpy() 405 | if args.csls is True: 406 | sim = 1 - csls_sim(1 - sim, args.csls_k) 407 | for i in [0, 1]: 408 | top_total = np.array([0] * len(top_k)) 409 | mean_total, mrr_total = 0.0, 0.0 410 | s_len = Lvec.shape[0] if i == 0 else Rvec.shape[0] 411 | tasks = div_list(np.array(range(s_len)), 10) 412 | pool = multiprocessing.Pool(processes=len(tasks)) 413 | reses = list() 414 | for task in tasks: 415 | if i == 0: 416 | reses.append(pool.apply_async(multi_cal_rank, (task, sim[task, :], top_k, i))) 417 | else: 418 | reses.append(pool.apply_async(multi_cal_rank, (task, sim[:, task], top_k, i))) 419 | pool.close() 420 | pool.join() 421 | for res in reses: 422 | mean, num, mrr = res.get() 423 | mean_total += mean 424 | mrr_total += mrr 425 | top_total += np.array(num) 426 | acc_total = top_total / s_len 427 | for i in range(len(acc_total)): 428 | acc_total[i] = round(acc_total[i], 4) 429 | mean_total /= s_len 430 | mrr_total /= s_len 431 | result.append(acc_total) 432 | result.append(mean_total) 433 | result.append(mrr_total) 434 | return result 435 | 436 | 437 | def csls_sim(sim_mat, k): 438 | """ 439 | Compute pairwise csls similarity based on the input similarity matrix. 440 | Parameters 441 | ---------- 442 | sim_mat : matrix-like 443 | A pairwise similarity matrix. 444 | k : int 445 | The number of nearest neighbors. 446 | Returns 447 | ------- 448 | csls_sim_mat : A csls similarity matrix of n1*n2. 449 | """ 450 | 451 | nearest_values1 = torch.mean(torch.topk(sim_mat, k)[0], 1) 452 | nearest_values2 = torch.mean(torch.topk(sim_mat.t(), k)[0], 1) 453 | csls_sim_mat = 2 * sim_mat.t() - nearest_values1 454 | csls_sim_mat = csls_sim_mat.t() - nearest_values2 455 | return csls_sim_mat 456 | 457 | 458 | def get_topk_indices(M, K=1000): 459 | H, W = M.shape 460 | M_view = M.view(-1) 461 | vals, indices = M_view.topk(K) 462 | print("highest sim:", vals[0].item(), "lowest sim:", vals[-1].item()) 463 | two_d_indices = torch.cat(((indices // W).unsqueeze(1), (indices % W).unsqueeze(1)), dim=1) 464 | return two_d_indices 465 | 466 | 467 | def normalize_zero_one(A): 468 | A -= A.min(1, keepdim=True)[0] 469 | A /= A.max(1, keepdim=True)[0] 470 | return A 471 | 472 | 473 | def output_device(model): 474 | sd = model.state_dict() 475 | devices = [] 476 | for v in sd.values(): 477 | if v.device not in devices: 478 | devices.append(v.device) 479 | # for d in devices: 480 | # print(d) 481 | print(devices) 482 | 483 | 484 | if __name__ == '__main__': 485 | # test cal_ranks 9 nodes, 5left , 4right,2 seeds(3,7)(8,2) 486 | scores = np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.31,0.9,0.8,0.7], 487 | [0.5, 0.4, 0.7, 0.2, 0.1,0.3,0.32,0.23,0.44]]) 488 | labels = np.array([7,2]) 489 | is_lefts = np.array([True,False]) 490 | left_num = 5 491 | ranks = cal_ranks(scores, labels, is_lefts, left_num) 492 | print(ranks) 493 | -------------------------------------------------------------------------------- /vis.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' 4 | import argparse 5 | import random 6 | import torch 7 | import numpy as np 8 | from load_data import DataLoader 9 | from base_model import BaseModel 10 | import time 11 | from collections import OrderedDict 12 | import networkx as nx 13 | import matplotlib.pyplot as plt 14 | 15 | 16 | parser = argparse.ArgumentParser(description="Parser for MASEA") 17 | parser.add_argument("--data_path", default="../data/mmkg", type=str, help="Experiment path") 18 | parser.add_argument("--data_choice", default="FBDB15K", type=str, choices=["DBP15K", "DWY", "FBYG15K", "FBDB15K"], 19 | help="Experiment path") 20 | parser.add_argument("--data_split", default="norm", type=str, help="Experiment split", 21 | choices=["dbp_wd_15k_V2", "dbp_wd_15k_V1", "zh_en", "ja_en", "fr_en", "norm"]) 22 | parser.add_argument("--data_rate", type=float, default=0.8, choices=[0.2, 0.3, 0.5, 0.8], help="training set rate") 23 | parser.add_argument('--seed', type=str, default=1234) 24 | parser.add_argument('--gpu', type=int, default=0) 25 | parser.add_argument('--perf_file', type=str, default='perf.txt') 26 | parser.add_argument('--lr', type=float, default=0.001) 27 | parser.add_argument('--lamb', type=float, default=0.0002) 28 | parser.add_argument('--decay_rate', type=float, default=0.991) 29 | parser.add_argument('--hidden_dim', type=int, default=64) 30 | parser.add_argument('--attn_dim', type=int, default=5) 31 | parser.add_argument('--dropout', type=float, default=0.2) 32 | parser.add_argument('--act', type=str, default='relu') 33 | parser.add_argument('--n_layer', type=int, default=5) 34 | parser.add_argument('--n_batch', type=int, default=2) 35 | parser.add_argument("--lamda", type=float, default=0.5) 36 | parser.add_argument("--exp_name", default="EA_exp", type=str, help="Experiment name") 37 | parser.add_argument("--MLP_hidden_dim", type=int, default=64) 38 | parser.add_argument("--MLP_num_layers", type=int, default=3) 39 | parser.add_argument("--MLP_dropout", type=float, default=0.2) 40 | 41 | parser.add_argument("--n_ent", type=int, default=0) 42 | parser.add_argument("--n_rel", type=int, default=0) 43 | 44 | parser.add_argument("--stru_dim", type=int, default=16) 45 | parser.add_argument("--text_dim", type=int, default=768) 46 | parser.add_argument("--img_dim", type=int, default=2048) 47 | parser.add_argument("--time_dim", type=int, default=32) 48 | parser.add_argument("--out_dim", type=int, default=32) 49 | parser.add_argument("--train_support", type=int, default=0) 50 | parser.add_argument("--gnn_model", type=str, default='RS_GNN') 51 | parser.add_argument("--mm", type=int, default=0) 52 | parser.add_argument("--shuffle", type=int, default=1) 53 | parser.add_argument("--meta", type=int, default=1) 54 | parser.add_argument("--temperature", type=float, default=0.5) 55 | parser.add_argument("--premm", type=int, default=0) 56 | parser.add_argument("--withmm", type=int, default=1) 57 | parser.add_argument("--update_step", type=int, default=20) 58 | parser.add_argument("--update_step_test", type=int, default=20) 59 | parser.add_argument("--update_lr", type=float, default=0.001) 60 | 61 | 62 | # base 63 | 64 | parser.add_argument('--batch_size', default=128, type=int) 65 | parser.add_argument('--epoch', default=100, type=int) 66 | parser.add_argument("--save_model", default=0, type=int, choices=[0, 1]) 67 | parser.add_argument("--only_test", default=0, type=int, choices=[0, 1]) 68 | 69 | # torthlight 70 | parser.add_argument("--no_tensorboard", default=False, action="store_true") 71 | 72 | parser.add_argument("--dump_path", default="dump/", type=str, help="Experiment dump path") 73 | parser.add_argument("--exp_id", default="001", type=str, help="Experiment ID") 74 | parser.add_argument("--random_seed", default=42, type=int) 75 | 76 | 77 | # --------- EA ----------- 78 | 79 | # parser.add_argument("--data_rate", type=float, default=0.3, help="training set rate") 80 | # 81 | 82 | # TODO: add some dynamic variable 83 | parser.add_argument("--model_name", default="MEAformer", type=str, choices=["EVA", "MCLEA", "MSNEA", "MEAformer"], 84 | help="model name") 85 | parser.add_argument("--model_name_save", default="", type=str, help="model name for model load") 86 | 87 | parser.add_argument('--workers', type=int, default=8) 88 | parser.add_argument('--accumulation_steps', type=int, default=1) 89 | parser.add_argument("--scheduler", default="linear", type=str, choices=["linear", "cos", "fixed"]) 90 | parser.add_argument("--optim", default="adamw", type=str, choices=["adamw", "adam"]) 91 | parser.add_argument('--weight_decay', type=float, default=0.0001) 92 | parser.add_argument("--adam_epsilon", default=1e-8, type=float) 93 | parser.add_argument('--eval_epoch', default=100, type=int, help='evaluate each n epoch') 94 | parser.add_argument("--enable_sota", action="store_true", default=False) 95 | 96 | parser.add_argument('--margin', default=1, type=float, help='The fixed margin in loss function. ') 97 | parser.add_argument('--emb_dim', default=1000, type=int, help='The embedding dimension in KGE model.') 98 | parser.add_argument('--adv_temp', default=1.0, type=float, 99 | help='The temperature of sampling in self-adversarial negative sampling.') 100 | parser.add_argument("--contrastive_loss", default=0, type=int, choices=[0, 1]) 101 | parser.add_argument('--clip', type=float, default=1., help='gradient clipping') 102 | 103 | # --------- EVA ----------- 104 | 105 | parser.add_argument("--hidden_units", type=str, default="128,128,128", 106 | help="hidden units in each hidden layer(including in_dim and out_dim), splitted with comma") 107 | parser.add_argument("--attn_dropout", type=float, default=0.0, help="dropout rate for gat layers") 108 | parser.add_argument("--distance", type=int, default=2, help="L1 distance or L2 distance. ('1', '2')", choices=[1, 2]) 109 | parser.add_argument("--csls", action="store_true", default=False, help="use CSLS for inference") 110 | parser.add_argument("--csls_k", type=int, default=10, help="top k for csls") 111 | parser.add_argument("--il", action="store_true", default=False, help="Iterative learning?") 112 | parser.add_argument("--semi_learn_step", type=int, default=10, help="If IL, what's the update step?") 113 | parser.add_argument("--il_start", type=int, default=500, help="If Il, when to start?") 114 | parser.add_argument("--unsup", action="store_true", default=False) 115 | parser.add_argument("--unsup_k", type=int, default=1000, help="|visual seed|") 116 | 117 | # --------- MCLEA ----------- 118 | parser.add_argument("--unsup_mode", type=str, default="img", help="unsup mode", choices=["img", "name", "char"]) 119 | parser.add_argument("--tau", type=float, default=0.1, help="the temperature factor of contrastive loss") 120 | parser.add_argument("--alpha", type=float, default=0.2, help="the margin of InfoMaxNCE loss") 121 | parser.add_argument("--with_weight", type=int, default=1, help="Whether to weight the fusion of different ") 122 | parser.add_argument("--structure_encoder", type=str, default="gat", help="the encoder of structure view", 123 | choices=["gat", "gcn"]) 124 | parser.add_argument("--ab_weight", type=float, default=0.5, help="the weight of NTXent Loss") 125 | 126 | parser.add_argument("--projection", action="store_true", default=False, help="add projection for model") 127 | parser.add_argument("--heads", type=str, default="2,2", help="heads in each gat layer, splitted with comma") 128 | parser.add_argument("--instance_normalization", action="store_true", default=False, 129 | help="enable instance normalization") 130 | parser.add_argument("--attr_dim", type=int, default=100, help="the hidden size of attr and rel features") 131 | parser.add_argument("--name_dim", type=int, default=100, help="the hidden size of name feature") 132 | parser.add_argument("--char_dim", type=int, default=100, help="the hidden size of char feature") 133 | 134 | parser.add_argument("--w_gcn", action="store_false", default=True, help="with gcn features") 135 | parser.add_argument("--w_rel", action="store_false", default=True, help="with rel features") 136 | parser.add_argument("--w_attr", action="store_false", default=True, help="with attr features") 137 | parser.add_argument("--w_name", action="store_false", default=True, help="with name features") 138 | parser.add_argument("--w_char", action="store_false", default=True, help="with char features") 139 | parser.add_argument("--w_img", action="store_false", default=True, help="with img features") 140 | parser.add_argument("--use_surface", type=int, default=0, help="whether to use the surface") 141 | 142 | parser.add_argument("--inner_view_num", type=int, default=6, help="the number of inner view") 143 | parser.add_argument("--word_embedding", type=str, default="glove", help="the type of word embedding, [glove|fasttext]", 144 | choices=["glove", "bert"]) 145 | # projection head 146 | parser.add_argument("--use_project_head", action="store_true", default=False, help="use projection head") 147 | parser.add_argument("--zoom", type=float, default=0.1, help="narrow the range of losses") 148 | parser.add_argument("--reduction", type=str, default="mean", help="[sum|mean]", choices=["sum", "mean"]) 149 | 150 | # --------- MEAformer ----------- 151 | parser.add_argument("--hidden_size", type=int, default=100, help="the hidden size of MEAformer") 152 | parser.add_argument("--intermediate_size", type=int, default=400, help="the hidden size of MEAformer") 153 | parser.add_argument("--num_attention_heads", type=int, default=5, help="the number of attention_heads of MEAformer") 154 | parser.add_argument("--num_hidden_layers", type=int, default=2, help="the number of hidden_layers of MEAformer") 155 | parser.add_argument("--position_embedding_type", default="absolute", type=str) 156 | parser.add_argument("--use_intermediate", type=int, default=1, help="whether to use_intermediate") 157 | parser.add_argument("--replay", type=int, default=0, help="whether to use replay strategy") 158 | parser.add_argument("--neg_cross_kg", type=int, default=0, 159 | help="whether to force the negative samples in the opposite KG") 160 | 161 | # --------- MSNEA ----------- 162 | parser.add_argument("--dim", type=int, default=100, help="the hidden size of MSNEA") 163 | parser.add_argument("--neg_triple_num", type=int, default=1, help="neg triple num") 164 | parser.add_argument("--use_bert", type=int, default=0) 165 | parser.add_argument("--use_attr_value", type=int, default=0) 166 | # parser.add_argument("--learning_rate", type=int, default=0.001) 167 | # parser.add_argument("--optimizer", type=str, default="Adam") 168 | # parser.add_argument("--max_epoch", type=int, default=200) 169 | 170 | # parser.add_argument("--save_path", type=str, default="save_pkl", help="save path") 171 | 172 | # ------------ Para ------------ 173 | parser.add_argument('--rank', type=int, default=0, help='rank to dist') 174 | parser.add_argument('--dist', type=int, default=0, help='whether to dist') 175 | parser.add_argument('--device', default='cuda', help='device id (i.e. 0 or 0,1 or cpu)') 176 | parser.add_argument('--world-size', default=3, type=int, 177 | help='number of distributed processes') 178 | parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') 179 | parser.add_argument("--local_rank", default=-1, type=int) 180 | 181 | parser.add_argument("--nni", default=0, type=int) 182 | args = parser.parse_args() 183 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 184 | 185 | # use gpu 0 186 | torch.cuda.set_device(args.gpu) 187 | 188 | 189 | if __name__ == '__main__': 190 | random.seed(args.seed) 191 | np.random.seed(args.seed) 192 | torch.manual_seed(args.seed) 193 | 194 | results_dir = 'results' 195 | if not os.path.exists(results_dir): 196 | os.makedirs(results_dir) 197 | args_str = f'{args.data_choice}_{args.data_split}_{args.data_rate}_lr{args.lr}_bs{args.n_batch}_hidden_dim{args.hidden_dim}_lamb{args.lamb}_dropout{args.dropout}_act{args.act}_decay_rate{args.decay_rate}' 198 | args.perf_file = os.path.join(results_dir, args.exp_name, args_str + time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) + '.txt') 199 | if not os.path.exists(os.path.join(results_dir, args.exp_name)): 200 | os.makedirs(os.path.join(results_dir, args.exp_name),exist_ok=True) 201 | if args.nni: 202 | import nni 203 | from nni.utils import merge_parameter 204 | nni_params = nni.get_next_parameter() 205 | args = merge_parameter(args, nni_params) 206 | print(args) 207 | print(args, file=open(args.perf_file, 'a')) 208 | loader = DataLoader(args) 209 | id2name = loader.id2name 210 | id2rel = loader.id2rel 211 | n_rel = loader.n_rel 212 | id2rel_reverse = {} 213 | for k, v in id2rel.items(): 214 | id2rel_reverse[k+n_rel] = v+'_reverse' 215 | id2rel = {**id2rel , **id2rel_reverse} 216 | id2rel[2*n_rel] = 'self_loop' 217 | id2rel[2*n_rel+1] = 'anchor' 218 | id2rel[2*n_rel+2] = 'anchor_reverse' 219 | left_entity = len(loader.left_ents) 220 | 221 | batch_size = 1 222 | n_data = loader.n_test 223 | n_batch = n_data // batch_size + (n_data % batch_size > 0) 224 | 225 | for i in range(n_batch): 226 | start = i*batch_size 227 | end = min(n_data, (i+1)*batch_size) 228 | batch_idx = np.arange(start, end) 229 | triple = loader.get_batch(batch_idx, data='test') 230 | subs, rels, objs = triple[:,0],triple[:,1],triple[:,2] 231 | sub = subs[0] 232 | rel = rels[0] 233 | obj = objs[0] 234 | edges = loader.get_vis_subgraph(sub, obj, 5) 235 | all_edges_size = sum([len(edge) for edge in edges]) 236 | print(all_edges_size) 237 | if all_edges_size >100 or all_edges_size == 0: 238 | continue 239 | pos = {} 240 | x_pos = [-5,-3, -1, 1, 3, 5] 241 | g = {'nodes': [], 'edges': []} 242 | G = nx.DiGraph() 243 | for node in edges[0][:,0].unique(): 244 | G.add_node(str(node.item()) + '_' + str(0), desc=id2name[node.item()] + '_' + str(0), layer=0) 245 | g['nodes'].append({'id': str(node.item()) + '_' + str(0), 'name': id2name[node.item()] + '_' + str(0),"class": 1 if node.item() < left_entity else 2 ,"imgsrc": "None","content": "None"} ) 246 | pos[str(node.item()) + '_' + str(0)] = (x_pos[0], 0) 247 | for idx, edge in enumerate(edges): 248 | # node_1 = edge[:,0].unique() 249 | node_2 = edge[:,2].unique() 250 | size = len(node_2) 251 | 252 | for y, node in enumerate(node_2): 253 | G.add_node(str(node.item())+'_'+str(idx+1), desc=id2name[node.item()]+'_'+str(idx+1),layer=idx+1) 254 | g['nodes'].append({'id': str(node.item())+'_'+str(idx+1), 'name': id2name[node.item()]+'_'+str(idx+1),"class": 1 if node.item() < left_entity else 2,"imgsrc": "None","content": "None"} ) 255 | pos[str(node.item())+'_'+str(idx+1)] = (x_pos[idx+1], 10/(size+1) * (y+1) - 5) 256 | for e in edge: 257 | g['edges'].append({'source': str(e[0].item())+'_'+str(idx), 'target': str(e[2].item())+'_'+str(idx+1), 'name': id2rel[e[1].item()]} ) 258 | G.add_edge(str(e[0].item())+'_'+str(idx), str(e[2].item())+'_'+str(idx+1), name=id2rel[e[1].item()]) 259 | 260 | 261 | # nodes = torch.cat([edges[:,0], edges[:,2]]).unique() 262 | # for node in nodes: 263 | # G.add_node(node.item(), desc=id2name[node.item()]) 264 | # for edge in edges: 265 | # G.add_edge(edge[0].item(), edge[2].item(), name=id2rel[edge[1].item()]) 266 | 267 | # draw graph with labels 268 | plt.figure(figsize=(16, 16), dpi=80) 269 | # pos = nx.kamada_kawai_layout(G) 270 | pos = nx.spring_layout(G) 271 | nx.draw(G, pos) 272 | nx.draw_networkx_nodes(G, pos=pos, nodelist=[str(sub.item()) + '_' + str(0),str(obj.item()) + '_' + str(5)], node_color='red', node_size=1000) 273 | node_labels = nx.get_node_attributes(G, 'desc') 274 | nx.draw_networkx_labels(G, pos, labels=node_labels) 275 | edge_labels = nx.get_edge_attributes(G, 'name') 276 | nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels) 277 | 278 | plt.savefig(f'FBDB_{sub}_{rel}_{obj}.png', dpi=100) 279 | plt.close() 280 | json.dump(g, open(f'FBDB_{sub}_{rel}_{obj}.json', 'w',encoding='utf-8'), indent=4) 281 | 282 | 283 | 284 | 285 | 286 | --------------------------------------------------------------------------------