├── .gitignore
├── IMG
├── logo.png
├── logo1.png
├── logo2.png
└── model.jpg
├── LICENSE
├── README.md
├── base_model.py
├── data.py
├── load_data.py
├── models.py
├── opt.py
├── run.sh
├── run.slurm
├── run_dbp.sh
├── run_dbp.slurm
├── run_oea.sh
├── run_oea.slurm
├── train.py
├── utils.py
└── vis.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | __pycache__
3 | results
4 | *.log
5 | vis
--------------------------------------------------------------------------------
/IMG/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyyf2002/ASGEA/57eb3216f273db03e2aab5db1e38de65917b9f37/IMG/logo.png
--------------------------------------------------------------------------------
/IMG/logo1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyyf2002/ASGEA/57eb3216f273db03e2aab5db1e38de65917b9f37/IMG/logo1.png
--------------------------------------------------------------------------------
/IMG/logo2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyyf2002/ASGEA/57eb3216f273db03e2aab5db1e38de65917b9f37/IMG/logo2.png
--------------------------------------------------------------------------------
/IMG/model.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyyf2002/ASGEA/57eb3216f273db03e2aab5db1e38de65917b9f37/IMG/model.jpg
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Luo Yangyifei
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |
4 |
5 | # 🏕️ [ASGEA: Exploiting Logic Rules from Align-Subgraphs for Entity Alignment](https://arxiv.org/abs/2402.11000)
6 |
7 | [](https://github.com/zjukg/MEAformer/blob/main/licence)
8 | [](https://arxiv.org/abs/2402.11000)
9 | [](https://pytorch.org/)
10 |
11 |
12 | >This paper proposes the Align-Subgraph Entity Alignment (ASGEA) framework to exploit logic rules from Align-Subgraphs. ASGEA uses anchor links as bridges to construct Align-Subgraphs and spreads along the paths across KGs, which distinguishes it from the embedding-based methods.
13 |
14 |
15 |

16 |
17 |
18 |
19 | ## 🔬 Dependencies
20 | ```
21 | pytorch 1.12.0
22 | torch_geometric 2.2.0
23 | torch_scatter 2.0.9
24 | transformers 4.26.1
25 | ```
26 |
27 | ## 🚀 Train
28 |
29 | - **Quick start**: Using script file for ASGEA-MM.
30 |
31 | ```bash
32 | # FBDB15K & FBYG15K
33 | >> bash run.sh FB
34 | # DBP15K
35 | >> bash run_dbp.sh DBP
36 | # Multi OpenEA
37 | >> bash run_oea.sh OEA
38 | ```
39 |
40 | - **❗tips**: If you are using slurm, you can change the `.sh` file from
41 |
42 | ```bash
43 | datas="FBDB15K FBYG15K"
44 | rates="0.2 0.5 0.8"
45 | expn=$1
46 | if [ ! -d "results/${expn}" ]; then
47 | mkdir results/${expn}
48 | fi
49 | if [ ! -d "results/${expn}/backup" ]; then
50 | mkdir results/${expn}/backup
51 | fi
52 | cp *.py results/${expn}/backup/
53 | for data in $datas ; do
54 | for rate in $rates ; do
55 | python train.py --data_split norm --n_batch 4 --n_layer 5 --lr 0.001 --data_choice ${data} --data_rate ${rate} --exp_name ${expn} --mm 1 --img_dim 4096
56 | # echo "sbatch -o ${data}_${rate}.log run.slurm 4 5 0.001 ${data} ${rate}"
57 | # sbatch -o ${expn}_${data}_${rate}.log run.slurm 4 5 0.001 ${data} ${rate} ${expn}
58 | done
59 | done
60 | ```
61 |
62 | to
63 |
64 | ```bash
65 | datas="FBDB15K FBYG15K"
66 | rates="0.2 0.5 0.8"
67 | expn=$1
68 | if [ ! -d "results/${expn}" ]; then
69 | mkdir results/${expn}
70 | fi
71 | if [ ! -d "results/${expn}/backup" ]; then
72 | mkdir results/${expn}/backup
73 | fi
74 | cp *.py results/${expn}/backup/
75 | for data in $datas ; do
76 | for rate in $rates ; do
77 | echo "sbatch -o ${data}_${rate}.log run.slurm 4 5 0.001 ${data} ${rate}"
78 | sbatch -o ${expn}_${data}_${rate}.log run.slurm 4 5 0.001 ${data} ${rate} ${expn}
79 | done
80 | done
81 | ```
82 |
83 | - **for ASGEA-Stru**: Just set `mm=0`.
84 |
85 |
86 | ## 📚 Dataset
87 | ❗NOTE: Download from [ufile](https://ufile.io/kzkkfayd) (1.69G) and unzip it to make those files satisfy the following file hierarchy:
88 |
89 | ```
90 | ROOT
91 | ├── data
92 | │ └── mmkg
93 | └── ASGEA
94 | ```
95 |
96 | #### Code Path
97 |
98 |
99 | 👈 🔎 Click
100 |
101 | ```
102 | ASGEA
103 | ├── base_model.py
104 | ├── data.py
105 | ├── load_data.py
106 | ├── models.py
107 | ├── opt.py
108 | ├── README.md
109 | ├── run.sh
110 | ├── run.slurm
111 | ├── run_dbp.sh
112 | ├── run_dbp.slurm
113 | ├── run_oea.sh
114 | ├── run_oea.slurm
115 | ├── train.py
116 | ├── utils.py
117 | └── vis.py
118 | ```
119 |
120 |
121 |
122 | #### Data Path
123 |
124 | 👈 🔎 Click
125 |
126 | ```
127 | mmkg
128 | ├─ DBP15K
129 | │ ├─ fr_en
130 | │ │ ├─ att_features100.npy
131 | │ │ ├─ att_features500.npy
132 | │ │ ├─ att_rel_features100.npy
133 | │ │ ├─ att_rel_features500.npy
134 | │ │ ├─ att_val_features100.npy
135 | │ │ ├─ att_val_features500.npy
136 | │ │ ├─ en_att_triples
137 | │ │ ├─ ent_ids_1
138 | │ │ ├─ ent_ids_2
139 | │ │ ├─ fr_att_triples
140 | │ │ ├─ ill_ent_ids
141 | │ │ ├─ training_attrs_1
142 | │ │ ├─ training_attrs_2
143 | │ │ ├─ triples_1
144 | │ │ └─ triples_2
145 | │ ├─ ja_en
146 | │ │ ├─ att_features100.npy
147 | │ │ ├─ att_features500.npy
148 | │ │ ├─ att_rel_features100.npy
149 | │ │ ├─ att_rel_features500.npy
150 | │ │ ├─ att_val_features100.npy
151 | │ │ ├─ att_val_features500.npy
152 | │ │ ├─ en_att_triples
153 | │ │ ├─ ent_ids_1
154 | │ │ ├─ ent_ids_2
155 | │ │ ├─ ill_ent_ids
156 | │ │ ├─ ja_att_triples
157 | │ │ ├─ training_attrs_1
158 | │ │ ├─ training_attrs_2
159 | │ │ ├─ triples_1
160 | │ │ └─ triples_2
161 | │ ├─ translated_ent_name
162 | │ │ ├─ dbp_fr_en.json
163 | │ │ ├─ dbp_ja_en.json
164 | │ │ └─ dbp_zh_en.json
165 | │ └─ zh_en
166 | │ ├─ att_features100.npy
167 | │ ├─ att_features500.npy
168 | │ ├─ att_rel_features100.npy
169 | │ ├─ att_rel_features500.npy
170 | │ ├─ att_val_features100.npy
171 | │ ├─ att_val_features500.npy
172 | │ ├─ en_att_triples
173 | │ ├─ ent_ids_1
174 | │ ├─ ent_ids_2
175 | │ ├─ ill_ent_ids
176 | │ ├─ rule_test.txt
177 | │ ├─ rule_train.txt
178 | │ ├─ training_attrs_1
179 | │ ├─ training_attrs_2
180 | │ ├─ triples_1
181 | │ ├─ triples_2
182 | │ └─ zh_att_triples
183 | ├─ FBDB15K
184 | │ └─ norm
185 | │ ├─ DB15K_NumericalTriples.txt
186 | │ ├─ FB15K_NumericalTriples.txt
187 | │ ├─ att_features.npy
188 | │ ├─ att_rel_features.npy
189 | │ ├─ att_val_features.npy
190 | │ ├─ ent_ids_1
191 | │ ├─ ent_ids_2
192 | │ ├─ fbid2name.txt
193 | │ ├─ id2relation.txt
194 | │ ├─ ill_ent_ids
195 | │ ├─ training_attrs_1
196 | │ ├─ training_attrs_2
197 | │ ├─ triples_1
198 | │ └─ triples_2
199 | ├─ FBYG15K
200 | │ └─ norm
201 | │ ├─ FB15K_NumericalTriples.txt
202 | │ ├─ YAGO15K_NumericalTriples.txt
203 | │ ├─ att_features.npy
204 | │ ├─ att_rel_features.npy
205 | │ ├─ att_val_features.npy
206 | │ ├─ ent_ids_1
207 | │ ├─ ent_ids_2
208 | │ ├─ fbid2name.txt
209 | │ ├─ id2relation.txt
210 | │ ├─ ill_ent_ids
211 | │ ├─ training_attrs_1
212 | │ ├─ training_attrs_2
213 | │ ├─ triples_1
214 | │ └─ triples_2
215 | ├─ MEAformer
216 | ├─ OpenEA
217 | │ ├─ OEA_D_W_15K_V1
218 | │ │ ├─ att_features.npy
219 | │ │ ├─ att_features500.npy
220 | │ │ ├─ att_rel_features.npy
221 | │ │ ├─ att_rel_features500.npy
222 | │ │ ├─ att_val_features.npy
223 | │ │ ├─ att_val_features500.npy
224 | │ │ ├─ attr_triples_1
225 | │ │ ├─ attr_triples_2
226 | │ │ ├─ ent_ids_1
227 | │ │ ├─ ent_ids_2
228 | │ │ ├─ ill_ent_ids
229 | │ │ ├─ rel_ids
230 | │ │ ├─ training_attrs_1
231 | │ │ ├─ training_attrs_2
232 | │ │ ├─ triples_1
233 | │ │ └─ triples_2
234 | │ ├─ OEA_D_W_15K_V2
235 | │ │ ├─ att_features.npy
236 | │ │ ├─ att_features500.npy
237 | │ │ ├─ att_rel_features.npy
238 | │ │ ├─ att_rel_features500.npy
239 | │ │ ├─ att_val_features.npy
240 | │ │ ├─ att_val_features500.npy
241 | │ │ ├─ attr_triples_1
242 | │ │ ├─ attr_triples_2
243 | │ │ ├─ ent_ids_1
244 | │ │ ├─ ent_ids_2
245 | │ │ ├─ ill_ent_ids
246 | │ │ ├─ rel_ids
247 | │ │ ├─ training_attrs_1
248 | │ │ ├─ training_attrs_2
249 | │ │ ├─ triples_1
250 | │ │ └─ triples_2
251 | │ ├─ OEA_D_Y_15K_V1
252 | │ │ ├─ 721_5fold
253 | │ │ │ ├─ 1
254 | │ │ │ │ ├─ test_links
255 | │ │ │ │ ├─ train_links
256 | │ │ │ │ └─ valid_links
257 | │ │ │ ├─ 2
258 | │ │ │ │ ├─ test_links
259 | │ │ │ │ ├─ train_links
260 | │ │ │ │ └─ valid_links
261 | │ │ │ ├─ 3
262 | │ │ │ │ ├─ test_links
263 | │ │ │ │ ├─ train_links
264 | │ │ │ │ └─ valid_links
265 | │ │ │ ├─ 4
266 | │ │ │ │ ├─ test_links
267 | │ │ │ │ ├─ train_links
268 | │ │ │ │ └─ valid_links
269 | │ │ │ └─ 5
270 | │ │ │ ├─ test_links
271 | │ │ │ ├─ train_links
272 | │ │ │ └─ valid_links
273 | │ │ ├─ attr_triples_1
274 | │ │ ├─ attr_triples_2
275 | │ │ ├─ ent_ids_1
276 | │ │ ├─ ent_ids_2
277 | │ │ ├─ ent_links
278 | │ │ ├─ ill_ent_ids
279 | │ │ ├─ rel_ids
280 | │ │ ├─ rel_triples_1
281 | │ │ ├─ rel_triples_2
282 | │ │ ├─ triples_1
283 | │ │ └─ triples_2
284 | │ ├─ OEA_D_Y_15K_V2
285 | │ │ ├─ 721_5fold
286 | │ │ │ ├─ 1
287 | │ │ │ │ ├─ test_links
288 | │ │ │ │ ├─ train_links
289 | │ │ │ │ └─ valid_links
290 | │ │ │ ├─ 2
291 | │ │ │ │ ├─ test_links
292 | │ │ │ │ ├─ train_links
293 | │ │ │ │ └─ valid_links
294 | │ │ │ ├─ 3
295 | │ │ │ │ ├─ test_links
296 | │ │ │ │ ├─ train_links
297 | │ │ │ │ └─ valid_links
298 | │ │ │ ├─ 4
299 | │ │ │ │ ├─ test_links
300 | │ │ │ │ ├─ train_links
301 | │ │ │ │ └─ valid_links
302 | │ │ │ └─ 5
303 | │ │ │ ├─ test_links
304 | │ │ │ ├─ train_links
305 | │ │ │ └─ valid_links
306 | │ │ ├─ attr_triples_1
307 | │ │ ├─ attr_triples_2
308 | │ │ ├─ ent_ids_1
309 | │ │ ├─ ent_ids_2
310 | │ │ ├─ ent_links
311 | │ │ ├─ ill_ent_ids
312 | │ │ ├─ rel_ids
313 | │ │ ├─ rel_triples_1
314 | │ │ ├─ rel_triples_2
315 | │ │ ├─ triples_1
316 | │ │ └─ triples_2
317 | │ ├─ OEA_EN_DE_15K_V1
318 | │ │ ├─ att_features.npy
319 | │ │ ├─ att_features500.npy
320 | │ │ ├─ att_rel_features.npy
321 | │ │ ├─ att_rel_features500.npy
322 | │ │ ├─ att_val_features.npy
323 | │ │ ├─ att_val_features500.npy
324 | │ │ ├─ attr_triples_1
325 | │ │ ├─ attr_triples_2
326 | │ │ ├─ ent_ids_1
327 | │ │ ├─ ent_ids_2
328 | │ │ ├─ ill_ent_ids
329 | │ │ ├─ rel_ids
330 | │ │ ├─ training_attrs_1
331 | │ │ ├─ training_attrs_2
332 | │ │ ├─ triples_1
333 | │ │ └─ triples_2
334 | │ ├─ OEA_EN_DE_15K_V2
335 | │ │ ├─ 721_5fold
336 | │ │ │ ├─ 1
337 | │ │ │ │ ├─ test_links
338 | │ │ │ │ ├─ train_links
339 | │ │ │ │ └─ valid_links
340 | │ │ │ ├─ 2
341 | │ │ │ │ ├─ test_links
342 | │ │ │ │ ├─ train_links
343 | │ │ │ │ └─ valid_links
344 | │ │ │ ├─ 3
345 | │ │ │ │ ├─ test_links
346 | │ │ │ │ ├─ train_links
347 | │ │ │ │ └─ valid_links
348 | │ │ │ ├─ 4
349 | │ │ │ │ ├─ test_links
350 | │ │ │ │ ├─ train_links
351 | │ │ │ │ └─ valid_links
352 | │ │ │ └─ 5
353 | │ │ │ ├─ test_links
354 | │ │ │ ├─ train_links
355 | │ │ │ └─ valid_links
356 | │ │ ├─ attr_triples_1
357 | │ │ ├─ attr_triples_2
358 | │ │ ├─ ent_ids_1
359 | │ │ ├─ ent_ids_2
360 | │ │ ├─ ent_links
361 | │ │ ├─ ill_ent_ids
362 | │ │ ├─ rel_ids
363 | │ │ ├─ rel_triples_1
364 | │ │ ├─ rel_triples_2
365 | │ │ ├─ triples_1
366 | │ │ └─ triples_2
367 | │ ├─ OEA_EN_FR_15K_V1
368 | │ │ ├─ att_features.npy
369 | │ │ ├─ att_rel_features.npy
370 | │ │ ├─ att_val_features.npy
371 | │ │ ├─ attr_triples_1
372 | │ │ ├─ attr_triples_2
373 | │ │ ├─ ent_ids_1
374 | │ │ ├─ ent_ids_2
375 | │ │ ├─ ill_ent_ids
376 | │ │ ├─ rel_ids
377 | │ │ ├─ training_attrs_1
378 | │ │ ├─ training_attrs_2
379 | │ │ ├─ triples_1
380 | │ │ └─ triples_2
381 | │ ├─ OEA_EN_FR_15K_V2
382 | │ │ ├─ 721_5fold
383 | │ │ │ ├─ 1
384 | │ │ │ │ ├─ test_links
385 | │ │ │ │ ├─ train_links
386 | │ │ │ │ └─ valid_links
387 | │ │ │ ├─ 2
388 | │ │ │ │ ├─ test_links
389 | │ │ │ │ ├─ train_links
390 | │ │ │ │ └─ valid_links
391 | │ │ │ ├─ 3
392 | │ │ │ │ ├─ test_links
393 | │ │ │ │ ├─ train_links
394 | │ │ │ │ └─ valid_links
395 | │ │ │ ├─ 4
396 | │ │ │ │ ├─ test_links
397 | │ │ │ │ ├─ train_links
398 | │ │ │ │ └─ valid_links
399 | │ │ │ └─ 5
400 | │ │ │ ├─ test_links
401 | │ │ │ ├─ train_links
402 | │ │ │ └─ valid_links
403 | │ │ ├─ attr_triples_1
404 | │ │ ├─ attr_triples_2
405 | │ │ ├─ ent_ids_1
406 | │ │ ├─ ent_ids_2
407 | │ │ ├─ ent_links
408 | │ │ ├─ ill_ent_ids
409 | │ │ ├─ rel_ids
410 | │ │ ├─ rel_triples_1
411 | │ │ ├─ rel_triples_2
412 | │ │ ├─ triples_1
413 | │ │ └─ triples_2
414 | │ ├─ pkl
415 | │ │ ├─ OEA_D_W_15K_V1_id_img_feature_dict.pkl
416 | │ │ ├─ OEA_D_W_15K_V2_id_img_feature_dict.pkl
417 | │ │ ├─ OEA_EN_DE_15K_V1_id_img_feature_dict.pkl
418 | │ │ └─ OEA_EN_FR_15K_V1_id_img_feature_dict.pkl
419 | │ └─ data.py
420 | ├─ dump
421 | ├─ embedding
422 | │ ├─ dbp_fr_en_char.pkl
423 | │ ├─ dbp_fr_en_name.pkl
424 | │ ├─ dbp_ja_en_char.pkl
425 | │ ├─ dbp_ja_en_name.pkl
426 | │ ├─ dbp_zh_en_char.pkl
427 | │ ├─ dbp_zh_en_name.pkl
428 | │ └─ glove.6B.300d.txt
429 | └─ pkls
430 | ├─ FBDB15K_id_img_feature_dict.pkl
431 | ├─ FBYG15K_id_img_feature_dict.pkl
432 | ├─ dbpedia_wikidata_15k_dense_GA_id_img_feature_dict.pkl
433 | ├─ dbpedia_wikidata_15k_norm_GA_id_img_feature_dict.pkl
434 | ├─ fr_en_GA_id_img_feature_dict.pkl
435 | ├─ ja_en_GA_id_img_feature_dict.pkl
436 | └─ zh_en_GA_id_img_feature_dict.pkl
437 | ```
438 |
439 |
440 |
441 | ## 🤝 Cite:
442 |
443 | Please condiser citing this paper if you use the ```code``` or ```data``` from our work.
444 | Thanks a lot :)
445 | ```
446 | @article{DBLP:journals/corr/abs-2402-11000,
447 | author = {Yangyifei Luo and
448 | Zhuo Chen and
449 | Lingbing Guo and
450 | Qian Li and
451 | Wenxuan Zeng and
452 | Zhixin Cai and
453 | Jianxin Li},
454 | title = {{ASGEA:} Exploiting Logic Rules from Align-Subgraphs for Entity Alignment},
455 | journal = {CoRR},
456 | volume = {abs/2402.11000},
457 | year = {2024}
458 | }
459 | ```
460 |
--------------------------------------------------------------------------------
/base_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import time
4 | from tqdm import tqdm
5 | from torch.optim import Adam
6 | from torch.optim.lr_scheduler import ExponentialLR
7 | from models import MASGNN
8 | from utils import cal_ranks, cal_performance
9 |
10 | class BaseModel(object):
11 | def __init__(self, args, loader):
12 | self.model = MASGNN(args, loader)
13 | self.model.cuda()
14 |
15 | self.loader = loader
16 | self.n_ent = loader.n_ent
17 | self.n_batch = args.n_batch
18 | self.n_rel = loader.n_rel
19 | self.left_ents = loader.left_ents
20 | self.right_ents = loader.right_ents
21 | self.shuffle = args.shuffle
22 |
23 | self.n_train = loader.n_train
24 | # self.n_valid = loader.n_valid
25 | self.n_test = loader.n_test
26 | self.n_layer = args.n_layer
27 |
28 | self.optimizer = Adam(self.model.parameters(), lr=args.lr, weight_decay=args.lamb)
29 | self.scheduler = ExponentialLR(self.optimizer, args.decay_rate)
30 | self.t_time = 0
31 |
32 | def train_batch(self,):
33 | epoch_loss = 0
34 | i = 0
35 |
36 | batch_size = self.n_batch
37 | n_batch = self.n_train // batch_size + (self.n_train % batch_size > 0)
38 | if self.shuffle:
39 | self.loader.shuffle_train()
40 |
41 | t_time = time.time()
42 | self.model.train()
43 | for i in tqdm(range(n_batch)):
44 | start = i*batch_size
45 | end = min(self.n_train, (i+1)*batch_size)
46 | batch_idx = np.arange(start, end)
47 | triple = self.loader.get_batch(batch_idx)
48 |
49 | self.model.zero_grad()
50 | scores = self.model(triple[:,0])
51 |
52 | pos_scores = scores[[torch.arange(len(scores)).cuda(),torch.LongTensor(triple[:,2]).cuda()]]
53 | max_n = torch.max(scores, 1, keepdim=True)[0]
54 | loss = torch.sum(- pos_scores + max_n + torch.log(torch.sum(torch.exp(scores - max_n),1)))
55 | # gamma = 0.1
56 | # lambd = 1
57 | # tau = 1
58 | # max_n = torch.max(scores, 1, keepdim=True)[0]
59 | # scores = max_n - scores
60 | # pos_scores = scores[[torch.arange(len(scores)).cuda(), torch.LongTensor(triple[:, 2]).cuda()]]
61 | # # extend pos_scores to scores
62 | # pos_scores = pos_scores.unsqueeze(-1)
63 | # l = gamma + pos_scores - scores
64 | # ln = (l - l.mean(dim=-1, keepdim=True).detach()) / l.std(dim=-1, keepdim=True).detach()
65 | # # ln = (l - mu) / torch.sqrt(sig + 1e-6)
66 | # loss = torch.sum(torch.log(1 + torch.sum(torch.exp(lambd * ln + tau), 1)))
67 |
68 | loss.backward()
69 | self.optimizer.step()
70 |
71 | # avoid NaN
72 | for p in self.model.parameters():
73 | X = p.data.clone()
74 | flag = X != X
75 | X[flag] = np.random.random()
76 | p.data.copy_(X)
77 | epoch_loss += loss.item()
78 | self.scheduler.step()
79 | self.t_time += time.time() - t_time
80 |
81 | t_mrr,t_h1, t_h3, t_h5, t_h10, out_str = self.evaluate()
82 | return t_mrr,t_h1, t_h3, t_h5, t_h10, out_str
83 |
84 | def evaluate(self, ):
85 | batch_size = self.n_batch
86 | i_time = time.time()
87 | n_data = self.n_test
88 | n_batch = n_data // batch_size + (n_data % batch_size > 0)
89 | ranking = []
90 | self.model.eval()
91 | for i in range(n_batch):
92 | start = i*batch_size
93 | end = min(n_data, (i+1)*batch_size)
94 | batch_idx = np.arange(start, end)
95 | triple = self.loader.get_batch(batch_idx, data='test')
96 | subs, rels, objs = triple[:,0],triple[:,1],triple[:,2]
97 | is_lefts = rels == self.n_rel*2+1
98 | scores = self.model(subs,'test').data.cpu().numpy()
99 |
100 | ranks = cal_ranks(scores, objs, is_lefts, len(self.left_ents))
101 | ranking += ranks
102 | ranking = np.array(ranking)
103 | t_mrr, t_h1, t_h3, t_h5, t_h10 = cal_performance(ranking)
104 | i_time = time.time() - i_time
105 |
106 | out_str = '[TEST] MRR:%.4f H@1:%.4f H@3:%.4f H@5:%.4f H@10:%.4f \t[TIME] inference:%.4f\n' % (t_mrr, t_h1, t_h3, t_h5, t_h10, i_time)
107 | return t_mrr,t_h1, t_h3, t_h5, t_h10, out_str
108 |
109 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import json
4 | import numpy as np
5 | import pdb
6 | import torch.distributed as dist
7 | import os
8 | import os.path as osp
9 | from collections import Counter
10 | import pickle
11 | import torch.nn.functional as F
12 | from transformers import BertTokenizer
13 | import torch.distributed
14 | from tqdm import tqdm
15 | import re
16 |
17 | from utils import get_topk_indices, get_adjr
18 |
19 |
20 | class EADataset(torch.utils.data.Dataset):
21 | def __init__(self, data):
22 | self.data = data
23 |
24 | def __len__(self):
25 | return len(self.data)
26 |
27 | def __getitem__(self, index):
28 | return self.data[index]
29 |
30 |
31 | class Collator_base(object):
32 | def __init__(self, args):
33 | self.args = args
34 |
35 | def __call__(self, batch):
36 | # pdb.set_trace()
37 |
38 | return np.array(batch)
39 |
40 |
41 | # def load_data(logger, args):
42 | # assert args.data_choice in ["DWY", "DBP15K", "FBYG15K", "FBDB15K"]
43 | # if args.data_choice in ["DWY", "DBP15K", "FBYG15K", "FBDB15K"]:
44 | # KGs, non_train, train_ill, test_ill, eval_ill, test_ill_ = load_eva_data(logger, args)
45 | #
46 | # elif args.data_choice in ["FBYG15K_attr", "FBDB15K_attr"]:
47 | # pass
48 | #
49 | # return KGs, non_train, train_ill, test_ill, eval_ill, test_ill_
50 | #
51 |
52 |
53 | def load_eva_data(args):
54 | if "OEA" in args.data_choice:
55 | file_dir = osp.join(args.data_path, "OpenEA", args.data_choice)
56 | else:
57 | file_dir = osp.join(args.data_path, args.data_choice, args.data_split)
58 | lang_list = [1, 2]
59 | ent2id_dict, ills, triples, r_hs, r_ts, ids = read_raw_data(file_dir, lang_list)
60 | e1 = os.path.join(file_dir, 'ent_ids_1')
61 | e2 = os.path.join(file_dir, 'ent_ids_2')
62 | left_ents,left_id2name = get_ids(e1,file_dir)
63 | right_ents,right_id2name = get_ids(e2,file_dir)
64 | id2name = {**left_id2name, **right_id2name}
65 | if not args.data_choice == "DBP15K" and not args.data_choice == "OpenEA":
66 | id2rel = get_id2rel(os.path.join(file_dir, 'id2relation.txt'))
67 | elif args.data_choice == "OpenEA":
68 | id2rel = get_id2rel(os.path.join(file_dir, 'rel_ids'))
69 | else:
70 | id2rel = None
71 | ENT_NUM = len(ent2id_dict)
72 | REL_NUM = len(r_hs)
73 | np.random.shuffle(ills)
74 | if args.mm:
75 | if args.data_choice == "OpenEA":
76 | img_vec_path = osp.join(args.data_path, f"OpenEA/pkl/{args.data_split}_id_img_feature_dict.pkl")
77 | elif "FB" in file_dir:
78 | img_vec_path = osp.join(args.data_path, f"pkls/{args.data_choice}_id_img_feature_dict.pkl")
79 | else:
80 | # fr_en
81 | split = file_dir.split("/")[-1]
82 | img_vec_path = osp.join(args.data_path, "pkls", args.data_split + "_GA_id_img_feature_dict.pkl")
83 |
84 | assert osp.exists(img_vec_path)
85 | img_features = load_img(ENT_NUM, img_vec_path)
86 | print(f"image feature shape:{img_features.shape}")
87 |
88 | if args.word_embedding == "glove":
89 | word2vec_path = os.path.join(args.data_path, "embedding", "glove.6B.300d.txt")
90 | elif args.word_embedding == 'bert':
91 | pass
92 | else:
93 | raise Exception("error word embedding")
94 | else:
95 | img_features = None
96 |
97 | name_features = None
98 | char_features = None
99 | # if args.data_choice == "DBP15K" and (args.w_name or args.w_char):
100 |
101 | # assert osp.exists(word2vec_path)
102 | # ent_vec, char_features = load_word_char_features(ENT_NUM, word2vec_path, args)
103 | # name_features = F.normalize(torch.Tensor(ent_vec))
104 | # char_features = F.normalize(torch.Tensor(char_features))
105 | # print(f"name feature shape:{name_features.shape}")
106 | # print(f"char feature shape:{char_features.shape}")
107 | img_ill = None
108 | if args.mm:
109 | input_features = F.normalize(torch.Tensor(img_features))
110 | img_ill = visual_pivot_induction(args, left_ents, right_ents, input_features, ills)
111 |
112 | train_ill = np.array(ills[:int(len(ills) // 1 * args.data_rate)], dtype=np.int32)
113 |
114 | test_ill_ = ills[int(len(ills) // 1 * args.data_rate):]
115 | test_ill = np.array(test_ill_, dtype=np.int32)
116 |
117 | test_left = torch.LongTensor(test_ill[:, 0].squeeze())
118 | test_right = torch.LongTensor(test_ill[:, 1].squeeze())
119 |
120 | left_non_train = list(set(left_ents) - set(train_ill[:, 0].tolist()))
121 |
122 | right_non_train = list(set(right_ents) - set(train_ill[:, 1].tolist()))
123 |
124 | print(f"#left entity : {len(left_ents)}, #right entity: {len(right_ents)}")
125 | print(f"#left entity not in train set: {len(left_non_train)}, #right entity not in train set: {len(right_non_train)}")
126 |
127 | rel_features = load_relation(ENT_NUM, triples, 1000)
128 | print(f"relation feature shape:{rel_features.shape}")
129 | if 'OpenEA' in args.data_choice:
130 | a1 = os.path.join(file_dir, f'attr_triples_1')
131 | a2 = os.path.join(file_dir, f'attr_triples_2')
132 | att_features, num_att_left, num_att_right = load_attr_withNums(['oea', 'oea'], [a1, a2], ent2id_dict, file_dir,
133 | topk=args.topk)
134 | elif 'FB' in args.data_choice:
135 | a1 = os.path.join(file_dir, 'FB15K_NumericalTriples.txt')
136 | a2 = os.path.join(file_dir, 'DB15K_NumericalTriples.txt') if 'DB' in args.data_choice else os.path.join(file_dir, 'YAGO15K_NumericalTriples.txt')
137 | att_features, num_att_left, num_att_right = load_attr_withNums(['FB15K','DB15K'] if 'DB' in args.data_choice else ['FB15K','YAGO15K'],[a1, a2], ent2id_dict, file_dir, topk=0)
138 | else:
139 | att1,att2 = args.data_split.split('_')
140 | a1 = os.path.join(file_dir, f'{att1}_att_triples')
141 | a2 = os.path.join(file_dir, f'{att2}_att_triples')
142 | att_features, num_att_left, num_att_right = load_attr_withNums([att1,att2],[a1, a2], ent2id_dict, file_dir, topk=args.topk)
143 | print(f"attribute feature shape:{len(att_features)}")
144 | print("-----dataset summary-----")
145 | print(f"dataset:\t\t {file_dir}")
146 | print(f"triple num:\t {len(triples)}")
147 | print(f"entity num:\t {ENT_NUM}")
148 | print(f"relation num:\t {REL_NUM}")
149 | print(f"train ill num:\t {train_ill.shape[0]} \t test ill num:\t {test_ill.shape[0]}")
150 | print("-------------------------")
151 |
152 | eval_ill = None
153 | input_idx = torch.LongTensor(np.arange(ENT_NUM))
154 |
155 | # pdb.set_trace()
156 | # train_ill = EADataset(train_ill)
157 | # test_ill = EADataset(test_ill)
158 |
159 | return {
160 | 'ent_num': ENT_NUM,
161 | 'rel_num': REL_NUM,
162 | 'images_list': img_features,
163 | 'rel_features': rel_features,
164 | 'att_features': att_features,
165 | 'num_att_left': num_att_left,
166 | 'num_att_right': num_att_right,
167 | 'name_features': name_features,
168 | 'char_features': char_features,
169 | 'input_idx': input_idx,
170 | 'triples': triples,
171 | 'id2name':id2name,
172 | 'id2rel':id2rel,
173 | 'img_ill':img_ill
174 | }, {"left": left_non_train, "right": right_non_train},left_ents,right_ents, train_ill, test_ill, eval_ill, test_ill_
175 |
176 |
177 | def load_word2vec(path, dim=300):
178 | """
179 | glove or fasttext embedding
180 | """
181 | # print('\n', path)
182 | word2vec = dict()
183 | err_num = 0
184 | err_list = []
185 |
186 | with open(path, 'r', encoding='utf-8') as file:
187 | for line in tqdm(file.readlines(), desc="load word embedding"):
188 | line = line.strip('\n').split(' ')
189 | if len(line) != dim + 1:
190 | continue
191 | try:
192 | v = np.array(list(map(float, line[1:])), dtype=np.float64)
193 | word2vec[line[0].lower()] = v
194 | except:
195 | err_num += 1
196 | err_list.append(line[0])
197 | continue
198 | file.close()
199 | print("err list ", err_list)
200 | print("err num ", err_num)
201 | return word2vec
202 |
203 |
204 | def load_char_bigram(path):
205 | """
206 | character bigrams of translated entity names
207 | """
208 | # load the translated entity names
209 | ent_names = json.load(open(path, "r"))
210 | # generate the bigram dictionary
211 | char2id = {}
212 | count = 0
213 | for _, name in ent_names:
214 | for word in name:
215 | word = word.lower()
216 | for idx in range(len(word) - 1):
217 | if word[idx:idx + 2] not in char2id:
218 | char2id[word[idx:idx + 2]] = count
219 | count += 1
220 | return ent_names, char2id
221 |
222 |
223 | def load_word_char_features(node_size, word2vec_path, args):
224 | """
225 | node_size : ent num
226 | """
227 | name_path = os.path.join(args.data_path, "DBP15K", "translated_ent_name", "dbp_" + args.data_split + ".json")
228 | assert osp.exists(name_path)
229 | save_path_name = os.path.join(args.data_path, "embedding", f"dbp_{args.data_split}_name.pkl")
230 | save_path_char = os.path.join(args.data_path, "embedding", f"dbp_{args.data_split}_char.pkl")
231 | if osp.exists(save_path_name) and osp.exists(save_path_char):
232 | print(f"load entity name emb from {save_path_name} ... ")
233 | ent_vec = pickle.load(open(save_path_name, "rb"))
234 | print(f"load entity char emb from {save_path_char} ... ")
235 | char_vec = pickle.load(open(save_path_char, "rb"))
236 | return ent_vec, char_vec
237 |
238 | word_vecs = load_word2vec(word2vec_path)
239 | ent_names, char2id = load_char_bigram(name_path)
240 |
241 | # generate the word-level features and char-level features
242 |
243 | ent_vec = np.zeros((node_size, 300))
244 | char_vec = np.zeros((node_size, len(char2id)))
245 | for i, name in ent_names:
246 | k = 0
247 | for word in name:
248 | word = word.lower()
249 | if word in word_vecs:
250 | ent_vec[i] += word_vecs[word]
251 | k += 1
252 | for idx in range(len(word) - 1):
253 | char_vec[i, char2id[word[idx:idx + 2]]] += 1
254 | if k:
255 | ent_vec[i] /= k
256 | else:
257 | ent_vec[i] = np.random.random(300) - 0.5
258 |
259 | if np.sum(char_vec[i]) == 0:
260 | char_vec[i] = np.random.random(len(char2id)) - 0.5
261 | ent_vec[i] = ent_vec[i] / np.linalg.norm(ent_vec[i])
262 | char_vec[i] = char_vec[i] / np.linalg.norm(char_vec[i])
263 |
264 | with open(save_path_name, 'wb') as f:
265 | pickle.dump(ent_vec, f)
266 | with open(save_path_char, 'wb') as f:
267 | pickle.dump(char_vec, f)
268 | print("save entity emb done. ")
269 | return ent_vec, char_vec
270 |
271 |
272 | def visual_pivot_induction(args, left_ents, right_ents, img_features, ills):
273 |
274 | l_img_f = img_features[left_ents] # left images
275 | r_img_f = img_features[right_ents] # right images
276 |
277 | img_sim = l_img_f.mm(r_img_f.t())
278 | topk = args.img_ill_k
279 | two_d_indices = get_topk_indices(img_sim, topk * 100)
280 | del l_img_f, r_img_f, img_sim
281 |
282 | visual_links = []
283 | used_inds = []
284 | count = 0
285 | for ind in two_d_indices:
286 | if left_ents[ind[0]] in used_inds:
287 | continue
288 | if right_ents[ind[1]] in used_inds:
289 | continue
290 | used_inds.append(left_ents[ind[0]])
291 | used_inds.append(right_ents[ind[1]])
292 | visual_links.append((left_ents[ind[0]], right_ents[ind[1]]))
293 | count += 1
294 | if count == topk:
295 | break
296 |
297 | count = 0.0
298 | for link in visual_links:
299 | if link in ills:
300 | count = count + 1
301 | print(f"{(count / len(visual_links) * 100):.2f}% in true links")
302 | print(f"visual links length: {(len(visual_links))}")
303 | train_ill = np.array(visual_links, dtype=np.int32)
304 | return train_ill
305 |
306 |
307 | def read_raw_data(file_dir, lang=[1, 2]):
308 | """
309 | Read DBP15k/DWY15k dataset.
310 | Parameters
311 | ----------
312 | file_dir: root of the dataset.
313 | Returns
314 | -------
315 | ent2id_dict : A dict mapping from entity name to ids
316 | ills: inter-lingual links (specified by ids)
317 | triples: a list of tuples (ent_id_1, relation_id, ent_id_2)
318 | r_hs: a dictionary containing mappings of relations to a list of entities that are head entities of the relation
319 | r_ts: a dictionary containing mappings of relations to a list of entities that are tail entities of the relation
320 | ids: all ids as a list
321 | """
322 | print('loading raw data...')
323 |
324 | def read_file(file_paths):
325 | tups = []
326 | for file_path in file_paths:
327 | with open(file_path, "r", encoding="utf-8") as fr:
328 | for line in fr:
329 | params = line.strip("\n").split("\t")
330 | tups.append(tuple([int(x) for x in params]))
331 | return tups
332 |
333 | def read_dict(file_paths):
334 | ent2id_dict = {}
335 | ids = []
336 | for file_path in file_paths:
337 | id = set()
338 | with open(file_path, "r", encoding="utf-8") as fr:
339 | for line in fr:
340 | params = line.strip("\n").split("\t")
341 | ent2id_dict[params[1]] = int(params[0])
342 | id.add(int(params[0]))
343 | ids.append(id)
344 | return ent2id_dict, ids
345 | ent2id_dict, ids = read_dict([file_dir + "/ent_ids_" + str(i) for i in lang])
346 | ills = read_file([file_dir + "/ill_ent_ids"])
347 | triples = read_file([file_dir + "/triples_" + str(i) for i in lang])
348 | r_hs, r_ts = {}, {}
349 | for (h, r, t) in triples:
350 | if r not in r_hs:
351 | r_hs[r] = set()
352 | if r not in r_ts:
353 | r_ts[r] = set()
354 | r_hs[r].add(h)
355 | r_ts[r].add(t)
356 | assert len(r_hs) == len(r_ts)
357 | return ent2id_dict, ills, triples, r_hs, r_ts, ids
358 |
359 |
360 | def loadfile(fn, num=1):
361 | print('loading a file...' + fn)
362 | ret = []
363 | with open(fn, encoding='utf-8') as f:
364 | for line in f:
365 | th = line[:-1].split('\t')
366 | x = []
367 | for i in range(num):
368 | x.append(int(th[i]))
369 | ret.append(tuple(x))
370 | return ret
371 |
372 |
373 | def get_ids(fn,file_dir):
374 | ids = []
375 | id2name = {}
376 | fbid2name = {}
377 | if 'FB' in fn:
378 | with open(os.path.join(file_dir, 'fbid2name.txt'), encoding='utf-8') as f:
379 | for line in f:
380 | th = line[:-1].split('\t')
381 | fbid2name[th[0]] = th[1]
382 | with open(fn, encoding='utf-8') as f:
383 | for line in f:
384 | th = line[:-1].split('\t')
385 | ids.append(int(th[0]))
386 | name = th[1]
387 | if ''==s[-1]:
436 | s = s[1:-1]
437 | t = s.split('/')[-1].replace('_',' ')
438 | t_ = ' '.join(split_camel_case(t))
439 | if t_ == '':
440 | return t
441 | return t_
442 |
443 |
444 | def dbp_value(s):
445 | # print(s)
446 | if '^^' in s:
447 | s = s.split("^^")[0]
448 | if ('<' == s[0] and '>' == s[-1]) or ('\"' == s[0] and '\"' == s[-1]):
449 | s = s[1:-1]
450 | elif '@' in s and s.index('@')>0:
451 | s = '@'.join(s.split('@')[:-1])
452 | if ('<' == s[0] and '>' == s[-1]) or ('\"' == s[0] and '\"' == s[-1]):
453 | s = s[1:-1]
454 | # print(s)
455 | if s[-1]=='\"':
456 | s = s[:-1]
457 | else:
458 | if ('<' == s[0] and '>' == s[-1]) or ('\"' == s[0] and '\"' == s[-1]):
459 | s = s[1:-1]
460 | return s
461 | if 'e' in s:
462 | return s
463 |
464 | if '-' not in s[1:]:
465 | return s
466 | try:
467 | s_ = s.split('-')
468 | y = int(s_[0].replace('#','0'))
469 | m = int(s_[1]) if s_[1]!='##'else 1
470 | d = int(s_[2]) if s_[2]!='##' and s_[2]!='' else 1
471 | return y + (m-1)/12 +(d-1)/30/12
472 | except:
473 | return s
474 |
475 |
476 |
477 | def load_attr_withNums(datas,fns, ent2id_dict, file_dir, topk=0):
478 | ans = [load_attr_withNum(data,fn,ent2id_dict) for data,fn in zip(datas,fns)]
479 | if topk!=0:
480 |
481 | rels = []
482 | rels2index = {}
483 | rels2times = {}
484 | cur = 0
485 | att2rel = []
486 | for i, att in enumerate(ans[0]+ans[1]):
487 | if att[1] not in rels2index:
488 | rels2index[att[1]] = cur
489 | rels.append(att[1])
490 | cur += 1
491 | rels2times[att[1]] = 0
492 | rels2times[att[1]] += 1
493 | att2rel.append(rels2index[att[1]])
494 | att2rel = np.array(att2rel)
495 |
496 | rels_left = []
497 | rels2index_left = {}
498 | cur = 0
499 | att2rel_left = []
500 | for i, att in enumerate(ans[0]):
501 | if att[1] not in rels2index_left:
502 | rels2index_left[att[1]] = cur
503 | rels_left.append(att[1])
504 | cur += 1
505 | att2rel_left.append(rels2index_left[att[1]])
506 | att2rel_left = np.array(att2rel_left)
507 |
508 |
509 | rels_right = []
510 | rels2index_right = {}
511 | cur = 0
512 | att2rel_right = []
513 | for i, att in enumerate(ans[1]):
514 | if att[1] not in rels2index_right:
515 | rels2index_right[att[1]] = cur
516 | rels_right.append(att[1])
517 | cur += 1
518 | att2rel_right.append(rels2index_right[att[1]])
519 | att2rel_right = np.array(att2rel_right)
520 |
521 | rels_right = set(rels_right)
522 | rels_left = set(rels_left)
523 | rels_inter = rels_left.intersection(rels_right)
524 | if len(rels_inter)==0:
525 | rels_inter = rels
526 | # select topk
527 | rels_inter = sorted(rels_inter, key=lambda x: rels2times[x], reverse=True)[:topk]
528 |
529 | ans_ = []
530 | for i in ans[0]:
531 | if i[1] in rels_inter:
532 | ans_.append(i)
533 | num_left = len(ans_)
534 | for i in ans[1]:
535 | if i[1] in rels_inter:
536 | ans_.append(i)
537 | num_right = len(ans_)-num_left
538 | return ans_,num_left,num_right
539 |
540 |
541 |
542 | # num_att_left = len(rels2index)
543 | # att_rel_features = np.load(os.path.join(file_dir, 'att_rel_features.npy'), allow_pickle=True)
544 | # rels = torch.FloatTensor(att_rel_features).cuda()
545 | # sim_rels_left = torch.mm(rels[:num_att_left], rels[num_att_left:].T)
546 | # sim_rels_right = torch.mm(rels[num_att_left:], rels[:num_att_left].T)
547 | # # get the max sim at row
548 | # sim_rels_left = torch.max(sim_rels_left, dim=1)[0]
549 | # sim_rels_right = torch.max(sim_rels_right, dim=1)[0]
550 | # # get the topk rels
551 | # topk_rels_left = torch.topk(sim_rels_left, topk, dim=0)[1]
552 | # topk_rels_right = torch.topk(sim_rels_right, topk, dim=0)[1]
553 | #
554 | # topk_rels_left = topk_rels_left.cpu().numpy()
555 | # topk_rels_right = topk_rels_right.cpu().numpy()
556 | # # topk_rels = np.concatenate([topk_rels_left,topk_rels_right+num_att_left])
557 | #
558 | #
559 | # # contain topkrels
560 | # common_elements = np.in1d(att2rel, topk_rels_left)
561 | # common_elements_indices = list(np.where(common_elements)[0])
562 | # ans_ = []
563 | # for i in common_elements_indices:
564 | # ans_.append(ans[0][i])
565 | # num_left = len(ans_)
566 | #
567 | # rels = []
568 | # rels2index = {}
569 | # cur = 0
570 | # att2rel = []
571 | # for i,att in enumerate(ans[1]):
572 | # if att[1] not in rels2index:
573 | # rels2index[att[1]] = cur
574 | # rels.append(att[1])
575 | # cur += 1
576 | # att2rel.append(rels2index[att[1]])
577 | # att2rel = np.array(att2rel)
578 | # # contain topkrels
579 | # common_elements = np.in1d(att2rel, topk_rels_right)
580 | # common_elements_indices = list(np.where(common_elements)[0])
581 | # for i in common_elements_indices:
582 | # ans_.append(ans[1][i])
583 | # num_right = len(ans_) - num_left
584 | # return ans_,num_left,num_right
585 |
586 |
587 |
588 |
589 |
590 | return ans[0]+ans[1], len(ans[0]), len(ans[1])
591 | def load_attr_withNum(data, fn, ent2id):
592 |
593 | with open(fn, 'r',encoding='utf-8') as f:
594 | Numericals = f.readlines()
595 | if data == 'FB15K' or data == 'DB15K' or data=='YAGO15K':
596 | Numericals_ = list(set(Numericals))
597 | Numericals_.sort(key = Numericals.index)
598 | Numericals = Numericals_
599 |
600 | if data=='FB15K':
601 | Numericals = [i[:-1].split('\t') for i in Numericals]
602 | Numericals = [(ent2id[i[0]], i[1][1:-1].replace('http://rdf.freebase.com/ns/', '').split('.')[-1].replace('_',' '), i[2]) for i in
603 | Numericals]
604 | elif data=='DB15K':
605 | Numericals = [i[:-1].split(' ') if '\t' not in i else i[:-1].split('\t') for i in Numericals]
606 | Numericals = [(ent2id[i[0]], db_str(i[1]), db_time(i[2])) for i in Numericals]
607 |
608 | elif data=='YAGO15K':
609 | Numericals = [i[:-1].split(' ') if '\t' not in i else i[:-1].split('\t') for i in Numericals]
610 | Numericals = [(ent2id[i[0]], db_str(i[1]), db_time(i[2])) for i in Numericals]
611 | elif data=='oea':
612 | Numericals = [i[:-1].split('\t') for i in Numericals]
613 | Numericals = [(ent2id[i[0]], dbp_str(i[1]), dbp_value(i[2])) for i in Numericals]
614 | else:
615 | Numericals = [i[:-1].split(' ') if '\t' not in i else i[:-1].split('\t') for i in Numericals]
616 | Numericals = [(ent2id[i[0][1:-1]], dbp_str(i[1]), dbp_value(' '.join(i[2:]))) for i in Numericals]
617 |
618 | return Numericals
619 |
620 |
621 | # The most frequent attributes are selected to save space
622 | def load_attr(fns, e, ent2id, topA=1000):
623 | cnt = {}
624 | for fn in fns:
625 | with open(fn, 'r', encoding='utf-8') as f:
626 | for line in f:
627 | th = line[:-1].split('\t')
628 | if th[0] not in ent2id:
629 | continue
630 | for i in range(1, len(th)):
631 | if th[i] not in cnt:
632 | cnt[th[i]] = 1
633 | else:
634 | cnt[th[i]] += 1
635 | fre = [(k, cnt[k]) for k in sorted(cnt, key=cnt.get, reverse=True)]
636 | attr2id = {}
637 | # pdb.set_trace()
638 | topA = min(1000, len(fre))
639 | for i in range(topA):
640 | attr2id[fre[i][0]] = i
641 | attr = np.zeros((e, topA), dtype=np.float32)
642 | for fn in fns:
643 | with open(fn, 'r', encoding='utf-8') as f:
644 | for line in f:
645 | th = line[:-1].split('\t')
646 | if th[0] in ent2id:
647 | for i in range(1, len(th)):
648 | if th[i] in attr2id:
649 | attr[ent2id[th[0]]][attr2id[th[i]]] = 1.0
650 | return attr
651 |
652 |
653 | def load_relation(e, KG, topR=1000):
654 | # (39654, 1000)
655 | rel_mat = np.zeros((e, topR), dtype=np.float32)
656 | rels = np.array(KG)[:, 1]
657 | top_rels = Counter(rels).most_common(topR)
658 | rel_index_dict = {r: i for i, (r, cnt) in enumerate(top_rels)}
659 | for tri in KG:
660 | h = tri[0]
661 | r = tri[1]
662 | o = tri[2]
663 | if r in rel_index_dict:
664 | rel_mat[h][rel_index_dict[r]] += 1.
665 | rel_mat[o][rel_index_dict[r]] += 1.
666 | return np.array(rel_mat)
667 |
668 |
669 | def load_json_embd(path):
670 | embd_dict = {}
671 | with open(path) as f:
672 | for line in f:
673 | example = json.loads(line.strip())
674 | vec = np.array([float(e) for e in example['feature'].split()])
675 | embd_dict[int(example['guid'])] = vec
676 | return embd_dict
677 |
678 |
679 | def load_img(e_num, path):
680 | img_dict = pickle.load(open(path, "rb"))
681 | # init unknown img vector with mean and std deviation of the known's
682 | imgs_np = np.array(list(img_dict.values()))
683 | mean = np.mean(imgs_np, axis=0)
684 | std = np.std(imgs_np, axis=0)
685 | # img_embd = np.array([np.zeros_like(img_dict[0]) for i in range(e_num)]) # no image
686 | # img_embd = np.array([img_dict[i] if i in img_dict else np.zeros_like(img_dict[0]) for i in range(e_num)])
687 |
688 | img_embd = np.array([img_dict[i] if i in img_dict else np.random.normal(mean, std, mean.shape[0]) for i in range(e_num)])
689 | print(f"{(100 * len(img_dict) / e_num):.2f}% entities have images")
690 | return img_embd
691 |
--------------------------------------------------------------------------------
/load_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | import torch
5 | from scipy.sparse import csr_matrix
6 | import numpy as np
7 | from collections import defaultdict
8 | from data import load_eva_data
9 | import pickle
10 | from tqdm import tqdm
11 | # import lmdb
12 | class DataLoader:
13 | def __init__(self, args):
14 |
15 | KGs, non_train, left_ents, right_ents, train_ill, test_ill, eval_ill, test_ill_ = load_eva_data(args)
16 | ent_num = KGs['ent_num']
17 | rel_num = KGs['rel_num']
18 | self.img_ill = KGs['img_ill']
19 | self.use_img_ill = args.use_img_ill
20 | self.images_list = KGs['images_list']
21 | self.rel_features = KGs['rel_features']
22 | self.att_features = KGs['att_features']
23 | self.num_att_left = KGs['num_att_left']
24 | self.num_att_right = KGs['num_att_right']
25 | self.id2name = KGs['id2name']
26 | self.id2rel = KGs['id2rel']
27 | self.left_ents = [i for i in range(len(left_ents))]
28 | self.right_ents = [len(left_ents) + i for i in range(len(right_ents))]
29 | old_ids = np.array(left_ents+right_ents)
30 | # new_ids = torch.arange(len(self.left_ents+self.right_ents))
31 | # old2new = torch.zeros(len(self.left_ents+self.right_ents)).long()
32 | # old2new[old_ids] = new_ids
33 | # self.old2new = old2new
34 | self.old_ids = old_ids
35 | if args.mm:
36 | self.images_list = self.images_list[self.old_ids]
37 | self.old2new_dict = {oldid:newid for newid,oldid in enumerate(left_ents+right_ents)}
38 | triples = KGs['triples']
39 | triples = [(self.old2new_dict[tri[0]],tri[1],self.old2new_dict[tri[2]]) for tri in triples]
40 | train_ill = np.array([(self.old2new_dict[tri[0]],self.old2new_dict[tri[1]]) for tri in train_ill])
41 | test_ill = np.array([(self.old2new_dict[tri[0]],self.old2new_dict[tri[1]]) for tri in test_ill])
42 | if args.mm:
43 | self.img_ill = np.array([(self.old2new_dict[tri[0]],self.old2new_dict[tri[1]]) for tri in self.img_ill])
44 |
45 |
46 | # self.att_features_text = np.array(KGs['att_features'])
47 | self.att2rel ,self.rels = self.process_rels(self.att_features)
48 | self.att_ids = [self.old2new_dict[i[0]] for i in self.att_features]
49 |
50 | self.ids_att = {}
51 | for att_index,ids in enumerate(self.att_ids):
52 | if ids not in self.ids_att:
53 | self.ids_att[ids] = []
54 | self.ids_att[ids].append(att_index)
55 | # self.test_cache_url = os.path.join(args.data_path, args.data_choice, args.data_split, f'test_{args.data_rate}')
56 | # self.test_cache = {}
57 |
58 | if args.mm:
59 | if args.topk == 0:
60 | if os.path.exists(os.path.join(args.data_path, args.data_choice, args.data_split, 'att_features.npy')):
61 | self.att_features = np.load(os.path.join(args.data_path, args.data_choice, args.data_split, 'att_features.npy'), allow_pickle=True)
62 | self.att_rel_features = np.load(os.path.join(args.data_path, args.data_choice, args.data_split, 'att_rel_features.npy'), allow_pickle=True)
63 | self.att_val_features = np.load(os.path.join(args.data_path, args.data_choice, args.data_split, 'att_val_features.npy'), allow_pickle=True)
64 | else:
65 | self.att_features, self.att_rel_features,self.att_val_features = self.bert_feature()
66 | np.save(os.path.join(args.data_path, args.data_choice, args.data_split, 'att_features.npy'), self.att_features)
67 | np.save(os.path.join(args.data_path, args.data_choice, args.data_split, 'att_rel_features.npy'), self.att_rel_features)
68 | np.save(os.path.join(args.data_path, args.data_choice, args.data_split, 'att_val_features.npy'), self.att_val_features)
69 | else:
70 | if os.path.exists(os.path.join(args.data_path, args.data_choice, args.data_split, f'att_features{args.topk}.npy')):
71 | self.att_features = np.load(os.path.join(args.data_path, args.data_choice, args.data_split, f'att_features{args.topk}.npy'), allow_pickle=True)
72 | self.att_rel_features = np.load(os.path.join(args.data_path, args.data_choice, args.data_split, f'att_rel_features{args.topk}.npy'), allow_pickle=True)
73 | self.att_val_features = np.load(os.path.join(args.data_path, args.data_choice, args.data_split, f'att_val_features{args.topk}.npy'), allow_pickle=True)
74 | else:
75 | self.att_features, self.att_rel_features,self.att_val_features = self.bert_feature()
76 | np.save(os.path.join(args.data_path, args.data_choice, args.data_split, f'att_features{args.topk}.npy'), self.att_features)
77 | np.save(os.path.join(args.data_path, args.data_choice, args.data_split, f'att_rel_features{args.topk}.npy'), self.att_rel_features)
78 | np.save(os.path.join(args.data_path, args.data_choice, args.data_split, f'att_val_features{args.topk}.npy'), self.att_val_features)
79 | # for i1,i2 in train_ill:
80 | # f1 = self.att_features[np.array(self.att_ids)==i1]
81 | # f2 = self.att_features[np.array(self.att_ids)==i2]
82 | # print('-'*30)
83 | # print('1',self.att_features_text[np.array(self.att_ids)==i1])
84 | # print('2',self.att_features_text[np.array(self.att_ids)==i2])
85 |
86 | # for f1i in f1:
87 | # for f2i in f2:
88 | # print(f1i.dot(f2i))
89 | # f1 = self.att_rel_features[self.att2rel[np.array(self.att_ids)==i1]]
90 | # f2 = self.att_rel_features[self.att2rel[np.array(self.att_ids)==i2]]
91 | # print()
92 | # for f1i in f1:
93 | # for f2i in f2:
94 | # print(f1i.dot(f2i))
95 |
96 | # f1 = self.att_val_features[np.array(self.att_ids)==i1]
97 | # f2 = self.att_val_features[np.array(self.att_ids)==i2]
98 | # print()
99 | # for f1i in f1:
100 | # for f2i in f2:
101 | # print(f1i.dot(f2i))
102 |
103 |
104 |
105 |
106 |
107 | self.n_ent = ent_num
108 | self.n_rel = rel_num
109 |
110 | self.filters = defaultdict(lambda: set())
111 |
112 | self.fact_triple = triples
113 |
114 | self.train_triple = self.ill2triples(train_ill)
115 | self.valid_triple = eval_ill # None
116 | self.test_triple = self.ill2triples(test_ill)
117 |
118 | # add inverse
119 | self.fact_data = self.double_triple(self.fact_triple)
120 | # self.train_data = np.array(self.double_triple(self.train_triple))
121 | # self.valid_data = self.double_triple(self.valid_triple)
122 | self.test_data = self.double_triple(self.test_triple, ill=True)
123 | self.test_data = np.array(self.test_data)
124 | self.train_data = self.double_triple(self.train_triple, ill=True)
125 | self.train_data = np.array(self.train_data)
126 | if self.use_img_ill:
127 | self.img_ill_triple = self.img_ill2triples(self.img_ill)
128 | self.img_ill_triple = self.double_triple(self.img_ill_triple, ill=True)
129 | self.img_ill_triple = np.array(self.img_ill_triple)
130 | self.img_ill_data = torch.LongTensor(self.img_ill_triple).cuda()
131 |
132 | # self.KG,self.M_sub = self.load_graph(self.fact_data) # do it in shuffle_train
133 | self.tKG = self.load_graph(self.fact_data + self.double_triple(self.train_triple, ill=True))
134 | self.tKG = torch.LongTensor(self.tKG).cuda()
135 |
136 | # in torch
137 | idd = np.concatenate([np.expand_dims(np.arange(self.n_ent), 1), 2 * self.n_rel * np.ones((self.n_ent, 1)),
138 | np.expand_dims(np.arange(self.n_ent), 1)], 1)
139 | self.fact_data = np.concatenate([np.array(self.fact_data), idd], 0)
140 | self.fact_data = torch.LongTensor(self.fact_data).cuda()
141 | # self.node2index = {}
142 | # for i, triple in enumerate(self.train_triple):
143 | # h, r, t = triple
144 | # assert h not in self.node2index
145 | # assert t not in self.node2index
146 | # self.node2index[h] = i
147 | # self.node2index[t] = i
148 | # self.train_triple = torch.LongTensor(self.train_triple).cuda()
149 |
150 |
151 | self.n_test = len(self.test_data)
152 | self.n_train = len(self.train_data)
153 | self.shuffle_train()
154 |
155 | # if os.path.exists(self.test_cache_url):
156 | # self.test_env = lmdb.open(self.test_cache_url)
157 | # else:
158 | # self.test_env = lmdb.open(self.test_cache_url, map_size=200*1024 * 1024 * 1024, max_dbs=1)
159 | # self.preprocess_test()
160 | def process_rels(self, atts):
161 | rels = []
162 | rels2index = {}
163 | cur = 0
164 | att2rel = []
165 | for i,att in enumerate(atts):
166 | if att[1] not in rels2index:
167 | rels2index[att[1]] = cur
168 | rels.append(att[1])
169 | cur += 1
170 | att2rel.append(rels2index[att[1]])
171 | return np.array(att2rel),rels
172 |
173 |
174 |
175 | def bert_feature(self, ):
176 | from sentence_transformers import SentenceTransformer
177 | from transformers import BertTokenizer, BertModel
178 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
179 | # model = BertModel.from_pretrained("bert-base-uncased").cuda()
180 | # model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2').cuda()
181 | model = SentenceTransformer('sentence-transformers/LaBSE').cuda()
182 |
183 | outputs = []
184 | texts = [a + ' ' + str(v) for i,a,v in self.att_features]
185 | batch_size = 2048
186 | sent_batch = [texts[i:i + batch_size] for i in range(0, len(texts), batch_size)]
187 | for sent in sent_batch:
188 |
189 | # encoded_input = tokenizer(sent, return_tensors='pt', padding=True, truncation=True, max_length=512)
190 | # #cuda
191 | # encoded_input.data['input_ids'] = encoded_input.data['input_ids'].cuda()
192 | # encoded_input.data['attention_mask'] = encoded_input.data['attention_mask'].cuda()
193 | # encoded_input.data['token_type_ids'] = encoded_input.data['token_type_ids'].cuda()
194 | with torch.no_grad():
195 | # output = model(**encoded_input)
196 | output = model.encode(sent)
197 | outputs.append(output)
198 | outputs = np.concatenate(outputs)
199 |
200 | # batch_size = 512
201 | sent_batch = [self.rels[i:i + batch_size] for i in range(0, len(self.rels), batch_size)]
202 | rel_outputs = []
203 | for sent in sent_batch:
204 | # encoded_input = tokenizer(sent, return_tensors='pt', padding=True, truncation=True, max_length=512)
205 | # #cuda
206 | # encoded_input.data['input_ids'] = encoded_input.data['input_ids'].cuda()
207 | # encoded_input.data['attention_mask'] = encoded_input.data['attention_mask'].cuda()
208 | # encoded_input.data['token_type_ids'] = encoded_input.data['token_type_ids'].cuda()
209 | with torch.no_grad():
210 | # output = model(**encoded_input)
211 | output = model.encode(sent)
212 | rel_outputs.append(output)
213 | rel_outputs = np.concatenate(rel_outputs)
214 |
215 | vals = [str(i[2]) for i in self.att_features]
216 | # batch_size = 512
217 | sent_batch = [vals[i:i + batch_size] for i in range(0, len(vals), batch_size)]
218 | val_outputs = []
219 | for sent in sent_batch:
220 | # encoded_input = tokenizer(sent, return_tensors='pt', padding=True, truncation=True, max_length=512)
221 | # #cuda
222 | # encoded_input.data['input_ids'] = encoded_input.data['input_ids'].cuda()
223 | # encoded_input.data['attention_mask'] = encoded_input.data['attention_mask'].cuda()
224 | # encoded_input.data['token_type_ids'] = encoded_input.data['token_type_ids'].cuda()
225 | with torch.no_grad():
226 | # output = model(**encoded_input)
227 | output = model.encode(sent)
228 | val_outputs.append(output)
229 | val_outputs = np.concatenate(val_outputs)
230 | del model
231 | return outputs, rel_outputs, val_outputs
232 |
233 |
234 |
235 | def ill2triples(self, ill):
236 | return [(i[0], self.n_rel * 2 + 1, i[1]) for i in ill]
237 |
238 | def img_ill2triples(self, ill):
239 | return [(i[0], self.n_rel * 2 + 3, i[1]) for i in ill]
240 |
241 | # def read_triples(self, filename):
242 | # triples = []
243 | # with open(os.path.join(self.task_dir, filename)) as f:
244 | # for line in f:
245 | # h, r, t = line.strip().split()
246 | # h, r, t = self.entity2id[h], self.relation2id[r], self.entity2id[t]
247 | # triples.append([h, r, t])
248 | # self.filters[(h, r)].add(t)
249 | # self.filters[(t, r + self.n_rel)].add(h)
250 | # return triples
251 |
252 | def double_triple(self, triples, ill=False):
253 | new_triples = []
254 | for triple in triples:
255 | h, r, t = triple
256 | new_triples.append([t, r + self.n_rel if not ill else r+1, h])
257 | return triples + new_triples
258 |
259 | def load_graph(self, triples):
260 | idd = np.concatenate([np.expand_dims(np.arange(self.n_ent), 1), 2 * self.n_rel * np.ones((self.n_ent, 1)),
261 | np.expand_dims(np.arange(self.n_ent), 1)], 1)
262 |
263 | KG = np.concatenate([np.array(triples), idd], 0)
264 | # n_fact = len(KG)
265 | # M_sub = csr_matrix((np.ones((n_fact,)), (np.arange(n_fact), KG[:, 0])),
266 | # shape=(n_fact, self.n_ent))
267 | return KG
268 |
269 |
270 | def get_subgraphs(self, head_nodes, layer=3,mode='train',sim=None):
271 | all_edges = []
272 | for index,head_node in enumerate(head_nodes):
273 | all_edge = self.get_subgraph(head_node, index, layer, mode,sim=sim)
274 | all_edges.append(all_edge)
275 | all_nodes = []
276 | layer_edges = []
277 | old_nodes_new_idxs = []
278 | old_nodes = []
279 | for i in range(layer):
280 | edges = []
281 | for j in range(len(all_edges)):
282 | edges.append(all_edges[j][i])
283 | edges = torch.cat(edges, dim=0)
284 | edges = edges.long()
285 |
286 | head_nodes, head_index = torch.unique(edges[:, [0, 1]], dim=0, sorted=True, return_inverse=True)
287 | tail_nodes, tail_index = torch.unique(edges[:, [0, 3]], dim=0, sorted=True, return_inverse=True)
288 | sampled_edges = torch.cat([edges, head_index.unsqueeze(1), tail_index.unsqueeze(1)], 1)
289 |
290 |
291 | mask = sampled_edges[:, 2] == (self.n_rel * 2)
292 | old_node, old_idx = head_index[mask].sort()
293 | old_nodes_new_idx = tail_index[mask][old_idx]
294 | all_nodes.append(tail_nodes)
295 | layer_edges.append(sampled_edges)
296 | old_nodes_new_idxs.append(old_nodes_new_idx)
297 | old_nodes.append(old_node)
298 |
299 |
300 | return all_nodes, layer_edges, old_nodes_new_idxs, old_nodes
301 | #
302 | def get_subgraph(self, head_node, index, layer, mode, max_size=500, sim=None):
303 | if mode == 'train':
304 | # # set false to self.node2index[node]
305 | # mask = torch.ones(len(self.train_triple), dtype=torch.bool).cuda()
306 | # mask[self.node2index[head_node.item()]] = False
307 | # support = self.train_triple[mask]
308 | # reverse_support = support[:, [2, 1, 0]]
309 | # reverse_support[:, 1] += 1
310 | # support = torch.cat((support, reverse_support), dim=0)
311 | # KG = torch.cat((support,self.fact_data),dim=0)
312 | KG=self.KG
313 | else:
314 | KG = self.tKG
315 | if sim is not None:
316 | KG = torch.cat((KG, sim), dim=0)
317 | if self.use_img_ill:
318 | KG = torch.cat((KG, self.img_ill_data), dim=0)
319 | row, col = KG[:, 0], KG[:, 2]
320 | node_mask = row.new_empty(self.n_ent, dtype=torch.bool)
321 | # edge_mask = row.new_empty(row.size(0), dtype=torch.bool)
322 | subsets = [torch.LongTensor([head_node]).cuda()]
323 | raw_layer_edges = []
324 | for i in range(layer):
325 | node_mask.fill_(False)
326 | node_mask[subsets[-1]] = True
327 | edge_mask = torch.index_select(node_mask, 0, row)
328 | subsets.append(torch.unique(col[edge_mask]))
329 | raw_layer_edges.append(edge_mask)
330 | # nodes, edges, old_nodes_new_idx = self.get_neighbors(nodes.data.cpu().numpy())
331 | # delete target not in the other KG
332 | tail_node = self.left_ents if head_node.item() >= len(self.left_ents) else self.right_ents
333 | tail_node = torch.LongTensor(tail_node).cuda()
334 | node_mask_ = row.new_empty(self.n_ent, dtype=torch.bool)
335 | node_mask_.fill_(False)
336 | node_mask_[tail_node] = True
337 | tail_set = subsets[-1]
338 | node_mask.fill_(False)
339 | node_mask[tail_set] = True
340 | node_mask = node_mask & node_mask_
341 | layer_edges = []
342 | for i in reversed(range(layer)):
343 | edge_mask = torch.index_select(node_mask, 0, col)
344 | edge_mask = edge_mask & raw_layer_edges[i]
345 | node_mask_.fill_(False)
346 | node_mask_[row[edge_mask]] = True
347 | node_mask = node_mask | node_mask_
348 | layer_edges.append(KG[edge_mask])
349 | layer_edges = layer_edges[::-1]
350 | batched_edges = []
351 | for i in range(layer):
352 | layer_edges[i] = torch.unique(layer_edges[i], dim=0)
353 | batched_edges.append(torch.cat([torch.ones(len(layer_edges[i])).unsqueeze(1).cuda() * index, layer_edges[i]], 1))
354 | return batched_edges
355 |
356 | def get_vis_subgraph(self, head_node, tail_node, layer, max_size=500, sim=None):
357 |
358 | KG = self.tKG
359 | if sim is not None:
360 | KG = torch.cat((KG, sim), dim=0)
361 | row, col = KG[:, 0], KG[:, 2]
362 | node_mask = row.new_empty(self.n_ent, dtype=torch.bool)
363 | # edge_mask = row.new_empty(row.size(0), dtype=torch.bool)
364 | subsets = [torch.LongTensor([head_node]).cuda()]
365 | raw_layer_edges = []
366 | for i in range(layer):
367 | node_mask.fill_(False)
368 | node_mask[subsets[-1]] = True
369 | edge_mask = torch.index_select(node_mask, 0, row)
370 | subsets.append(torch.unique(col[edge_mask]))
371 | raw_layer_edges.append(edge_mask)
372 | # nodes, edges, old_nodes_new_idx = self.get_neighbors(nodes.data.cpu().numpy())
373 | # delete target not in the other KG
374 | # tail_node = self.left_ents if head_node.item() >= len(self.left_ents) else self.right_ents
375 | tail_node = torch.LongTensor([tail_node]).cuda()
376 | node_mask_ = row.new_empty(self.n_ent, dtype=torch.bool)
377 | node_mask_.fill_(False)
378 | node_mask_[tail_node] = True
379 | tail_set = subsets[-1]
380 | node_mask.fill_(False)
381 | node_mask[tail_set] = True
382 | node_mask = node_mask & node_mask_
383 | layer_edges = []
384 | for i in reversed(range(layer)):
385 | edge_mask = torch.index_select(node_mask, 0, col)
386 | edge_mask = edge_mask & raw_layer_edges[i]
387 | node_mask_.fill_(False)
388 | node_mask_[row[edge_mask]] = True
389 | node_mask = node_mask | node_mask_
390 | layer_edges.append(KG[edge_mask])
391 | layer_edges = layer_edges[::-1]
392 | batched_edges = []
393 | for i in range(layer):
394 | layer_edges[i] = torch.unique(layer_edges[i], dim=0)
395 | batched_edges.append(layer_edges[i])
396 | return batched_edges
397 |
398 | # def get_neighbors(self, nodes, mode='train', n_hop=0):
399 | # if mode == 'train':
400 | # KG = self.KG
401 | # M_sub = self.M_sub
402 | # else:
403 | # KG = self.tKG
404 | # M_sub = self.tM_sub
405 | # # if self.test_cache
406 | #
407 | # # nodes: n_node x 2 with (batch_idx, node_idx)
408 | # node_1hot = csr_matrix((np.ones(len(nodes)), (nodes[:, 1], nodes[:, 0])), shape=(self.n_ent, nodes.shape[0])) # (n_ent, batch_size)
409 | # edge_1hot = M_sub.dot(node_1hot)
410 | # edges = np.nonzero(edge_1hot)
411 | # sampled_edges = np.concatenate([np.expand_dims(edges[1], 1), KG[edges[0]]],
412 | # axis=1) # (batch_idx, head, rela, tail)
413 | # sampled_edges = torch.LongTensor(sampled_edges).cuda()
414 | #
415 | # # index to nodes
416 | # head_nodes, head_index = torch.unique(sampled_edges[:, [0, 1]], dim=0, sorted=True, return_inverse=True)
417 | # tail_nodes, tail_index = torch.unique(sampled_edges[:, [0, 3]], dim=0, sorted=True, return_inverse=True)
418 | #
419 | # sampled_edges = torch.cat([sampled_edges, head_index.unsqueeze(1), tail_index.unsqueeze(1)], 1)
420 | #
421 | # mask = sampled_edges[:, 2] == (self.n_rel * 2)
422 | # _, old_idx = head_index[mask].sort()
423 | # old_nodes_new_idx = tail_index[mask][old_idx]
424 | #
425 | # return tail_nodes, sampled_edges, old_nodes_new_idx
426 |
427 | # def get_neighbor(self, node, mode='train', n_hop=0):
428 | # if mode == 'train':
429 | # # set false to self.node2index[node]
430 | # mask = torch.ones(len(self.train_triple), dtype=torch.bool)
431 | # mask[self.node2index[node]] = False
432 | # KG = torch.cat(self.train_triple[mask],self.fact_data)
433 | #
434 | # else:
435 | # KG = self.tKG
436 | # # if self.test_cache
437 | #
438 | # # nodes: n_node x 2 with (batch_idx, node_idx)
439 | # # node_1hot = csr_matrix((np.ones(len(nodes)), (nodes[:, 1], nodes[:, 0])), shape=(self.n_ent, nodes.shape[0])) # (n_ent, batch_size)
440 | # # edge_1hot = M_sub.dot(node_1hot)
441 | # edges = KG[:, 0]==node
442 | # edges = np.nonzero(edges)
443 | # sampled_edges = KG[edges[0]] # (head, rela, tail)
444 | # sampled_edges = torch.LongTensor(sampled_edges).cuda()
445 | #
446 | # # index to nodes
447 | # head_nodes, head_index = torch.unique(sampled_edges[:, 1], dim=0, sorted=True, return_inverse=True)
448 | # tail_nodes, tail_index = torch.unique(sampled_edges[:, 3], dim=0, sorted=True, return_inverse=True)
449 | #
450 | # sampled_edges = torch.cat([sampled_edges, head_index.unsqueeze(1), tail_index.unsqueeze(1)], 1)
451 | #
452 | # # mask = sampled_edges[:, 2] == (self.n_rel * 2)
453 | # # _, old_idx = head_index[mask].sort()
454 | # # old_nodes_new_idx = tail_index[mask][old_idx]
455 | #
456 | # return tail_nodes, sampled_edges
457 |
458 | def get_batch(self, batch_idx, steps=2, data='train'):
459 | if data == 'train':
460 | return self.train_data[batch_idx]
461 | if data == 'valid':
462 | return None
463 | if data == 'test':
464 | return self.test_data[batch_idx]
465 |
466 | # subs = []
467 | # rels = []
468 | # objs = []
469 | #
470 | # subs = query[batch_idx, 0]
471 | # rels = query[batch_idx, 1]
472 | # objs = np.zeros((len(batch_idx), self.n_ent))
473 | # for i in range(len(batch_idx)):
474 | # objs[i][answer[batch_idx[i]]] = 1
475 | # return subs, rels, objs
476 |
477 | def shuffle_train(self, ):
478 | # fact_triple = np.array(self.fact_triple)
479 | # train_triple = np.array(self.train_triple)
480 | # all_triple = np.concatenate([fact_triple, train_triple], axis=0)
481 | # n_all = len(all_triple)
482 | # rand_idx = np.random.permutation(n_all)
483 | # all_triple = all_triple[rand_idx]
484 |
485 | # random shuffle train_triples
486 | random.shuffle(self.train_triple)
487 | # support/query split 3/1
488 | support_triple = self.train_triple[:len(self.train_triple) * 3 // 4]
489 | query_triple = self.train_triple[len(self.train_triple) * 3 // 4:]
490 | # add inverse triples
491 | support_triple = self.double_triple(support_triple, ill=True)
492 | query_triple = self.double_triple(query_triple, ill=True)
493 | support = torch.LongTensor(support_triple).cuda()
494 | self.KG = torch.cat((support,self.fact_data),dim=0)
495 | # now the fact triples are fact_triple + support_triple
496 | # self.KG, self.M_sub = self.load_graph(self.fact_data + support_triple)
497 | self.n_train = len(query_triple)
498 | self.train_data = np.array(query_triple)
499 |
500 | # # increase the ratio of fact_data, e.g., 3/4->4/5, can increase the performance
501 | # self.fact_data = self.double_triple(all_triple[:n_all * 3 // 4].tolist())
502 | # self.train_data = np.array(self.double_triple(all_triple[n_all * 3 // 4:].tolist()))
503 | # self.n_train = len(self.train_data)
504 | # self.KG,self.M_sub = self.load_graph(self.fact_data)
505 |
506 | print('n_train:', self.n_train, 'n_test:', self.n_test)
507 |
508 | def preprocess_test(self, ):
509 | batch_size = 4
510 | n_data = self.n_test
511 | n_batch = n_data // batch_size + (n_data % batch_size > 0)
512 | for i in tqdm(range(n_batch)):
513 | start = i * batch_size
514 | end = min(n_data, (i + 1) * batch_size)
515 | batch_idx = np.arange(start, end)
516 | triple = self.get_batch(batch_idx, data='test')
517 | subs, rels, objs = triple[:, 0], triple[:, 1], triple[:, 2]
518 | print(subs, rels, objs)
519 | n = len(subs)
520 | q_sub = torch.LongTensor(subs).cuda()
521 | nodes = torch.cat([torch.arange(n).unsqueeze(1).cuda(), q_sub.unsqueeze(1)], 1)
522 | for h in range(5):
523 | nodes, edges, old_nodes_new_idx = self.get_neighbors(nodes.data.cpu().numpy(), mode='test',
524 | n_hop=h)
525 | # to np
526 | # self.test_cache[(i, h)] = (nodes.cpu().numpy(), edges.cpu().numpy(), old_nodes_new_idx.cpu().numpy())
527 | # use lmdb write
528 | with self.test_env.begin(write=True) as txn:
529 | txn.put(f'{i}_{h}'.encode(), pickle.dumps((nodes.cpu().numpy(), edges.cpu().numpy(), old_nodes_new_idx.cpu().numpy())))
530 | # pickle.dump(self.test_cache, open(self.test_cache_url, 'wb'))
531 |
532 | def get_test_cache(self, batch_idx, h):
533 | #use lmdb read
534 | with self.test_env.begin(write=False) as txn:
535 | nodes, edges, old_nodes_new_idx = pickle.loads(txn.get(f'{batch_idx}_{h}'.encode()))
536 | return nodes, edges, old_nodes_new_idx
537 | # return self.test_cache[(batch_idx, h)]
538 |
539 |
540 | # def save_cache(self):
541 | # with open(self.cache_path, 'wb') as f:
542 | # pickle.dump(self.edge_cache, f)
543 | #
544 | # def load_cache(self):
545 | # with open(self.cache_path, 'rb') as f:
546 | # self.edge_cache = pickle.load(f)
547 | # print("load cache from {}".format(self.cache_path))
548 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch_scatter import scatter
4 | import torch.nn.functional as F
5 | from torch_geometric.utils import softmax
6 | from torch_geometric.nn.models import MLP
7 | class Text_enc(nn.Module):
8 | def __init__(self, params):
9 | super().__init__()
10 | self.hidden_dim = params.text_dim
11 | self.u = nn.Linear(params.text_dim, 1)
12 | self.W = nn.Linear(2*params.text_dim , params.text_dim)
13 |
14 | def forward(self, ent_num, Textid, Text, Text_rel):
15 | # print(edge_index.device)
16 |
17 | a_v = torch.cat((Text_rel,Text),-1)
18 | o = self.u(Text_rel)
19 | alpha = softmax(o, Textid, None, ent_num)
20 | text = scatter(alpha * a_v, index=Textid, dim=0, dim_size=ent_num, reduce='sum')
21 |
22 | return text
23 |
24 |
25 | # class FeatureMapping(nn.Module):
26 | # def __init__(self, params):
27 | # super().__init__()
28 | # self.params = params
29 | # self.in_dims = {'Stru': params.stru_dim, 'Text': params.text_dim, 'IMG': params.hidden_dim,
30 | # 'Temporal': params.time_dim, 'Numerical': params.time_dim}
31 | # self.out_dim = params.hidden_dim
32 | # modals = ['Stru', 'Text', 'IMG', 'Temporal', 'Numerical']
33 | # self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
34 | # if self.device == 'cuda':
35 |
36 | # self.W_list = {
37 | # modal: MLP(in_channels=self.in_dims[modal], out_channels=self.out_dim,
38 | # hidden_channels=params.MLP_hidden_dim, num_layers=params.MLP_num_layers,
39 | # dropout=params.MLP_dropout, norm=None).cuda() for modal in modals
40 | # }
41 | # else:
42 | # self.W_list = {
43 | # modal: MLP(in_channels=self.in_dims[modal], out_channels=self.out_dim,
44 | # hidden_channels=params.MLP_hidden_dim, num_layers=params.MLP_num_layers,
45 | # dropout=params.MLP_dropout, norm=None) for modal in modals
46 | # }
47 | # self.W_list = nn.ModuleDict(self.W_list)
48 |
49 | # def forward(self, features):
50 | # new_features = {}
51 | # modals = ['Text']
52 |
53 | # for modal, feature in features.items():
54 | # if modal not in modals:
55 | # continue
56 | # # print(modal,feature.device)
57 | # new_features[modal] = self.W_list[modal](feature)
58 | # mean_feature = torch.mean(torch.stack(list(new_features.values())), dim=0)
59 | # return new_features, mean_feature
60 |
61 |
62 | class MMFeature(nn.Module):
63 | def __init__(self, n_ent, params):
64 | super().__init__()
65 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
66 | self.params = params
67 | self.n_ent = n_ent
68 | # self.feature_mapping = FeatureMapping(params)
69 | self.text_model = Text_enc(params)
70 | self.in_dims = {'Stru': params.stru_dim, 'Text': params.text_dim, 'IMG': params.img_dim}
71 | self.out_dim = params.hidden_dim
72 | modals = ['Text', 'IMG']
73 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
74 | self.W_list = {
75 | modal: MLP(in_channels=self.in_dims[modal], out_channels=self.out_dim,
76 | hidden_channels=params.MLP_hidden_dim, num_layers=params.MLP_num_layers,
77 | dropout=params.MLP_dropout, norm=None).to(self.device) for modal in modals
78 | }
79 | self.W_list = nn.ModuleDict(self.W_list)
80 |
81 | def forward(self, img_features = None,att_features= None,att_rel_features= None, att_ids=None):
82 | # features = {'IMG': self.W_list['IMG'](img_features),
83 | # 'Text': self.W_list['Text'](self.text_model(self.n_ent, att_ids, att_features, att_rel_features))}
84 | features = {'IMG': img_features,
85 | 'Text': self.text_model(self.n_ent, att_ids, att_features, att_rel_features)}
86 | # mean_feature = torch.mean(torch.stack(list(features.values())), dim=0)
87 | mean_feature = None
88 | return features, mean_feature
89 |
90 |
91 | class GNNLayer(torch.nn.Module):
92 | def __init__(self, in_dim, out_dim, attn_dim, n_rel, act=lambda x:x):
93 | super(GNNLayer, self).__init__()
94 | self.n_rel = n_rel
95 | self.in_dim = in_dim
96 | self.out_dim = out_dim
97 | self.attn_dim = attn_dim
98 | self.act = act
99 |
100 | # +3 for self-loop, alignment and alignment-inverse
101 | self.rela_embed = nn.Embedding(2 * n_rel + 5, in_dim)
102 |
103 | self.Ws_attn = nn.Linear(in_dim, attn_dim, bias=False)
104 | self.Wr_attn = nn.Linear(in_dim, attn_dim, bias=False)
105 | self.Wkg_attn = nn.Linear(2*in_dim, attn_dim)
106 | self.w_alpha = nn.Linear(attn_dim, 1)
107 |
108 | self.W_h = nn.Linear(in_dim, out_dim, bias=False)
109 |
110 | def forward(self, hidden, edges, n_node, kgemb, left_num):
111 | # edges: [batch_idx, head, rela, tail, old_idx, new_idx]
112 | sub = edges[:, 4]
113 | rel = edges[:, 2]
114 | obj = edges[:, 5]
115 |
116 | hs = hidden[sub]
117 | hr = self.rela_embed(rel)
118 |
119 | head = edges[:, 1]
120 | tail = edges[:, 3]
121 |
122 | kg_h = kgemb((head>=left_num).long())
123 | kg_t = kgemb((tail>=left_num).long())
124 | kg = torch.cat([kg_h, kg_t], dim=1)
125 |
126 | message = hs + hr
127 | alpha = torch.sigmoid(self.w_alpha(nn.ReLU()(self.Ws_attn(hs) + self.Wr_attn(hr) + self.Wkg_attn(kg))))
128 | message = alpha * message
129 |
130 | message_agg = scatter(message, index=obj, dim=0, dim_size=n_node, reduce='sum')
131 |
132 | hidden_new = self.act(self.W_h(message_agg))
133 |
134 | return hidden_new
135 |
136 |
137 | class MASGNN(torch.nn.Module):
138 | def __init__(self, params, loader):
139 | super(MASGNN, self).__init__()
140 | self.n_layer = params.n_layer
141 | self.hidden_dim = params.hidden_dim
142 | self.attn_dim = params.attn_dim
143 | self.mm = params.mm
144 | self.n_rel = loader.n_rel
145 | self.n_ent = loader.n_ent
146 | self.loader = loader
147 | self.left_num = len(self.loader.left_ents)
148 | acts = {'relu': nn.ReLU(), 'tanh': torch.tanh, 'idd': lambda x: x}
149 | act = acts[params.act]
150 |
151 | self.gnn_layers = []
152 | for i in range(self.n_layer):
153 | self.gnn_layers.append(GNNLayer(self.hidden_dim, self.hidden_dim, self.attn_dim, self.n_rel, act=act))
154 | self.gnn_layers = nn.ModuleList(self.gnn_layers)
155 |
156 | self.dropout = nn.Dropout(params.dropout)
157 | self.W_final = nn.Linear(self.hidden_dim if self.mm else self.hidden_dim, 1, bias=False) # get score todo: try to use mlp
158 | self.gate = nn.GRU(self.hidden_dim, self.hidden_dim)
159 | self.kgemb = nn.Embedding(2, self.hidden_dim)
160 | if self.mm:
161 | self.img_features = F.normalize(torch.FloatTensor(self.loader.images_list)).cuda()
162 | self.att_features = torch.FloatTensor(self.loader.att_features).cuda()
163 | self.num_att_left = self.loader.num_att_left
164 | self.num_att_right = self.loader.num_att_right
165 | self.att_val_features = torch.FloatTensor(self.loader.att_val_features).cuda()
166 | self.att_rel_features = torch.nn.Embedding(self.loader.att_rel_features.shape[0], self.loader.att_rel_features.shape[1])
167 | self.att_rel_features.weight.data = torch.FloatTensor(self.loader.att_rel_features).cuda()
168 | self.att_ids = torch.LongTensor(self.loader.att_ids).cuda()
169 | self.ids_att = self.loader.ids_att
170 | self.ids_att = {k:torch.LongTensor(v).cuda() for k,v in self.loader.ids_att.items()}
171 | self.att2rel = torch.LongTensor(self.loader.att2rel).cuda()
172 | self.mmfeature = MMFeature(self.n_ent, params)
173 | self.textMLP = MLP(in_channels=params.hidden_dim, out_channels=1,
174 | hidden_channels=params.MLP_hidden_dim, num_layers=params.MLP_num_layers,
175 | dropout=[params.MLP_dropout]*params.MLP_num_layers, norm=None)
176 | self.textW = nn.Linear(2*params.text_dim, params.hidden_dim)
177 | self.ImgMLP = MLP(in_channels=params.img_dim, out_channels=1,
178 | hidden_channels=params.MLP_hidden_dim, num_layers=params.MLP_num_layers,
179 | dropout=[params.MLP_dropout]*params.MLP_num_layers, norm=None)
180 |
181 |
182 | def forward(self, subs, mode='train',batch_idx=None):
183 | # if self.mm:
184 | # features, mean_feature = self.mmfeature(img_features=self.img_features, att_features=self.att_val_features,
185 | # att_rel_features=self.att_rel_features(self.att2rel), att_ids=self.att_ids)
186 | # simlarity of att_rel_features use cosine shape (n_rel, n_rel)
187 | # use self.att2rel to get simlarity from rel_sim , self.att2rel shape is n_att , attention shape is (n_att, n_att)
188 | # attention = rel_sim[torch.meshgrid(self.att2rel[:self.num_att_left], self.att2rel[self.num_att_left:])]
189 | # attention_l2r = scatter(attention, index=self.att_ids[self.num_att_left:]-self.left_num, dim=1, dim_size=self.n_ent-self.left_num, reduce='sum')
190 | # attention_r2l = scatter(attention, index=self.att_ids[:self.num_att_left], dim=0, dim_size=self.left_num, reduce='sum')
191 | # alpha_l2r = softmax(attention_l2r, self.att_ids[:self.num_att_left], None, self.left_num,0)
192 | # alpha_r2l = softmax(attention_r2l, self.att_ids[self.num_att_left:]-self.left_num, None, self.n_ent-self.left_num,-1)
193 | # get att_features (n1,n2,dim)
194 |
195 |
196 |
197 |
198 |
199 | # features['IMG'] = features['IMG'] / torch.norm(features['IMG'], dim=-1, keepdim=True)
200 | # features['Text'] = features['Text'] / torch.norm(features['Text'], dim=-1, keepdim=True)
201 |
202 | # img_features = self.ImgMLP(self.img_features)
203 | # img_features = F.normalize(img_features)
204 | # sim_i = torch.mm(img_features[:self.left_num], img_features[self.left_num:].T)
205 | # sim_t = torch.mm(features['Text'][:self.left_num], features['Text'][self.left_num:].T)
206 | # sim_m = sim_i+sim_t
207 | # select sim > 0.9 index
208 | # sim = torch.nonzero(sim_m > 0.8).squeeze(1)
209 | # # add rels = (2 * n_rel + 3) and inverse rels = (2 * n_rel + 4)
210 | # sim_ = torch.cat([sim[:,[0]],torch.ones(sim.shape[0],1).long().cuda() * (2 * self.n_rel + 3), sim[:,[1]] + self.left_num], -1)
211 | # rev_sim = torch.cat([sim[:,[1]] + self.left_num,torch.ones(sim.shape[0],1).long().cuda() * (2 * self.n_rel + 4),sim[:,[0]]], -1)
212 | # sim = torch.cat([sim_, rev_sim], 0)
213 |
214 |
215 | q_sub = torch.LongTensor(subs).cuda()
216 | n = q_sub.shape[0]
217 | nodes = torch.cat([torch.arange(n).unsqueeze(1).cuda(), q_sub.unsqueeze(1)], 1)
218 | nodess, edgess, old_nodes_new_idxs,old_nodes = self.loader.get_subgraphs(q_sub, layer=self.n_layer,mode=mode,sim=None)
219 |
220 |
221 |
222 |
223 |
224 | # hidden = mean_feature[nodes[:, 1]]
225 | # h0 = mean_feature[nodes[:, 1]].unsqueeze(0)
226 | # else:
227 | h0 = torch.zeros((1, n, self.hidden_dim)).cuda()
228 | hidden = torch.zeros(n, self.hidden_dim).cuda()
229 |
230 |
231 |
232 |
233 | scores_all = []
234 | for i in range(self.n_layer):
235 | nodes = nodess[i]
236 | edges = edgess[i]
237 | old_nodes_new_idx = old_nodes_new_idxs[i]
238 | old_node = old_nodes[i]
239 | # if mode == 'train':
240 | # nodes, edges, old_nodes_new_idx = self.loader.get_neighbors(nodes.data.cpu().numpy(), mode=mode,n_hop=i)
241 | # else:
242 | # nodes, edges, old_nodes_new_idx = self.loader.get_test_cache(batch_idx,i)
243 | # # np to tensor
244 | # nodes = torch.LongTensor(nodes).cuda()
245 | # edges = torch.LongTensor(edges).cuda()
246 | # old_nodes_new_idx = torch.LongTensor(old_nodes_new_idx).cuda()
247 | # print(nodes)
248 | # print(edges)
249 | # print(old_nodes_new_idx)
250 | # print(hidden)
251 | # print(h0)
252 | hidden = self.gnn_layers[i](hidden, edges, nodes.size(0), self.kgemb, self.left_num)
253 | # print(hidden)
254 |
255 | # if self.mm:
256 | # h0 = mean_feature[nodes[:, 1]].unsqueeze(0).cuda().index_copy_(1, old_nodes_new_idx, h0[:,old_node])
257 | # else:
258 | h0 = torch.zeros(1, nodes.size(0), hidden.size(1)).cuda().index_copy_(1, old_nodes_new_idx, h0[:, old_node])
259 | hidden = self.dropout(hidden)
260 | hidden, h0 = self.gate(hidden.unsqueeze(0), h0)
261 | hidden = hidden.squeeze(0)
262 | # hidden -> (len(nodes), hidden_dim)
263 | # if self.mm:
264 | # mm_hidden = torch.cat((hidden, features['IMG'][nodes[:, 1]] - features['IMG'][q_sub[nodes[:, 0]]],
265 | # features['Text'][nodes[:, 1]] - features['Text'][q_sub[nodes[:, 0]]]), dim=-1)
266 | # scores = self.W_final(mm_hidden).squeeze(-1)
267 | # else:
268 | scores = self.W_final(hidden).squeeze(-1)
269 |
270 | scores_all = torch.zeros((len(subs), self.loader.n_ent)).cuda() # non_visited entities have 0 scores
271 | scores_all[[nodes[:, 0], nodes[:, 1]]] = scores
272 |
273 |
274 | if self.mm:
275 | source,target = torch.meshgrid(q_sub, torch.arange(self.n_ent).cuda())
276 | hidden = self.img_features[source] * self.img_features[target]
277 | b,_ = torch.meshgrid(torch.arange(n).cuda(), torch.arange(self.n_ent).cuda())
278 | img_scores = self.ImgMLP(hidden).squeeze(-1)
279 | scores_all[[b, target]] += img_scores
280 |
281 | rel_sim = torch.mm(self.att_rel_features.weight, self.att_rel_features.weight.T)
282 |
283 | for i,sub in enumerate(subs):
284 | if sub not in self.ids_att:
285 | continue
286 | if sub 0:
221 | epoch += 1
222 | mrr,t_h1, t_h3, t_h5, t_h10, out_str = model.train_batch()
223 | if args.nni:
224 | nni.report_intermediate_result({'default':mrr,'h1':t_h1,'h3':t_h3,'h5':t_h5,'h10':t_h10})
225 | with open(args.perf_file, 'a+') as f:
226 | f.write(out_str)
227 | if mrr > best_mrr:
228 | best_mrr = mrr
229 | best_h1 = t_h1
230 | best_h3 = t_h3
231 | best_h5 = t_h5
232 | best_h10 = t_h10
233 | best_str = out_str
234 | print(str(epoch) + '\t' + best_str)
235 | with open(args.perf_file,'a+') as f:
236 | f.write("best at "+ str(epoch) + '\t' + best_str)
237 | wait_patient = 10
238 | else:
239 | wait_patient -= 1
240 |
241 | if args.nni:
242 | nni.report_final_result({'default':best_mrr,'h1':best_h1,'h3':best_h3,'h5':best_h5,'h10':best_h10})
243 | print(best_str)
244 |
245 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import errno
4 | import torch
5 | import sys
6 | import logging
7 | import json
8 | from pathlib import Path
9 | import torch.optim as optim
10 | from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
11 | import torch.distributed as dist
12 | import csv
13 | import os.path as osp
14 | import time
15 | import re
16 | import pdb
17 | from torch import nn
18 | from numpy import mean
19 | import multiprocessing
20 | import math
21 | import random
22 | import numpy as np
23 | import scipy
24 | import scipy.sparse as sp
25 | from scipy.stats import rankdata
26 |
27 |
28 | def set_optim(opt, model_list, freeze_part=[], accumulation_step=None):
29 | named_parameters = []
30 | param_name = []
31 | for model in model_list:
32 | model_para_train, freeze_layer = [], []
33 | model_para = list(model.named_parameters())
34 |
35 | for n, p in model_para:
36 | if not any(nd in n for nd in freeze_part):
37 | model_para_train.append((n, p))
38 | param_name.append(n)
39 | else:
40 | p.requires_grad = False
41 | freeze_layer.append((n, p))
42 | # pdb.set_trace()
43 | named_parameters.extend(model_para_train)
44 |
45 | parameters = [
46 | {'params': [p for n, p in named_parameters], "lr": opt.lr, 'weight_decay': opt.weight_decay}
47 | ]
48 |
49 | if opt.optim == 'adamw':
50 | # optimizer = optim.AdamW(model.parameters(), lr=opt.lr, eps=opt.adam_epsilon)
51 | optimizer = optim.AdamW(parameters, lr=opt.lr, eps=opt.adam_epsilon)
52 | # optimizer = AdamW(parameters, lr=opt.lr, eps=opt.adam_epsilon)
53 | elif opt.optim == 'adam':
54 | optimizer = optim.Adam(parameters, lr=opt.lr)
55 |
56 | if accumulation_step is None:
57 | accumulation_step = opt.accumulation_steps
58 | if opt.scheduler == 'fixed':
59 | scheduler = FixedScheduler(optimizer)
60 | elif opt.scheduler == 'linear':
61 | scheduler_steps = opt.total_steps
62 | # scheduler = WarmupLinearScheduler(optimizer, warmup_steps=opt.warmup_steps, scheduler_steps=scheduler_steps, min_ratio=0.)
63 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(opt.warmup_steps / accumulation_step), num_training_steps=int(opt.total_steps / accumulation_step))
64 | elif opt.scheduler == 'cos':
65 | scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=int(opt.warmup_steps / accumulation_step), num_training_steps=int(opt.total_steps / accumulation_step))
66 |
67 | return optimizer, scheduler
68 |
69 |
70 | class FixedScheduler(torch.optim.lr_scheduler.LambdaLR):
71 | def __init__(self, optimizer, last_epoch=-1):
72 | super(FixedScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
73 |
74 | def lr_lambda(self, step):
75 | return 1.0
76 |
77 |
78 | class WarmupLinearScheduler(torch.optim.lr_scheduler.LambdaLR):
79 | def __init__(self, optimizer, warmup_steps, scheduler_steps, min_ratio, last_epoch=-1):
80 | self.warmup_steps = warmup_steps
81 | self.scheduler_steps = scheduler_steps
82 | self.min_ratio = min_ratio
83 | # self.fixed_lr = fixed_lr
84 | super(WarmupLinearScheduler, self).__init__(
85 | optimizer, self.lr_lambda, last_epoch=last_epoch
86 | )
87 |
88 | def lr_lambda(self, step):
89 | if step < self.warmup_steps:
90 | return (1 - self.min_ratio) * step / float(max(1, self.warmup_steps)) + self.min_ratio
91 |
92 | # if self.fixed_lr:
93 | # return 1.0
94 |
95 | return max(0.0,
96 | 1.0 + (self.min_ratio - 1) * (step - self.warmup_steps) / float(max(1.0, self.scheduler_steps - self.warmup_steps)),
97 | )
98 |
99 |
100 | class Loss_log():
101 | def __init__(self):
102 | self.loss = [999999.]
103 | self.acc = [0.]
104 | self.flag = 0
105 | self.token_right_num = []
106 | self.token_all_num = []
107 | self.use_top_k_acc = 0
108 |
109 | def acc_init(self, topn=[1]):
110 | self.loss = []
111 | self.token_right_num = []
112 | self.token_all_num = []
113 | self.topn = topn
114 | self.use_top_k_acc = 1
115 | self.top_k_word_right = {}
116 | for n in topn:
117 | self.top_k_word_right[n] = []
118 |
119 | def get_token_acc(self):
120 | if len(self.token_all_num) == 0:
121 | return 0.
122 | elif self.use_top_k_acc == 1:
123 | res = []
124 | for n in self.topn:
125 | res.append(round((sum(self.top_k_word_right[n]) / sum(self.token_all_num)) * 100, 3))
126 | return res
127 | else:
128 | return [sum(self.token_right_num) / sum(self.token_all_num)]
129 |
130 | def update_token(self, token_num, token_right):
131 | self.token_all_num.append(token_num)
132 | if isinstance(token_right, list):
133 | for i, n in enumerate(self.topn):
134 | self.top_k_word_right[n].append(token_right[i])
135 | self.token_right_num.append(token_right)
136 |
137 | def update(self, case):
138 | self.loss.append(case)
139 |
140 | def update_acc(self, case):
141 | self.acc.append(case)
142 |
143 | def get_acc(self):
144 | return self.acc[-1]
145 |
146 | def get_min_loss(self):
147 | return min(self.loss)
148 |
149 | def get_loss(self):
150 | if len(self.loss) == 0:
151 | return 500.
152 | return mean(self.loss)
153 |
154 | def early_stop(self):
155 | # min_loss = min(self.loss)
156 | if self.loss[-1] > min(self.loss):
157 | self.flag += 1
158 | else:
159 | self.flag = 0
160 |
161 | if self.flag > 1000:
162 | return True
163 | else:
164 | return False
165 |
166 | def torch_accuracy(output, target, topk=(1,)):
167 | '''
168 | param output, target: should be torch Variable
169 | '''
170 | # assert isinstance(output, torch.cuda.Tensor), 'expecting Torch Tensor'
171 | # assert isinstance(target, torch.Tensor), 'expecting Torch Tensor'
172 | # print(type(output))
173 |
174 | topn = max(topk)
175 | batch_size = output.size(0)
176 |
177 | _, pred = output.topk(topn, 1, True, True)
178 | pred = pred.t()
179 |
180 | is_correct = pred.eq(target.view(1, -1).expand_as(pred))
181 |
182 | ans = []
183 | ans_num = []
184 | for i in topk:
185 | # is_correct_i = is_correct[:i].view(-1).float().sum(0, keepdim=True)
186 | is_correct_i = is_correct[:i].contiguous().view(-1).float().sum(0, keepdim=True)
187 | ans_num.append(int(is_correct_i.item()))
188 | ans.append(is_correct_i.mul_(100.0 / batch_size))
189 |
190 | return ans, ans_num
191 |
192 |
193 | def pairwise_distances(x, y=None):
194 | '''
195 | Input: x is a Nxd matrix
196 | y is an optional Mxd matirx
197 | Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
198 | if y is not given then use 'y=x'.
199 | i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
200 | '''
201 | x_norm = (x**2).sum(1).view(-1, 1)
202 | if y is not None:
203 | y_norm = (y**2).sum(1).view(1, -1)
204 | else:
205 | y = x
206 | y_norm = x_norm.view(1, -1)
207 |
208 | distance = x_norm + y_norm - 2.0 * torch.mm(x, torch.transpose(y, 0, 1))
209 | return torch.clamp(distance, 0.0, np.inf)
210 |
211 |
212 | def normalize_adj(mx):
213 | """Row-normalize sparse matrix"""
214 | rowsum = np.array(mx.sum(1))
215 | r_inv_sqrt = np.power(rowsum, -0.5).flatten()
216 | r_inv_sqrt[np.isinf(r_inv_sqrt)] = 0.
217 | r_mat_inv_sqrt = sp.diags(r_inv_sqrt)
218 | return mx.dot(r_mat_inv_sqrt).transpose().dot(r_mat_inv_sqrt)
219 |
220 |
221 | def normalize_features(mx):
222 | """Row-normalize sparse matrix"""
223 | rowsum = np.array(mx.sum(1))
224 | r_inv = np.power(rowsum, -1).flatten()
225 | r_inv[np.isinf(r_inv)] = 0.
226 | r_mat_inv = sp.diags(r_inv)
227 | mx = r_mat_inv.dot(mx)
228 | return mx
229 |
230 |
231 | def sparse_mx_to_torch_sparse_tensor(sparse_mx):
232 | """Convert a scipy sparse matrix to a torch sparse tensor."""
233 | sparse_mx = sparse_mx.tocoo().astype(np.float32)
234 | indices = torch.from_numpy(
235 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
236 | values = torch.FloatTensor(sparse_mx.data)
237 | shape = torch.Size(sparse_mx.shape)
238 | return torch.sparse.FloatTensor(indices, values, shape)
239 |
240 |
241 | def div_list(ls, n):
242 | ls_len = len(ls)
243 | if n <= 0 or 0 == ls_len:
244 | return []
245 | if n > ls_len:
246 | return []
247 | elif n == ls_len:
248 | return [[i] for i in ls]
249 | else:
250 | j = ls_len // n
251 | k = ls_len % n
252 | ls_return = []
253 | for i in range(0, (n - 1) * j, j):
254 | ls_return.append(ls[i:i + j])
255 | ls_return.append(ls[(n - 1) * j:])
256 | return ls_return
257 |
258 |
259 | def multi_cal_neg(pos_triples, task, triples, r_hs_dict, r_ts_dict, ids, neg_scope):
260 | neg_triples = list()
261 | for idx, tas in enumerate(task):
262 | (h, r, t) = pos_triples[tas]
263 | h2, r2, t2 = h, r, t
264 | temp_scope, num = neg_scope, 0
265 | while True:
266 | choice = random.randint(0, 999)
267 | if choice < 500:
268 | if temp_scope:
269 | h2 = random.sample(r_hs_dict[r], 1)[0]
270 | else:
271 | for id in ids:
272 | if h2 in id:
273 | h2 = random.sample(id, 1)[0]
274 | break
275 | else:
276 | if temp_scope:
277 | t2 = random.sample(r_ts_dict[r], 1)[0]
278 | else:
279 | for id in ids:
280 | if t2 in id:
281 | t2 = random.sample(id, 1)[0]
282 | break
283 | if (h2, r2, t2) not in triples:
284 | break
285 | else:
286 | num += 1
287 | if num > 10:
288 | temp_scope = False
289 | neg_triples.append((h2, r2, t2))
290 | return neg_triples
291 |
292 |
293 | def multi_typed_sampling(pos_triples, triples, r_hs_dict, r_ts_dict, ids, neg_scope):
294 | t_ = time.time()
295 | triples = set(triples)
296 | tasks = div_list(np.array(range(len(pos_triples)), dtype=np.int32), 10)
297 | pool = multiprocessing.Pool(processes=len(tasks))
298 | reses = list()
299 | for task in tasks:
300 | reses.append(pool.apply_async(multi_cal_neg, (pos_triples, task, triples, r_hs_dict, r_ts_dict, ids, neg_scope)))
301 | pool.close()
302 | pool.join()
303 | neg_triples = []
304 | for res in reses:
305 | neg_triples.extend(res.get())
306 | return neg_triples
307 |
308 |
309 | def nearest_neighbor_sampling(emb, left, right, K):
310 | t = time.time()
311 | neg_left = []
312 | distance = pairwise_distances(emb[right], emb[right])
313 | for idx in range(right.shape[0]):
314 | _, indices = torch.sort(distance[idx, :], descending=False)
315 | neg_left.append(right[indices[1:K + 1]])
316 | neg_left = torch.cat(tuple(neg_left), dim=0)
317 | neg_right = []
318 | distance = pairwise_distances(emb[left], emb[left])
319 | for idx in range(left.shape[0]):
320 | _, indices = torch.sort(distance[idx, :], descending=False)
321 | neg_right.append(left[indices[1:K + 1]])
322 | neg_right = torch.cat(tuple(neg_right), dim=0)
323 | return neg_left, neg_right
324 |
325 |
326 | def get_adjr(ent_size, triples, norm=False):
327 | print('getting a sparse tensor r_adj...')
328 | M = {}
329 | for tri in triples:
330 | if tri[0] == tri[2]:
331 | continue
332 | if (tri[0], tri[2]) not in M:
333 | M[(tri[0], tri[2])] = 0
334 | M[(tri[0], tri[2])] += 1
335 | ind, val = [], []
336 | for (fir, sec) in M:
337 | ind.append((fir, sec))
338 | ind.append((sec, fir))
339 | val.append(M[(fir, sec)])
340 | val.append(M[(fir, sec)])
341 |
342 | for i in range(ent_size):
343 | ind.append((i, i))
344 | val.append(1)
345 |
346 | if norm:
347 | ind = np.array(ind, dtype=np.int32)
348 | val = np.array(val, dtype=np.float32)
349 | adj = sp.coo_matrix((val, (ind[:, 0], ind[:, 1])), shape=(ent_size, ent_size), dtype=np.float32)
350 | # 1. normalize_adj
351 | # 2. Convert a scipy sparse matrix to a torch sparse tensor
352 | # pdb.set_trace()
353 | return sparse_mx_to_torch_sparse_tensor(normalize_adj(adj))
354 | else:
355 | M = torch.sparse_coo_tensor(torch.LongTensor(ind).t(), torch.FloatTensor(val), torch.Size([ent_size, ent_size]))
356 | return M
357 |
358 |
359 | def cal_ranks(scores, labels, is_lefts, left_num):
360 | ranks = []
361 | for idx, score in enumerate(scores):
362 | if not is_lefts[idx]:
363 | real_score = - score[:left_num]
364 | rank = real_score.argsort()
365 | rank = np.where(rank == labels[idx])[0][0]
366 | else:
367 | real_score = - score[left_num:]
368 | rank = real_score.argsort()
369 | rank = np.where(rank == labels[idx]-left_num)[0][0]
370 | ranks.append(rank+1)
371 | return list(ranks)
372 |
373 |
374 | def cal_performance(ranks):
375 | mrr = (1. / ranks).sum() / len(ranks)
376 | h_1 = sum(ranks<=1) * 1.0 / len(ranks)
377 | h_3 = sum(ranks<=3) * 1.0 / len(ranks)
378 | h_5 = sum(ranks<=5) * 1.0 / len(ranks)
379 | h_10 = sum(ranks<=10) * 1.0 / len(ranks)
380 | return mrr, h_1,h_3,h_5, h_10
381 |
382 | def multi_cal_rank(task, sim, top_k, l_or_r):
383 | mean = 0
384 | mrr = 0
385 | num = [0 for k in top_k]
386 | for i in range(len(task)):
387 | ref = task[i]
388 | if l_or_r == 0:
389 | rank = (sim[i, :]).argsort()
390 | else:
391 | rank = (sim[:, i]).argsort()
392 | assert ref in rank
393 | rank_index = np.where(rank == ref)[0][0]
394 | mean += (rank_index + 1)
395 | mrr += 1.0 / (rank_index + 1)
396 | for j in range(len(top_k)):
397 | if rank_index < top_k[j]:
398 | num[j] += 1
399 | return mean, num, mrr
400 |
401 |
402 | def multi_get_hits(Lvec, Rvec, top_k=(1, 5, 10, 50, 100), args=None):
403 | result = []
404 | sim = pairwise_distances(torch.FloatTensor(Lvec), torch.FloatTensor(Rvec)).numpy()
405 | if args.csls is True:
406 | sim = 1 - csls_sim(1 - sim, args.csls_k)
407 | for i in [0, 1]:
408 | top_total = np.array([0] * len(top_k))
409 | mean_total, mrr_total = 0.0, 0.0
410 | s_len = Lvec.shape[0] if i == 0 else Rvec.shape[0]
411 | tasks = div_list(np.array(range(s_len)), 10)
412 | pool = multiprocessing.Pool(processes=len(tasks))
413 | reses = list()
414 | for task in tasks:
415 | if i == 0:
416 | reses.append(pool.apply_async(multi_cal_rank, (task, sim[task, :], top_k, i)))
417 | else:
418 | reses.append(pool.apply_async(multi_cal_rank, (task, sim[:, task], top_k, i)))
419 | pool.close()
420 | pool.join()
421 | for res in reses:
422 | mean, num, mrr = res.get()
423 | mean_total += mean
424 | mrr_total += mrr
425 | top_total += np.array(num)
426 | acc_total = top_total / s_len
427 | for i in range(len(acc_total)):
428 | acc_total[i] = round(acc_total[i], 4)
429 | mean_total /= s_len
430 | mrr_total /= s_len
431 | result.append(acc_total)
432 | result.append(mean_total)
433 | result.append(mrr_total)
434 | return result
435 |
436 |
437 | def csls_sim(sim_mat, k):
438 | """
439 | Compute pairwise csls similarity based on the input similarity matrix.
440 | Parameters
441 | ----------
442 | sim_mat : matrix-like
443 | A pairwise similarity matrix.
444 | k : int
445 | The number of nearest neighbors.
446 | Returns
447 | -------
448 | csls_sim_mat : A csls similarity matrix of n1*n2.
449 | """
450 |
451 | nearest_values1 = torch.mean(torch.topk(sim_mat, k)[0], 1)
452 | nearest_values2 = torch.mean(torch.topk(sim_mat.t(), k)[0], 1)
453 | csls_sim_mat = 2 * sim_mat.t() - nearest_values1
454 | csls_sim_mat = csls_sim_mat.t() - nearest_values2
455 | return csls_sim_mat
456 |
457 |
458 | def get_topk_indices(M, K=1000):
459 | H, W = M.shape
460 | M_view = M.view(-1)
461 | vals, indices = M_view.topk(K)
462 | print("highest sim:", vals[0].item(), "lowest sim:", vals[-1].item())
463 | two_d_indices = torch.cat(((indices // W).unsqueeze(1), (indices % W).unsqueeze(1)), dim=1)
464 | return two_d_indices
465 |
466 |
467 | def normalize_zero_one(A):
468 | A -= A.min(1, keepdim=True)[0]
469 | A /= A.max(1, keepdim=True)[0]
470 | return A
471 |
472 |
473 | def output_device(model):
474 | sd = model.state_dict()
475 | devices = []
476 | for v in sd.values():
477 | if v.device not in devices:
478 | devices.append(v.device)
479 | # for d in devices:
480 | # print(d)
481 | print(devices)
482 |
483 |
484 | if __name__ == '__main__':
485 | # test cal_ranks 9 nodes, 5left , 4right,2 seeds(3,7)(8,2)
486 | scores = np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.31,0.9,0.8,0.7],
487 | [0.5, 0.4, 0.7, 0.2, 0.1,0.3,0.32,0.23,0.44]])
488 | labels = np.array([7,2])
489 | is_lefts = np.array([True,False])
490 | left_num = 5
491 | ranks = cal_ranks(scores, labels, is_lefts, left_num)
492 | print(ranks)
493 |
--------------------------------------------------------------------------------
/vis.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
4 | import argparse
5 | import random
6 | import torch
7 | import numpy as np
8 | from load_data import DataLoader
9 | from base_model import BaseModel
10 | import time
11 | from collections import OrderedDict
12 | import networkx as nx
13 | import matplotlib.pyplot as plt
14 |
15 |
16 | parser = argparse.ArgumentParser(description="Parser for MASEA")
17 | parser.add_argument("--data_path", default="../data/mmkg", type=str, help="Experiment path")
18 | parser.add_argument("--data_choice", default="FBDB15K", type=str, choices=["DBP15K", "DWY", "FBYG15K", "FBDB15K"],
19 | help="Experiment path")
20 | parser.add_argument("--data_split", default="norm", type=str, help="Experiment split",
21 | choices=["dbp_wd_15k_V2", "dbp_wd_15k_V1", "zh_en", "ja_en", "fr_en", "norm"])
22 | parser.add_argument("--data_rate", type=float, default=0.8, choices=[0.2, 0.3, 0.5, 0.8], help="training set rate")
23 | parser.add_argument('--seed', type=str, default=1234)
24 | parser.add_argument('--gpu', type=int, default=0)
25 | parser.add_argument('--perf_file', type=str, default='perf.txt')
26 | parser.add_argument('--lr', type=float, default=0.001)
27 | parser.add_argument('--lamb', type=float, default=0.0002)
28 | parser.add_argument('--decay_rate', type=float, default=0.991)
29 | parser.add_argument('--hidden_dim', type=int, default=64)
30 | parser.add_argument('--attn_dim', type=int, default=5)
31 | parser.add_argument('--dropout', type=float, default=0.2)
32 | parser.add_argument('--act', type=str, default='relu')
33 | parser.add_argument('--n_layer', type=int, default=5)
34 | parser.add_argument('--n_batch', type=int, default=2)
35 | parser.add_argument("--lamda", type=float, default=0.5)
36 | parser.add_argument("--exp_name", default="EA_exp", type=str, help="Experiment name")
37 | parser.add_argument("--MLP_hidden_dim", type=int, default=64)
38 | parser.add_argument("--MLP_num_layers", type=int, default=3)
39 | parser.add_argument("--MLP_dropout", type=float, default=0.2)
40 |
41 | parser.add_argument("--n_ent", type=int, default=0)
42 | parser.add_argument("--n_rel", type=int, default=0)
43 |
44 | parser.add_argument("--stru_dim", type=int, default=16)
45 | parser.add_argument("--text_dim", type=int, default=768)
46 | parser.add_argument("--img_dim", type=int, default=2048)
47 | parser.add_argument("--time_dim", type=int, default=32)
48 | parser.add_argument("--out_dim", type=int, default=32)
49 | parser.add_argument("--train_support", type=int, default=0)
50 | parser.add_argument("--gnn_model", type=str, default='RS_GNN')
51 | parser.add_argument("--mm", type=int, default=0)
52 | parser.add_argument("--shuffle", type=int, default=1)
53 | parser.add_argument("--meta", type=int, default=1)
54 | parser.add_argument("--temperature", type=float, default=0.5)
55 | parser.add_argument("--premm", type=int, default=0)
56 | parser.add_argument("--withmm", type=int, default=1)
57 | parser.add_argument("--update_step", type=int, default=20)
58 | parser.add_argument("--update_step_test", type=int, default=20)
59 | parser.add_argument("--update_lr", type=float, default=0.001)
60 |
61 |
62 | # base
63 |
64 | parser.add_argument('--batch_size', default=128, type=int)
65 | parser.add_argument('--epoch', default=100, type=int)
66 | parser.add_argument("--save_model", default=0, type=int, choices=[0, 1])
67 | parser.add_argument("--only_test", default=0, type=int, choices=[0, 1])
68 |
69 | # torthlight
70 | parser.add_argument("--no_tensorboard", default=False, action="store_true")
71 |
72 | parser.add_argument("--dump_path", default="dump/", type=str, help="Experiment dump path")
73 | parser.add_argument("--exp_id", default="001", type=str, help="Experiment ID")
74 | parser.add_argument("--random_seed", default=42, type=int)
75 |
76 |
77 | # --------- EA -----------
78 |
79 | # parser.add_argument("--data_rate", type=float, default=0.3, help="training set rate")
80 | #
81 |
82 | # TODO: add some dynamic variable
83 | parser.add_argument("--model_name", default="MEAformer", type=str, choices=["EVA", "MCLEA", "MSNEA", "MEAformer"],
84 | help="model name")
85 | parser.add_argument("--model_name_save", default="", type=str, help="model name for model load")
86 |
87 | parser.add_argument('--workers', type=int, default=8)
88 | parser.add_argument('--accumulation_steps', type=int, default=1)
89 | parser.add_argument("--scheduler", default="linear", type=str, choices=["linear", "cos", "fixed"])
90 | parser.add_argument("--optim", default="adamw", type=str, choices=["adamw", "adam"])
91 | parser.add_argument('--weight_decay', type=float, default=0.0001)
92 | parser.add_argument("--adam_epsilon", default=1e-8, type=float)
93 | parser.add_argument('--eval_epoch', default=100, type=int, help='evaluate each n epoch')
94 | parser.add_argument("--enable_sota", action="store_true", default=False)
95 |
96 | parser.add_argument('--margin', default=1, type=float, help='The fixed margin in loss function. ')
97 | parser.add_argument('--emb_dim', default=1000, type=int, help='The embedding dimension in KGE model.')
98 | parser.add_argument('--adv_temp', default=1.0, type=float,
99 | help='The temperature of sampling in self-adversarial negative sampling.')
100 | parser.add_argument("--contrastive_loss", default=0, type=int, choices=[0, 1])
101 | parser.add_argument('--clip', type=float, default=1., help='gradient clipping')
102 |
103 | # --------- EVA -----------
104 |
105 | parser.add_argument("--hidden_units", type=str, default="128,128,128",
106 | help="hidden units in each hidden layer(including in_dim and out_dim), splitted with comma")
107 | parser.add_argument("--attn_dropout", type=float, default=0.0, help="dropout rate for gat layers")
108 | parser.add_argument("--distance", type=int, default=2, help="L1 distance or L2 distance. ('1', '2')", choices=[1, 2])
109 | parser.add_argument("--csls", action="store_true", default=False, help="use CSLS for inference")
110 | parser.add_argument("--csls_k", type=int, default=10, help="top k for csls")
111 | parser.add_argument("--il", action="store_true", default=False, help="Iterative learning?")
112 | parser.add_argument("--semi_learn_step", type=int, default=10, help="If IL, what's the update step?")
113 | parser.add_argument("--il_start", type=int, default=500, help="If Il, when to start?")
114 | parser.add_argument("--unsup", action="store_true", default=False)
115 | parser.add_argument("--unsup_k", type=int, default=1000, help="|visual seed|")
116 |
117 | # --------- MCLEA -----------
118 | parser.add_argument("--unsup_mode", type=str, default="img", help="unsup mode", choices=["img", "name", "char"])
119 | parser.add_argument("--tau", type=float, default=0.1, help="the temperature factor of contrastive loss")
120 | parser.add_argument("--alpha", type=float, default=0.2, help="the margin of InfoMaxNCE loss")
121 | parser.add_argument("--with_weight", type=int, default=1, help="Whether to weight the fusion of different ")
122 | parser.add_argument("--structure_encoder", type=str, default="gat", help="the encoder of structure view",
123 | choices=["gat", "gcn"])
124 | parser.add_argument("--ab_weight", type=float, default=0.5, help="the weight of NTXent Loss")
125 |
126 | parser.add_argument("--projection", action="store_true", default=False, help="add projection for model")
127 | parser.add_argument("--heads", type=str, default="2,2", help="heads in each gat layer, splitted with comma")
128 | parser.add_argument("--instance_normalization", action="store_true", default=False,
129 | help="enable instance normalization")
130 | parser.add_argument("--attr_dim", type=int, default=100, help="the hidden size of attr and rel features")
131 | parser.add_argument("--name_dim", type=int, default=100, help="the hidden size of name feature")
132 | parser.add_argument("--char_dim", type=int, default=100, help="the hidden size of char feature")
133 |
134 | parser.add_argument("--w_gcn", action="store_false", default=True, help="with gcn features")
135 | parser.add_argument("--w_rel", action="store_false", default=True, help="with rel features")
136 | parser.add_argument("--w_attr", action="store_false", default=True, help="with attr features")
137 | parser.add_argument("--w_name", action="store_false", default=True, help="with name features")
138 | parser.add_argument("--w_char", action="store_false", default=True, help="with char features")
139 | parser.add_argument("--w_img", action="store_false", default=True, help="with img features")
140 | parser.add_argument("--use_surface", type=int, default=0, help="whether to use the surface")
141 |
142 | parser.add_argument("--inner_view_num", type=int, default=6, help="the number of inner view")
143 | parser.add_argument("--word_embedding", type=str, default="glove", help="the type of word embedding, [glove|fasttext]",
144 | choices=["glove", "bert"])
145 | # projection head
146 | parser.add_argument("--use_project_head", action="store_true", default=False, help="use projection head")
147 | parser.add_argument("--zoom", type=float, default=0.1, help="narrow the range of losses")
148 | parser.add_argument("--reduction", type=str, default="mean", help="[sum|mean]", choices=["sum", "mean"])
149 |
150 | # --------- MEAformer -----------
151 | parser.add_argument("--hidden_size", type=int, default=100, help="the hidden size of MEAformer")
152 | parser.add_argument("--intermediate_size", type=int, default=400, help="the hidden size of MEAformer")
153 | parser.add_argument("--num_attention_heads", type=int, default=5, help="the number of attention_heads of MEAformer")
154 | parser.add_argument("--num_hidden_layers", type=int, default=2, help="the number of hidden_layers of MEAformer")
155 | parser.add_argument("--position_embedding_type", default="absolute", type=str)
156 | parser.add_argument("--use_intermediate", type=int, default=1, help="whether to use_intermediate")
157 | parser.add_argument("--replay", type=int, default=0, help="whether to use replay strategy")
158 | parser.add_argument("--neg_cross_kg", type=int, default=0,
159 | help="whether to force the negative samples in the opposite KG")
160 |
161 | # --------- MSNEA -----------
162 | parser.add_argument("--dim", type=int, default=100, help="the hidden size of MSNEA")
163 | parser.add_argument("--neg_triple_num", type=int, default=1, help="neg triple num")
164 | parser.add_argument("--use_bert", type=int, default=0)
165 | parser.add_argument("--use_attr_value", type=int, default=0)
166 | # parser.add_argument("--learning_rate", type=int, default=0.001)
167 | # parser.add_argument("--optimizer", type=str, default="Adam")
168 | # parser.add_argument("--max_epoch", type=int, default=200)
169 |
170 | # parser.add_argument("--save_path", type=str, default="save_pkl", help="save path")
171 |
172 | # ------------ Para ------------
173 | parser.add_argument('--rank', type=int, default=0, help='rank to dist')
174 | parser.add_argument('--dist', type=int, default=0, help='whether to dist')
175 | parser.add_argument('--device', default='cuda', help='device id (i.e. 0 or 0,1 or cpu)')
176 | parser.add_argument('--world-size', default=3, type=int,
177 | help='number of distributed processes')
178 | parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
179 | parser.add_argument("--local_rank", default=-1, type=int)
180 |
181 | parser.add_argument("--nni", default=0, type=int)
182 | args = parser.parse_args()
183 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
184 |
185 | # use gpu 0
186 | torch.cuda.set_device(args.gpu)
187 |
188 |
189 | if __name__ == '__main__':
190 | random.seed(args.seed)
191 | np.random.seed(args.seed)
192 | torch.manual_seed(args.seed)
193 |
194 | results_dir = 'results'
195 | if not os.path.exists(results_dir):
196 | os.makedirs(results_dir)
197 | args_str = f'{args.data_choice}_{args.data_split}_{args.data_rate}_lr{args.lr}_bs{args.n_batch}_hidden_dim{args.hidden_dim}_lamb{args.lamb}_dropout{args.dropout}_act{args.act}_decay_rate{args.decay_rate}'
198 | args.perf_file = os.path.join(results_dir, args.exp_name, args_str + time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) + '.txt')
199 | if not os.path.exists(os.path.join(results_dir, args.exp_name)):
200 | os.makedirs(os.path.join(results_dir, args.exp_name),exist_ok=True)
201 | if args.nni:
202 | import nni
203 | from nni.utils import merge_parameter
204 | nni_params = nni.get_next_parameter()
205 | args = merge_parameter(args, nni_params)
206 | print(args)
207 | print(args, file=open(args.perf_file, 'a'))
208 | loader = DataLoader(args)
209 | id2name = loader.id2name
210 | id2rel = loader.id2rel
211 | n_rel = loader.n_rel
212 | id2rel_reverse = {}
213 | for k, v in id2rel.items():
214 | id2rel_reverse[k+n_rel] = v+'_reverse'
215 | id2rel = {**id2rel , **id2rel_reverse}
216 | id2rel[2*n_rel] = 'self_loop'
217 | id2rel[2*n_rel+1] = 'anchor'
218 | id2rel[2*n_rel+2] = 'anchor_reverse'
219 | left_entity = len(loader.left_ents)
220 |
221 | batch_size = 1
222 | n_data = loader.n_test
223 | n_batch = n_data // batch_size + (n_data % batch_size > 0)
224 |
225 | for i in range(n_batch):
226 | start = i*batch_size
227 | end = min(n_data, (i+1)*batch_size)
228 | batch_idx = np.arange(start, end)
229 | triple = loader.get_batch(batch_idx, data='test')
230 | subs, rels, objs = triple[:,0],triple[:,1],triple[:,2]
231 | sub = subs[0]
232 | rel = rels[0]
233 | obj = objs[0]
234 | edges = loader.get_vis_subgraph(sub, obj, 5)
235 | all_edges_size = sum([len(edge) for edge in edges])
236 | print(all_edges_size)
237 | if all_edges_size >100 or all_edges_size == 0:
238 | continue
239 | pos = {}
240 | x_pos = [-5,-3, -1, 1, 3, 5]
241 | g = {'nodes': [], 'edges': []}
242 | G = nx.DiGraph()
243 | for node in edges[0][:,0].unique():
244 | G.add_node(str(node.item()) + '_' + str(0), desc=id2name[node.item()] + '_' + str(0), layer=0)
245 | g['nodes'].append({'id': str(node.item()) + '_' + str(0), 'name': id2name[node.item()] + '_' + str(0),"class": 1 if node.item() < left_entity else 2 ,"imgsrc": "None","content": "None"} )
246 | pos[str(node.item()) + '_' + str(0)] = (x_pos[0], 0)
247 | for idx, edge in enumerate(edges):
248 | # node_1 = edge[:,0].unique()
249 | node_2 = edge[:,2].unique()
250 | size = len(node_2)
251 |
252 | for y, node in enumerate(node_2):
253 | G.add_node(str(node.item())+'_'+str(idx+1), desc=id2name[node.item()]+'_'+str(idx+1),layer=idx+1)
254 | g['nodes'].append({'id': str(node.item())+'_'+str(idx+1), 'name': id2name[node.item()]+'_'+str(idx+1),"class": 1 if node.item() < left_entity else 2,"imgsrc": "None","content": "None"} )
255 | pos[str(node.item())+'_'+str(idx+1)] = (x_pos[idx+1], 10/(size+1) * (y+1) - 5)
256 | for e in edge:
257 | g['edges'].append({'source': str(e[0].item())+'_'+str(idx), 'target': str(e[2].item())+'_'+str(idx+1), 'name': id2rel[e[1].item()]} )
258 | G.add_edge(str(e[0].item())+'_'+str(idx), str(e[2].item())+'_'+str(idx+1), name=id2rel[e[1].item()])
259 |
260 |
261 | # nodes = torch.cat([edges[:,0], edges[:,2]]).unique()
262 | # for node in nodes:
263 | # G.add_node(node.item(), desc=id2name[node.item()])
264 | # for edge in edges:
265 | # G.add_edge(edge[0].item(), edge[2].item(), name=id2rel[edge[1].item()])
266 |
267 | # draw graph with labels
268 | plt.figure(figsize=(16, 16), dpi=80)
269 | # pos = nx.kamada_kawai_layout(G)
270 | pos = nx.spring_layout(G)
271 | nx.draw(G, pos)
272 | nx.draw_networkx_nodes(G, pos=pos, nodelist=[str(sub.item()) + '_' + str(0),str(obj.item()) + '_' + str(5)], node_color='red', node_size=1000)
273 | node_labels = nx.get_node_attributes(G, 'desc')
274 | nx.draw_networkx_labels(G, pos, labels=node_labels)
275 | edge_labels = nx.get_edge_attributes(G, 'name')
276 | nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)
277 |
278 | plt.savefig(f'FBDB_{sub}_{rel}_{obj}.png', dpi=100)
279 | plt.close()
280 | json.dump(g, open(f'FBDB_{sub}_{rel}_{obj}.json', 'w',encoding='utf-8'), indent=4)
281 |
282 |
283 |
284 |
285 |
286 |
--------------------------------------------------------------------------------