├── logs
└── readme.md
├── save_pkl
└── readme.md
├── src
├── __init__.py
├── loss.py
├── Load.py
├── layers.py
├── utils.py
├── models.py
└── run.py
├── files
└── IBMEA.png
├── requirements.txt
├── run_mmkb_main.sh
├── run_dbp_main.sh
├── README.md
└── env.yml
/logs/readme.md:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/save_pkl/readme.md:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/files/IBMEA.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sutaoyu/IBMEA/HEAD/files/IBMEA.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.10.2
2 | scipy==1.5.2
3 | numpy==1.19.2
4 | pytorch-metric-learning==2.1.1
--------------------------------------------------------------------------------
/run_mmkb_main.sh:
--------------------------------------------------------------------------------
1 | gpu_id='0'
2 | # dataset: 'FB15K_DB15K' 'FB15K_YAGO15K'
3 | dataset='FB15K_DB15K'
4 |
5 | # ratio: 0.2 0.5 0.8
6 | ratio=0.2
7 | seed=2023
8 | warm=400
9 | bsize=1000
10 | if [[ "$dataset" == *"FB"* ]]; then
11 | dataset_dir='mmkb-datasets'
12 | tau=400
13 | else
14 | dataset_dir='DBP15K'
15 | tau=0.1
16 | ratio=0.3
17 | fi
18 | echo "Running with dataset=${dataset}, ratio=${ratio}"
19 | current_datetime=$(date +"%Y-%m-%d-%H-%M")
20 | head_name=${current_datetime}_${dataset}
21 | file_name=${head_name}_bsize${bsize}_${ratio}
22 | echo ${file_name}
23 | CUDA_VISIBLE_DEVICES=${gpu_id} python3 -u src/run.py \
24 | --file_dir data/${dataset_dir}/${dataset} \
25 | --pred_name ${file_name} \
26 | --rate ${ratio} \
27 | --lr .006 \
28 | --epochs 500 \
29 | --dropout 0.45 \
30 | --hidden_units "300,300,300" \
31 | --check_point 50 \
32 | --bsize ${bsize} \
33 | --il \
34 | --il_start 20 \
35 | --semi_learn_step 5 \
36 | --csls \
37 | --csls_k 3 \
38 | --seed ${seed} \
39 | --structure_encoder "gat" \
40 | --img_dim 100 \
41 | --attr_dim 100 \
42 | --use_nce \
43 | --tau ${tau} \
44 | --use_sheduler_cos \
45 | --num_warmup_steps ${warm} \
46 | --w_name \
47 | --w_char > logs/${file_name}.log
--------------------------------------------------------------------------------
/run_dbp_main.sh:
--------------------------------------------------------------------------------
1 | gpu_id='0'
2 | # dataset: 'zh_en' 'ja_en' 'fr_en'
3 | dataset='zh_en'
4 | ratio=0.3
5 | seed=2023
6 | warm=200
7 | joint_beta=3
8 | ms_alpha=5
9 | training_step=1500
10 | bsize=7500
11 | if [[ "$dataset" == *"FB"* ]]; then
12 | dataset_dir='mmkb-datasets'
13 | tau=400
14 | else
15 | dataset_dir='DBP15K'
16 | tau=0.1
17 | ratio=0.3
18 | fi
19 | echo "Running with dataset=${dataset}"
20 | current_datetime=$(date +"%Y-%m-%d-%H-%M")
21 | head_name=${current_datetime}_${dataset}
22 | file_name=${head_name}_bsize${bsize}_${ratio}
23 | echo ${file_name}
24 | CUDA_VISIBLE_DEVICES=${gpu_id} python3 -u src/run.py \
25 | --file_dir data/${dataset_dir}/${dataset} \
26 | --pred_name ${file_name} \
27 | --rate ${ratio} \
28 | --lr .006 \
29 | --epochs 1000 \
30 | --dropout 0.45 \
31 | --hidden_units "300,300,300" \
32 | --check_point 50 \
33 | --bsize ${bsize} \
34 | --il \
35 | --il_start 20 \
36 | --semi_learn_step 5 \
37 | --csls \
38 | --csls_k 3 \
39 | --seed ${seed} \
40 | --structure_encoder "gat" \
41 | --img_dim 100 \
42 | --attr_dim 100 \
43 | --use_nce \
44 | --tau ${tau} \
45 | --use_sheduler_cos \
46 | --num_warmup_steps ${warm} \
47 | --num_training_steps ${training_step} \
48 | --joint_beta ${joint_beta} \
49 | --ms_alpha ${ms_alpha} \
50 | --w_name \
51 | --w_char > logs/${file_name}.log
52 |
--------------------------------------------------------------------------------
/src/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from pytorch_metric_learning import losses, miners
4 |
5 | try:
6 | from models import *
7 | from utils import *
8 | except:
9 | from src.models import *
10 | from src.utils import *
11 |
12 |
13 | class MsLoss(nn.Module):
14 | def __init__(self, device, thresh=0.5, scale_pos=0.1, scale_neg=40.0):
15 | super(MsLoss, self).__init__()
16 | self.device = device
17 | alpha, beta, base = scale_pos, scale_neg, thresh
18 | self.loss_func = losses.MultiSimilarityLoss(alpha=alpha, beta=beta, base=base)
19 |
20 | def sim(self, emb_left, emb_right):
21 | return emb_left.mm(emb_right.t())
22 |
23 | def forward(self, emb, train_links):
24 | emb = F.normalize(emb)
25 | emb_train_left = emb[train_links[:, 0]]
26 | emb_train_right = emb[train_links[:, 1]]
27 | labels = torch.arange(emb_train_left.size(0))
28 | embeddings = torch.cat([emb_train_left, emb_train_right], dim=0)
29 | labels = torch.cat([labels, labels], dim=0)
30 | loss = self.loss_func(embeddings, labels)
31 | return loss
32 |
33 |
34 | class InfoNCE_loss(nn.Module):
35 | def __init__(self, device, temperature=0.05) -> None:
36 | super().__init__()
37 | self.device = device
38 | self.t = temperature
39 |
40 | self.ce_loss = nn.CrossEntropyLoss()
41 |
42 | def sim(self, emb_left, emb_right):
43 | return emb_left.mm(emb_right.t())
44 |
45 | def forward(self, emb, train_links):
46 | emb = F.normalize(emb)
47 | emb_train_left = emb[train_links[:, 0]]
48 | emb_train_right = emb[train_links[:, 1]]
49 |
50 | score = self.sim(emb_train_left, emb_train_right)
51 |
52 | bsize = emb_train_left.size()[0]
53 | label = torch.arange(bsize, dtype=torch.long).cuda(self.device)
54 |
55 | loss = self.ce_loss(score / self.t, label)
56 | return loss
57 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # IBMEA
2 |
3 | The code of paper _**IBMEA: Exploring Variational Information Bottleneck for Multi-modal Entity Alignment**_ [[arxiv](https://arxiv.org/abs/2407.19302)] [[ACM MM](https://dl.acm.org/doi/10.1145/3664647.3680954)] in Proceedings of ACM MM 2024.
4 |
5 |
6 |
7 |

8 |
9 |
10 |
11 |
12 |
13 | ## Env
14 |
15 | Run the following command to create the required conda environment:
16 |
17 | ```
18 | conda env create -f environment.yml -n IBMEA
19 | ```
20 |
21 | ## Datasets
22 |
23 | ### Cross-KG datasets
24 |
25 | The original cross-KG datasets (FB15K-DB15K/YAGO15K) comes from [MMKB](https://github.com/mniepert/mmkb), in which the image embeddings are extracted from the pre-trained VGG16. We follow [MCLEA](https://github.com/lzxlin/MCLEA) to use the image embeddings provided by [MMKB](https://github.com/mniepert/mmkb#visual-data-for-fb15k-yago15k-and-dbpedia15k) and transform the data into the format consistent with DBP15K.
26 | The converted dataset can be downloaded from [BaiduDisk](https://pan.baidu.com/s/1xOz8ga_J1LPbBSIgSFHbNQ) (the password is `IBME`)
27 | Place `mmkb-datasets` directory in the `data` directory.
28 |
29 | ### Bilingual datasets
30 |
31 | The multi-modal version of DBP15K dataset comes from the [EVA](https://github.com/cambridgeltl/eva) repository, and the folder `pkls` of DBP15K image features should be downloaded according to the guidance of EVA repository, and the downloaded folder `pkls` is placed in the `data` directory of this repository.
32 |
33 | ## How to run script
34 |
35 | ```
36 | run_mmkb_main.sh
37 | run_dbp_main.sh
38 | ```
39 |
40 | ## Citation
41 |
42 | If you use this model or code, please cite it as follows:
43 |
44 | ```
45 | @inproceedings{IBMEA,
46 | author = {Taoyu Su and
47 | Jiawei Sheng and
48 | Shicheng Wang and
49 | Xinghua Zhang and
50 | Hongbo Xu and
51 | Tingwen Liu},
52 | editor = {Jianfei Cai and
53 | Mohan S. Kankanhalli and
54 | Balakrishnan Prabhakaran and
55 | Susanne Boll and
56 | Ramanathan Subramanian and
57 | Liang Zheng and
58 | Vivek K. Singh and
59 | Pablo C{\'{e}}sar and
60 | Lexing Xie and
61 | Dong Xu},
62 | title = {{IBMEA:} Exploring Variational Information Bottleneck for Multi-modal
63 | Entity Alignment},
64 | booktitle = {Proceedings of the 32nd {ACM} International Conference on Multimedia,
65 | {MM} 2024, Melbourne, VIC, Australia, 28 October 2024 - 1 November
66 | 2024},
67 | pages = {4436--4445},
68 | publisher = {{ACM}},
69 | year = {2024}
70 | }
71 | ```
72 |
73 | # Acknowledgement
74 |
75 | We appreciate [MCLEA](https://github.com/lzxlin/MCLEA), [EVA](https://github.com/cambridgeltl/eva), [MMEA](https://github.com/liyichen-cly/MMEA) and many other related works for their open-source contributions.
76 |
77 |
78 | #### Remarks
79 | - Welcome to give me a star ⭐, let me know how many people want to know about this work 🤩.
80 | - It is said that people who click on stars ⭐ will have better acceptance rates for their papers 😃.
81 | - **路过点star,下次中A啊!!!!**
82 |
83 |
84 |
--------------------------------------------------------------------------------
/src/Load.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from collections import Counter
3 | import json
4 | import pickle
5 | from tqdm import tqdm
6 |
7 |
8 | def loadfile(fn, num=1):
9 | print("loading a file..." + fn)
10 | ret = []
11 | with open(fn, encoding="utf-8") as f:
12 | for line in f:
13 | th = line[:-1].split("\t")
14 | x = []
15 | for i in range(num):
16 | x.append(int(th[i]))
17 | ret.append(tuple(x))
18 | return ret
19 |
20 |
21 | def get_ids(fn):
22 | ids = []
23 | with open(fn, encoding="utf-8") as f:
24 | for line in f:
25 | th = line[:-1].split("\t")
26 | ids.append(int(th[0]))
27 | return ids
28 |
29 |
30 | def get_ent2id(fns):
31 | ent2id = {}
32 | for fn in fns:
33 | with open(fn, "r", encoding="utf-8") as f:
34 | for line in f:
35 | th = line[:-1].split("\t")
36 | ent2id[th[1]] = int(th[0])
37 | return ent2id
38 |
39 |
40 | def load_attr(fns, e, ent2id, topA=1000):
41 | cnt = {}
42 | for fn in fns:
43 | with open(fn, "r", encoding="utf-8") as f:
44 | for line in f:
45 | th = line[:-1].split("\t")
46 | if th[0] not in ent2id:
47 | continue
48 | for i in range(1, len(th)):
49 | if th[i] not in cnt:
50 | cnt[th[i]] = 1
51 | else:
52 | cnt[th[i]] += 1
53 | fre = [(k, cnt[k]) for k in sorted(cnt, key=cnt.get, reverse=True)]
54 | attr2id = {}
55 | for i in range(min(topA, len(fre))):
56 | attr2id[fre[i][0]] = i
57 | attr = np.zeros((e, topA), dtype=np.float32)
58 | for fn in fns:
59 | with open(fn, "r", encoding="utf-8") as f:
60 | for line in f:
61 | th = line[:-1].split("\t")
62 | if th[0] in ent2id:
63 | for i in range(1, len(th)):
64 | if th[i] in attr2id:
65 | attr[ent2id[th[0]]][attr2id[th[i]]] = 1.0
66 | return attr
67 |
68 |
69 | def load_relation(e, KG, topR=1000):
70 | rel_mat = np.zeros((e, topR), dtype=np.float32)
71 | rels = np.array(KG)[:, 1]
72 | top_rels = Counter(rels).most_common(topR)
73 | rel_index_dict = {r: i for i, (r, cnt) in enumerate(top_rels)}
74 | for tri in KG:
75 | h = tri[0]
76 | r = tri[1]
77 | o = tri[2]
78 | if r in rel_index_dict:
79 | rel_mat[h][rel_index_dict[r]] += 1.0
80 | rel_mat[o][rel_index_dict[r]] += 1.0
81 | return np.array(rel_mat)
82 |
83 |
84 | def load_json_embd(path):
85 | embd_dict = {}
86 | with open(path) as f:
87 | for line in f:
88 | example = json.loads(line.strip())
89 | vec = np.array([float(e) for e in example["feature"].split()])
90 | embd_dict[int(example["guid"])] = vec
91 | return embd_dict
92 |
93 |
94 | def load_img(e_num, path):
95 | img_dict = pickle.load(open(path, "rb"))
96 | imgs_np = np.array(list(img_dict.values()))
97 | mean = np.mean(imgs_np, axis=0)
98 | std = np.std(imgs_np, axis=0)
99 | img_embd = np.array(
100 | [
101 | img_dict[i] if i in img_dict else np.random.normal(mean, std, mean.shape[0])
102 | for i in range(e_num)
103 | ]
104 | )
105 | print("%.2f%% entities have images" % (100 * len(img_dict) / e_num))
106 | return img_embd
107 |
108 |
109 | def load_img_new(e_num, path, triples):
110 | from collections import defaultdict
111 |
112 | img_dict = pickle.load(open(path, "rb"))
113 | neighbor_list = defaultdict(list)
114 | for triple in triples:
115 | head = triple[0]
116 | relation = triple[1]
117 | tail = triple[2]
118 | if tail in img_dict:
119 | neighbor_list[head].append(tail)
120 | if head in img_dict:
121 | neighbor_list[tail].append(head)
122 | imgs_np = np.array(list(img_dict.values()))
123 | mean = np.mean(imgs_np, axis=0)
124 | std = np.std(imgs_np, axis=0)
125 | all_img_emb_normal = np.random.normal(mean, std, mean.shape[0])
126 | img_embd = []
127 | follow_neirbor_img_num = 0
128 | follow_all_img_num = 0
129 | for i in range(e_num):
130 | if i in img_dict:
131 | img_embd.append(img_dict[i])
132 | else:
133 | if len(neighbor_list[i]) > 0:
134 | follow_neirbor_img_num += 1
135 | if i in img_dict:
136 | neighbor_list[i].append(i)
137 | neighbor_imgs_emb = np.array([img_dict[id] for id in neighbor_list[i]])
138 | neighbor_imgs_emb_mean = np.mean(neighbor_imgs_emb, axis=0)
139 | img_embd.append(neighbor_imgs_emb_mean)
140 | else:
141 | follow_all_img_num += 1
142 | img_embd.append(all_img_emb_normal)
143 | print(
144 | "%.2f%% entities have images," % (100 * len(img_dict) / e_num),
145 | " follow_neirbor_img_num is {0},".format(follow_neirbor_img_num),
146 | " follow_all_img_num is {0}".format(follow_all_img_num),
147 | )
148 | return np.array(img_embd)
--------------------------------------------------------------------------------
/env.yml:
--------------------------------------------------------------------------------
1 | name: IBMEA
2 | channels:
3 | - pytorch
4 | - https://repo.anaconda.com/pkgs/main
5 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/Paddle/
6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
7 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
8 | - defaults
9 | dependencies:
10 | - _libgcc_mutex=0.1=main
11 | - _openmp_mutex=5.1=1_gnu
12 | - blas=1.0=mkl
13 | - bzip2=1.0.8=h7b6447c_0
14 | - ca-certificates=2022.07.19=h06a4308_0
15 | - certifi=2021.5.30=py36h06a4308_0
16 | - cudatoolkit=11.3.1=h2bc3f7f_2
17 | - dataclasses=0.8=pyh4f3eec9_6
18 | - ffmpeg=4.3=hf484d3e_0
19 | - freetype=2.11.0=h70c0345_0
20 | - gmp=6.2.1=h295c915_3
21 | - gnutls=3.6.15=he1e5248_0
22 | - intel-openmp=2022.1.0=h9e868ea_3769
23 | - jpeg=9e=h7f8727e_0
24 | - lame=3.100=h7b6447c_0
25 | - lcms2=2.12=h3be6417_0
26 | - ld_impl_linux-64=2.38=h1181459_1
27 | - lerc=3.0=h295c915_0
28 | - libdeflate=1.8=h7f8727e_5
29 | - libffi=3.3=he6710b0_2
30 | - libgcc-ng=11.2.0=h1234567_1
31 | - libgomp=11.2.0=h1234567_1
32 | - libiconv=1.16=h7f8727e_2
33 | - libidn2=2.3.2=h7f8727e_0
34 | - libpng=1.6.37=hbc83047_0
35 | - libstdcxx-ng=11.2.0=h1234567_1
36 | - libtasn1=4.16.0=h27cfd23_0
37 | - libtiff=4.4.0=hecacb30_0
38 | - libunistring=0.9.10=h27cfd23_0
39 | - libuv=1.40.0=h7b6447c_0
40 | - libwebp-base=1.2.4=h5eee18b_0
41 | - lz4-c=1.9.3=h295c915_1
42 | - mkl=2020.2=256
43 | - mkl-service=2.3.0=py36he8ac12f_0
44 | - mkl_fft=1.3.0=py36h54f3939_0
45 | - mkl_random=1.1.1=py36h0573a6f_0
46 | - ncurses=6.3=h5eee18b_3
47 | - nettle=3.7.3=hbbd107a_1
48 | - numpy=1.19.2=py36h54aff64_0
49 | - numpy-base=1.19.2=py36hfa32c7d_0
50 | - olefile=0.46=py36_0
51 | - openh264=2.1.1=h4ff587b_0
52 | - openjpeg=2.4.0=h3ad879b_0
53 | - openssl=1.1.1q=h7f8727e_0
54 | - pillow=8.3.1=py36h2c7a002_0
55 | - pip=21.2.2=py36h06a4308_0
56 | - python=3.6.13=h12debd9_1
57 | - pytorch=1.10.2=py3.6_cuda11.3_cudnn8.2.0_0
58 | - pytorch-mutex=1.0=cuda
59 | - readline=8.1.2=h7f8727e_1
60 | - setuptools=58.0.4=py36h06a4308_0
61 | - six=1.16.0=pyhd3eb1b0_1
62 | - sqlite=3.39.3=h5082296_0
63 | - tk=8.6.12=h1ccaba5_0
64 | - torchaudio=0.10.2=py36_cu113
65 | - torchvision=0.11.3=py36_cu113
66 | - typing_extensions=4.1.1=pyh06a4308_0
67 | - wheel=0.37.1=pyhd3eb1b0_0
68 | - xz=5.2.6=h5eee18b_0
69 | - zlib=1.2.12=h5eee18b_3
70 | - zstd=1.5.2=ha4553b6_0
71 | - pip:
72 | - absl-py==1.4.0
73 | - argon2-cffi==21.3.0
74 | - argon2-cffi-bindings==21.2.0
75 | - async-generator==1.10
76 | - attrs==22.1.0
77 | - backcall==0.2.0
78 | - bleach==4.1.0
79 | - cachetools==4.2.4
80 | - cffi==1.15.1
81 | - charset-normalizer==2.0.12
82 | - click==8.0.4
83 | - cycler==0.11.0
84 | - debugpy==1.5.1
85 | - decorator==5.1.1
86 | - defusedxml==0.7.1
87 | - easydict==1.11
88 | - entrypoints==0.4
89 | - filelock==3.4.1
90 | - google-auth==2.22.0
91 | - google-auth-oauthlib==0.4.6
92 | - grpcio==1.48.2
93 | - huggingface-hub==0.4.0
94 | - idna==3.6
95 | - importlib-metadata==4.8.3
96 | - importlib-resources==5.4.0
97 | - ipykernel==5.5.6
98 | - ipython==7.16.3
99 | - ipython-genutils==0.2.0
100 | - ipywidgets==7.7.2
101 | - jedi==0.17.2
102 | - jinja2==3.0.3
103 | - joblib==1.1.1
104 | - jsonschema==3.2.0
105 | - jupyter==1.0.0
106 | - jupyter-client==7.1.2
107 | - jupyter-console==6.4.3
108 | - jupyter-core==4.9.2
109 | - jupyterlab-pygments==0.1.2
110 | - jupyterlab-widgets==1.1.1
111 | - kiwisolver==1.3.1
112 | - markdown==3.3.7
113 | - markupsafe==2.0.1
114 | - matplotlib==3.3.4
115 | - mistune==0.8.4
116 | - nbclient==0.5.9
117 | - nbconvert==6.0.7
118 | - nbformat==5.1.3
119 | - nest-asyncio==1.5.6
120 | - notebook==6.4.10
121 | - oauthlib==3.2.2
122 | - packaging==21.3
123 | - pandas==1.1.5
124 | - pandocfilters==1.5.0
125 | - parso==0.7.1
126 | - pexpect==4.8.0
127 | - pickleshare==0.7.5
128 | - prometheus-client==0.15.0
129 | - prompt-toolkit==3.0.31
130 | - protobuf==3.19.6
131 | - ptyprocess==0.7.0
132 | - pyasn1==0.5.1
133 | - pyasn1-modules==0.3.0
134 | - pycparser==2.21
135 | - pygments==2.13.0
136 | - pyparsing==3.0.9
137 | - pyrsistent==0.18.0
138 | - python-dateutil==2.8.2
139 | - pytorch-metric-learning==2.1.1
140 | - pytz==2022.6
141 | - pyyaml==6.0.1
142 | - pyzmq==24.0.1
143 | - qtconsole==5.2.2
144 | - qtpy==2.0.1
145 | - regex==2023.8.8
146 | - requests==2.27.1
147 | - requests-oauthlib==1.3.1
148 | - rsa==4.9
149 | - sacremoses==0.0.53
150 | - scikit-learn==0.24.2
151 | - scipy==1.5.2
152 | - send2trash==1.8.0
153 | - tensorboard==2.10.1
154 | - tensorboard-data-server==0.6.1
155 | - tensorboard-plugin-wit==1.8.1
156 | - terminado==0.12.1
157 | - testpath==0.6.0
158 | - threadpoolctl==3.1.0
159 | - tokenizers==0.12.1
160 | - tornado==6.1
161 | - tqdm==4.64.1
162 | - traitlets==4.3.3
163 | - transformers==4.18.0
164 | - unidecode==1.3.7
165 | - urllib3==1.26.18
166 | - wcwidth==0.2.5
167 | - webencodings==0.5.1
168 | - werkzeug==2.0.3
169 | - widgetsnbextension==3.6.1
170 | - zipp==3.6.0
171 |
172 |
--------------------------------------------------------------------------------
/src/layers.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | from __future__ import absolute_import
5 | from __future__ import unicode_literals
6 | from __future__ import division
7 | from __future__ import print_function
8 |
9 | import math
10 | import torch
11 | import torch.nn as nn
12 | from torch.nn.parameter import Parameter
13 | from torch.nn.modules.module import Module
14 | import torch.nn.functional as F
15 |
16 |
17 | class SpecialSpmmFunction(torch.autograd.Function):
18 | @staticmethod
19 | def forward(ctx, indices, values, shape, b):
20 | assert indices.requires_grad == False
21 | a = torch.sparse_coo_tensor(indices, values, shape)
22 | ctx.save_for_backward(a, b)
23 | ctx.N = shape[0]
24 | return torch.matmul(a, b)
25 |
26 | @staticmethod
27 | def backward(ctx, grad_output):
28 | a, b = ctx.saved_tensors
29 | grad_values = grad_b = None
30 | if ctx.needs_input_grad[1]:
31 | grad_a_dense = grad_output.matmul(b.t())
32 | edge_idx = a._indices()[0, :] * ctx.N + a._indices()[1, :]
33 | grad_values = grad_a_dense.view(-1)[edge_idx]
34 | if ctx.needs_input_grad[3]:
35 | grad_b = a.t().matmul(grad_output)
36 | return None, grad_values, None, grad_b
37 |
38 |
39 | class SpecialSpmm(nn.Module):
40 | def forward(self, indices, values, shape, b):
41 | return SpecialSpmmFunction.apply(indices, values, shape, b)
42 |
43 |
44 | class MultiHeadGraphAttention(nn.Module):
45 | def __init__(
46 | self, n_head, f_in, f_out, attn_dropout, diag=True, init=None, bias=False
47 | ):
48 | super(MultiHeadGraphAttention, self).__init__()
49 | self.n_head = n_head
50 | self.f_in = f_in
51 | self.f_out = f_out
52 | self.diag = diag
53 | if self.diag:
54 | self.w = Parameter(torch.Tensor(n_head, 1, f_out))
55 | else:
56 | self.w = Parameter(torch.Tensor(n_head, f_in, f_out))
57 | self.a_src_dst = Parameter(torch.Tensor(n_head, f_out * 2, 1))
58 | self.attn_dropout = attn_dropout
59 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2)
60 | self.special_spmm = SpecialSpmm()
61 | if bias:
62 | self.bias = Parameter(torch.Tensor(f_out))
63 | nn.init.constant_(self.bias, 0)
64 | else:
65 | self.register_parameter("bias", None)
66 | if init is not None and diag:
67 | init(self.w)
68 | stdv = 1.0 / math.sqrt(self.a_src_dst.size(1))
69 | nn.init.uniform_(self.a_src_dst, -stdv, stdv)
70 | else:
71 | nn.init.xavier_uniform_(self.w)
72 | nn.init.xavier_uniform_(self.a_src_dst)
73 |
74 | def forward(self, input, adj):
75 | output = []
76 | for i in range(self.n_head):
77 | N = input.size()[0]
78 | edge = adj._indices()
79 | if self.diag:
80 | h = torch.mul(input, self.w[i])
81 | else:
82 | h = torch.mm(input, self.w[i])
83 |
84 | edge_h = torch.cat(
85 | (h[edge[0, :], :], h[edge[1, :], :]), dim=1
86 | ) # edge: 2*D x E
87 | edge_e = torch.exp(
88 | -self.leaky_relu(edge_h.mm(self.a_src_dst[i]).squeeze())
89 | ) # edge_e: 1 x E
90 |
91 | e_rowsum = self.special_spmm(
92 | edge,
93 | edge_e,
94 | torch.Size([N, N]),
95 | (
96 | torch.ones(size=(N, 1)).cuda()
97 | if next(self.parameters()).is_cuda
98 | else torch.ones(size=(N, 1))
99 | ),
100 | ) # e_rowsum: N x 1
101 | edge_e = F.dropout(
102 | edge_e, self.attn_dropout, training=self.training
103 | ) # edge_e: 1 x E
104 |
105 | h_prime = self.special_spmm(edge, edge_e, torch.Size([N, N]), h)
106 | h_prime = h_prime.div(e_rowsum)
107 |
108 | output.append(h_prime.unsqueeze(0))
109 |
110 | output = torch.cat(output, dim=0)
111 | if self.bias is not None:
112 | return output + self.bias
113 | else:
114 | return output
115 |
116 | def __repr__(self):
117 | if self.diag:
118 | return (
119 | self.__class__.__name__
120 | + " ("
121 | + str(self.f_out)
122 | + " -> "
123 | + str(self.f_out)
124 | + ") * "
125 | + str(self.n_head)
126 | + " heads"
127 | )
128 | else:
129 | return (
130 | self.__class__.__name__
131 | + " ("
132 | + str(self.f_in)
133 | + " -> "
134 | + str(self.f_out)
135 | + ") * "
136 | + str(self.n_head)
137 | + " heads"
138 | )
139 |
140 |
141 | class GraphConvolution(Module):
142 | def __init__(self, in_features, out_features, bias=True):
143 | super(GraphConvolution, self).__init__()
144 | self.in_features = in_features
145 | self.out_features = out_features
146 | self.weight = Parameter(torch.FloatTensor(in_features, out_features))
147 | if bias:
148 | self.bias = Parameter(torch.FloatTensor(out_features))
149 | else:
150 | self.register_parameter("bias", None)
151 | self.reset_parameters()
152 |
153 | def reset_parameters(self):
154 | stdv = 1.0 / math.sqrt(self.weight.size(1))
155 | self.weight.data.uniform_(-stdv, stdv)
156 | if self.bias is not None:
157 | self.bias.data.uniform_(-stdv, stdv)
158 |
159 | def forward(self, input, adj):
160 | support = torch.mm(input, self.weight)
161 | output = torch.spmm(adj, support)
162 | if self.bias is not None:
163 | return output + self.bias
164 | else:
165 | return output
166 |
167 | def __repr__(self):
168 | return (
169 | self.__class__.__name__
170 | + " ("
171 | + str(self.in_features)
172 | + " -> "
173 | + str(self.out_features)
174 | + ")"
175 | )
176 |
177 |
178 | class ProjectionHead(nn.Module):
179 | def __init__(self, in_dim, hidden_dim, out_dim, dropout):
180 | super(ProjectionHead, self).__init__()
181 | self.l1 = nn.Linear(in_dim, hidden_dim, bias=False)
182 | self.l2 = nn.Linear(hidden_dim, out_dim, bias=False)
183 | self.dropout = dropout
184 |
185 | def forward(self, x):
186 | x = self.l1(x)
187 | x = F.relu(x)
188 | x = F.dropout(x, self.dropout, training=self.training)
189 | x = self.l2(x)
190 | return x
191 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | from __future__ import absolute_import
5 | from __future__ import unicode_literals
6 | from __future__ import division
7 | from __future__ import print_function
8 |
9 | import os, time, multiprocessing
10 | import math
11 | import random
12 | import numpy as np
13 | import scipy
14 | import scipy.sparse as sp
15 | import torch
16 | import gc
17 | from tqdm import tqdm
18 | import json
19 | from torch.utils.data import Dataset
20 |
21 | import torch.optim as optim
22 |
23 |
24 | def normalize_adj(mx):
25 | rowsum = np.array(mx.sum(1))
26 | r_inv_sqrt = np.power(rowsum, -0.5).flatten()
27 | r_inv_sqrt[np.isinf(r_inv_sqrt)] = 0.0
28 | r_mat_inv_sqrt = sp.diags(r_inv_sqrt)
29 | return mx.dot(r_mat_inv_sqrt).transpose().dot(r_mat_inv_sqrt)
30 |
31 | def normalize_features(mx):
32 | rowsum = np.array(mx.sum(1))
33 | r_inv = np.power(rowsum, -1).flatten()
34 | r_inv[np.isinf(r_inv)] = 0.0
35 | r_mat_inv = sp.diags(r_inv)
36 | mx = r_mat_inv.dot(mx)
37 | return mx
38 |
39 |
40 | def sparse_mx_to_torch_sparse_tensor(sparse_mx):
41 | sparse_mx = sparse_mx.tocoo().astype(np.float32)
42 | indices = torch.from_numpy(
43 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)
44 | )
45 | values = torch.FloatTensor(sparse_mx.data)
46 | shape = torch.Size(sparse_mx.shape)
47 | return torch.sparse.FloatTensor(indices, values, shape)
48 |
49 |
50 | def read_raw_data(file_dir, l=[1, 2], reverse=False):
51 | print("loading raw data...")
52 | def read_file(file_paths):
53 | tups = []
54 | for file_path in file_paths:
55 | with open(file_path, "r", encoding="utf-8") as fr:
56 | for line in fr:
57 | params = line.strip("\n").split("\t")
58 | tups.append(tuple([int(x) for x in params]))
59 | return tups
60 |
61 | def read_dict(file_paths):
62 | ent2id_dict = {}
63 | ids = []
64 | for file_path in file_paths:
65 | id = set()
66 | with open(file_path, "r", encoding="utf-8") as fr:
67 | for line in fr:
68 | params = line.strip("\n").split("\t")
69 | ent2id_dict[params[1]] = int(params[0])
70 | id.add(int(params[0]))
71 | ids.append(id)
72 | return ent2id_dict, ids
73 |
74 | ent2id_dict, ids = read_dict([file_dir + "/ent_ids_" + str(i) for i in l])
75 | ills = read_file([file_dir + "/ill_ent_ids"])
76 | triples = read_file([file_dir + "/triples_" + str(i) for i in l])
77 | rel_size = max([t[1] for t in triples]) + 1
78 | reverse_triples = []
79 | r_hs, r_ts = {}, {}
80 | for h, r, t in triples:
81 | if r not in r_hs:
82 | r_hs[r] = set()
83 | if r not in r_ts:
84 | r_ts[r] = set()
85 | r_hs[r].add(h)
86 | r_ts[r].add(t)
87 | if reverse:
88 | reverse_r = r + rel_size
89 | reverse_triples.append((t, reverse_r, h))
90 | if reverse_r not in r_hs:
91 | r_hs[reverse_r] = set()
92 | if reverse_r not in r_ts:
93 | r_ts[reverse_r] = set()
94 | r_hs[reverse_r].add(t)
95 | r_ts[reverse_r].add(h)
96 | if reverse:
97 | triples = triples + reverse_triples
98 | assert len(r_hs) == len(r_ts)
99 | return ent2id_dict, ills, triples, r_hs, r_ts, ids
100 |
101 |
102 | def div_list(ls, n):
103 | ls_len = len(ls)
104 | if n <= 0 or 0 == ls_len:
105 | return []
106 | if n > ls_len:
107 | return []
108 | elif n == ls_len:
109 | return [[i] for i in ls]
110 | else:
111 | j = ls_len // n
112 | k = ls_len % n
113 | ls_return = []
114 | for i in range(0, (n - 1) * j, j):
115 | ls_return.append(ls[i : i + j])
116 | ls_return.append(ls[(n - 1) * j :])
117 | return ls_return
118 |
119 |
120 | def get_adjr(ent_size, triples, norm=False):
121 | print("getting a sparse tensor r_adj...")
122 | M = {}
123 | for tri in triples:
124 | if tri[0] == tri[2]:
125 | continue
126 | if (tri[0], tri[2]) not in M:
127 | M[(tri[0], tri[2])] = 0
128 | M[(tri[0], tri[2])] += 1
129 | ind, val = [], []
130 | for fir, sec in M:
131 | ind.append((fir, sec))
132 | ind.append((sec, fir)) # 关系逆
133 | val.append(M[(fir, sec)])
134 | val.append(M[(fir, sec)])
135 | for i in range(ent_size):
136 | ind.append((i, i))
137 | val.append(1)
138 | if norm:
139 | ind = np.array(ind, dtype=np.int32)
140 | val = np.array(val, dtype=np.float32)
141 | adj = sp.coo_matrix(
142 | (val, (ind[:, 0], ind[:, 1])), shape=(ent_size, ent_size), dtype=np.float32
143 | )
144 | return sparse_mx_to_torch_sparse_tensor(normalize_adj(adj))
145 | else:
146 | M = torch.sparse_coo_tensor(
147 | torch.LongTensor(ind).t(),
148 | torch.FloatTensor(val),
149 | torch.Size([ent_size, ent_size]),
150 | )
151 | return M
152 |
153 | def pairwise_distances(x, y=None):
154 | x_norm = (x**2).sum(1).view(-1, 1)
155 | if y is not None:
156 | y_norm = (y**2).sum(1).view(1, -1)
157 | else:
158 | y = x
159 | y_norm = x_norm.view(1, -1)
160 | dist = x_norm + y_norm - 2.0 * torch.mm(x, torch.transpose(y, 0, 1))
161 | return torch.clamp(dist, 0.0, np.inf)
162 |
163 |
164 | def multi_cal_rank(task, sim, top_k, l_or_r):
165 | mean = 0
166 | mrr = 0
167 | num = [0 for k in top_k]
168 | for i in range(len(task)):
169 | ref = task[i]
170 | if l_or_r == 0:
171 | rank = (sim[i, :]).argsort()
172 | else:
173 | rank = (sim[:, i]).argsort()
174 | assert ref in rank
175 | rank_index = np.where(rank == ref)[0][0]
176 | mean += rank_index + 1
177 | mrr += 1.0 / (rank_index + 1)
178 | for j in range(len(top_k)):
179 | if rank_index < top_k[j]:
180 | num[j] += 1
181 | return mean, num, mrr
182 |
183 |
184 | def multi_get_hits(Lvec, Rvec, top_k=(1, 5, 10, 50, 100), args=None):
185 | result = []
186 | sim = pairwise_distances(torch.FloatTensor(Lvec), torch.FloatTensor(Rvec)).numpy()
187 | if args.csls is True:
188 | sim = 1 - csls_sim(1 - sim, args.csls_k)
189 | for i in [0, 1]:
190 | top_total = np.array([0] * len(top_k))
191 | mean_total, mrr_total = 0.0, 0.0
192 | s_len = Lvec.shape[0] if i == 0 else Rvec.shape[0]
193 | tasks = div_list(np.array(range(s_len)), 10)
194 | pool = multiprocessing.Pool(processes=len(tasks))
195 | reses = list()
196 | for task in tasks:
197 | if i == 0:
198 | reses.append(
199 | pool.apply_async(multi_cal_rank, (task, sim[task, :], top_k, i))
200 | )
201 | else:
202 | reses.append(
203 | pool.apply_async(multi_cal_rank, (task, sim[:, task], top_k, i))
204 | )
205 | pool.close()
206 | pool.join()
207 | for res in reses:
208 | mean, num, mrr = res.get()
209 | mean_total += mean
210 | mrr_total += mrr
211 | top_total += np.array(num)
212 | acc_total = top_total / s_len
213 | for i in range(len(acc_total)):
214 | acc_total[i] = round(acc_total[i], 4)
215 | mean_total /= s_len
216 | mrr_total /= s_len
217 | result.append(acc_total)
218 | result.append(mean_total)
219 | result.append(mrr_total)
220 | return result
221 |
222 |
223 | def csls_sim(sim_mat, k):
224 | nearest_values1 = torch.mean(torch.topk(sim_mat, k)[0], 1)
225 | nearest_values2 = torch.mean(torch.topk(sim_mat.t(), k)[0], 1)
226 | csls_sim_mat = 2 * sim_mat.t() - nearest_values1
227 | csls_sim_mat = csls_sim_mat.t() - nearest_values2
228 | return csls_sim_mat
229 |
230 |
231 | def get_topk_indices(M, K=1000):
232 | H, W = M.shape
233 | M_view = M.view(-1)
234 | vals, indices = M_view.topk(K)
235 | print("highest sim:", vals[0].item(), "lowest sim:", vals[-1].item())
236 | two_d_indices = torch.cat(
237 | ((indices // W).unsqueeze(1), (indices % W).unsqueeze(1)), dim=1
238 | )
239 | return two_d_indices
240 |
241 |
242 | def normalize_zero_one(A):
243 | A -= A.min(1, keepdim=True)[0]
244 | A /= A.max(1, keepdim=True)[0]
245 | return A
246 |
247 |
248 | if __name__ == "__main__":
249 | pass
250 |
--------------------------------------------------------------------------------
/src/models.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | from __future__ import absolute_import
5 | from __future__ import unicode_literals
6 | from __future__ import division
7 | from __future__ import print_function
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import random
13 | import numpy as np
14 | import debugpy
15 |
16 | try:
17 | from layers import *
18 | except:
19 | from src.layers import *
20 |
21 |
22 | class GAT(nn.Module):
23 | def __init__(
24 | self, n_units, n_heads, dropout, attn_dropout, instance_normalization, diag
25 | ):
26 | super(GAT, self).__init__()
27 | self.num_layer = len(n_units) - 1
28 | self.dropout = dropout
29 | self.inst_norm = instance_normalization
30 | if self.inst_norm:
31 | self.norm = nn.InstanceNorm1d(n_units[0], momentum=0.0, affine=True)
32 | self.layer_stack = nn.ModuleList()
33 | self.diag = diag
34 | for i in range(self.num_layer):
35 | f_in = n_units[i] * n_heads[i - 1] if i else n_units[i]
36 | self.layer_stack.append(
37 | MultiHeadGraphAttention(
38 | n_heads[i],
39 | f_in,
40 | n_units[i + 1],
41 | attn_dropout,
42 | diag,
43 | nn.init.ones_,
44 | False,
45 | )
46 | )
47 |
48 | def forward(self, x, adj):
49 | if self.inst_norm:
50 | x = self.norm(x)
51 | for i, gat_layer in enumerate(self.layer_stack):
52 | if i + 1 < self.num_layer:
53 | x = F.dropout(x, self.dropout, training=self.training)
54 | x = gat_layer(x, adj)
55 | if self.diag:
56 | x = x.mean(dim=0)
57 | if i + 1 < self.num_layer:
58 | if self.diag:
59 | x = F.elu(x)
60 | else:
61 | x = F.elu(x.transpose(0, 1).contiguous().view(adj.size(0), -1))
62 | if not self.diag:
63 | x = x.mean(dim=0)
64 |
65 | return x
66 |
67 |
68 | class GCN(nn.Module):
69 | def __init__(self, nfeat, nhid, nout, dropout):
70 | super(GCN, self).__init__()
71 |
72 | self.gc1 = GraphConvolution(nfeat, nhid)
73 | self.gc2 = GraphConvolution(nhid, nout)
74 | self.dropout = dropout
75 |
76 | def forward(self, x, adj):
77 | x = F.relu(self.gc1(x, adj)) # change to leaky relu
78 | x = F.dropout(x, self.dropout, training=self.training)
79 | x = self.gc2(x, adj)
80 | # x = F.relu(x)
81 | return x
82 |
83 |
84 | def cosine_sim(im, s):
85 | """Cosine similarity between all the image and sentence pairs"""
86 | return im.mm(s.t())
87 |
88 |
89 | def l2norm(X):
90 | """L2-normalize columns of X"""
91 | norm = torch.pow(X, 2).sum(dim=1, keepdim=True).sqrt()
92 | a = norm.expand_as(X) + 1e-8
93 | X = torch.div(X, a)
94 | return X
95 |
96 |
97 | class MultiModalFusion(nn.Module):
98 | def __init__(self, modal_num, with_weight=1):
99 | super().__init__()
100 | self.modal_num = modal_num
101 | self.requires_grad = True if with_weight > 0 else False
102 | self.weight = nn.Parameter(
103 | torch.ones((self.modal_num, 1)), requires_grad=self.requires_grad
104 | )
105 |
106 | def forward(self, embs):
107 | assert len(embs) == self.modal_num
108 | weight_norm = F.softmax(self.weight, dim=0)
109 | embs = [
110 | weight_norm[idx] * F.normalize(embs[idx])
111 | for idx in range(self.modal_num)
112 | if embs[idx] is not None
113 | ]
114 | joint_emb = torch.cat(embs, dim=1)
115 | return joint_emb
116 |
117 |
118 | class MultiModalFusionNew_allmodal(nn.Module):
119 | def __init__(self, modal_num, with_weight=1):
120 | super().__init__()
121 | self.modal_num = modal_num
122 | self.requires_grad = True if with_weight > 0 else False
123 | self.weight = nn.Parameter(
124 | torch.ones((self.modal_num, 1)), requires_grad=self.requires_grad
125 | )
126 |
127 | self.linear_0 = nn.Linear(100, 600)
128 | self.linear_1 = nn.Linear(100, 600)
129 | self.linear_2 = nn.Linear(100, 600)
130 | self.linear_3 = nn.Linear(300, 600)
131 | self.linear = nn.Linear(600, 600)
132 | self.v = nn.Linear(600, 1, bias=False)
133 |
134 | self.LN_pre = nn.LayerNorm(600)
135 | self.LN_pre = nn.LayerNorm(600)
136 |
137 | def forward(self, embs):
138 | assert len(embs) == self.modal_num
139 |
140 | emb_list = []
141 |
142 | if embs[0] is not None:
143 | emb_list.append(self.linear_0(embs[0]).unsqueeze(1))
144 | if embs[1] is not None:
145 | emb_list.append(self.linear_1(embs[1]).unsqueeze(1))
146 | if embs[2] is not None:
147 | emb_list.append(self.linear_2(embs[2]).unsqueeze(1))
148 | if embs[3] is not None:
149 | emb_list.append(self.linear_3(embs[3]).unsqueeze(1))
150 | new_embs = torch.cat(emb_list, dim=1) # [n, 4, e]
151 | new_embs = self.LN_pre(new_embs)
152 | s = self.v(torch.tanh(self.linear(new_embs))) # [n, 4, 1]
153 | a = torch.softmax(s, dim=-1)
154 | joint_emb_1 = torch.matmul(a.transpose(-1, -2), new_embs).squeeze(1) # [n, e]
155 | joint_emb = joint_emb_1
156 | return joint_emb
157 |
158 |
159 | class IBMultiModal(nn.Module):
160 | def __init__(
161 | self,
162 | args,
163 | ent_num,
164 | img_feature_dim,
165 | char_feature_dim=None,
166 | use_project_head=False,
167 | ):
168 | super(IBMultiModal, self).__init__()
169 |
170 | self.args = args
171 | attr_dim = self.args.attr_dim
172 | img_dim = self.args.img_dim
173 | char_dim = self.args.char_dim
174 | dropout = self.args.dropout
175 | self.ENT_NUM = ent_num
176 | self.use_project_head = use_project_head
177 |
178 | self.n_units = [int(x) for x in self.args.hidden_units.strip().split(",")]
179 | self.n_heads = [int(x) for x in self.args.heads.strip().split(",")]
180 | self.input_dim = int(self.args.hidden_units.strip().split(",")[0])
181 | self.entity_emb = nn.Embedding(self.ENT_NUM, self.input_dim)
182 | nn.init.normal_(self.entity_emb.weight, std=1.0 / math.sqrt(self.ENT_NUM))
183 | self.entity_emb.requires_grad = True
184 | self.rel_fc = nn.Linear(1000, attr_dim)
185 | self.rel_fc_mu = nn.Linear(attr_dim, attr_dim)
186 | self.rel_fc_std = nn.Linear(attr_dim, attr_dim)
187 | self.rel_fc_d1 = nn.Linear(attr_dim, attr_dim)
188 | self.rel_fc_d2 = nn.Linear(attr_dim, attr_dim)
189 |
190 | self.att_fc = nn.Linear(1000, attr_dim)
191 | self.att_fc_mu = nn.Linear(attr_dim, attr_dim)
192 | self.att_fc_std = nn.Linear(attr_dim, attr_dim)
193 | self.att_fc_d1 = nn.Linear(attr_dim, attr_dim)
194 | self.att_fc_d2 = nn.Linear(attr_dim, attr_dim)
195 |
196 | self.img_fc = nn.Linear(img_feature_dim, img_dim)
197 | self.img_fc_mu = nn.Linear(img_dim, img_dim)
198 | self.img_fc_std = nn.Linear(img_dim, img_dim)
199 | self.img_fc_d1 = nn.Linear(img_dim, img_dim)
200 | self.img_fc_d2 = nn.Linear(img_dim, img_dim)
201 |
202 | joint_dim = 600
203 | self.joint_fc = nn.Linear(joint_dim, joint_dim)
204 | self.joint_fc_mu = nn.Linear(joint_dim, joint_dim)
205 | self.joint_fc_std = nn.Linear(joint_dim, joint_dim)
206 | self.joint_fc_d1 = nn.Linear(joint_dim, joint_dim)
207 | self.joint_fc_d2 = nn.Linear(joint_dim, joint_dim)
208 |
209 | use_graph_vib = self.args.use_graph_vib
210 | use_attr_vib = self.args.use_attr_vib
211 | use_img_vib = self.args.use_img_vib
212 | use_rel_vib = self.args.use_rel_vib
213 |
214 | no_diag = self.args.no_diag
215 |
216 | if no_diag:
217 | diag = False
218 | else:
219 | diag = True
220 |
221 | self.use_graph_vib = use_graph_vib
222 | self.use_attr_vib = use_attr_vib
223 | self.use_img_vib = use_img_vib
224 | self.use_rel_vib = use_rel_vib
225 | self.use_joint_vib = self.args.use_joint_vib
226 |
227 | self.name_fc = nn.Linear(300, char_dim)
228 | self.char_fc = nn.Linear(char_feature_dim, char_dim)
229 |
230 | self.kld_loss = 0
231 | self.gph_layer_norm_mu = nn.LayerNorm(self.input_dim, elementwise_affine=True)
232 | self.gph_layer_norm_std = nn.LayerNorm(self.input_dim, elementwise_affine=True)
233 | if self.args.structure_encoder == "gcn":
234 | if self.use_graph_vib:
235 | self.cross_graph_model_mu = GCN(
236 | self.n_units[0],
237 | self.n_units[1],
238 | self.n_units[2],
239 | dropout=self.args.dropout,
240 | )
241 | self.cross_graph_model_std = GCN(
242 | self.n_units[0],
243 | self.n_units[1],
244 | self.n_units[2],
245 | dropout=self.args.dropout,
246 | )
247 | else:
248 | self.cross_graph_model = GCN(
249 | self.n_units[0],
250 | self.n_units[1],
251 | self.n_units[2],
252 | dropout=self.args.dropout,
253 | )
254 | elif self.args.structure_encoder == "gat":
255 | if self.use_graph_vib:
256 | self.cross_graph_model_mu = GAT(
257 | n_units=self.n_units,
258 | n_heads=self.n_heads,
259 | dropout=args.dropout,
260 | attn_dropout=args.attn_dropout,
261 | instance_normalization=self.args.instance_normalization,
262 | diag=diag,
263 | )
264 | self.cross_graph_model_std = GAT(
265 | n_units=self.n_units,
266 | n_heads=self.n_heads,
267 | dropout=args.dropout,
268 | attn_dropout=args.attn_dropout,
269 | instance_normalization=self.args.instance_normalization,
270 | diag=diag,
271 | )
272 | else:
273 | self.cross_graph_model = GAT(
274 | n_units=self.n_units,
275 | n_heads=self.n_heads,
276 | dropout=args.dropout,
277 | attn_dropout=args.attn_dropout,
278 | instance_normalization=self.args.instance_normalization,
279 | diag=True,
280 | )
281 |
282 | if self.use_project_head:
283 | self.img_pro = ProjectionHead(img_dim, img_dim, img_dim, dropout)
284 | self.att_pro = ProjectionHead(attr_dim, attr_dim, attr_dim, dropout)
285 | self.rel_pro = ProjectionHead(attr_dim, attr_dim, attr_dim, dropout)
286 | self.gph_pro = ProjectionHead(
287 | self.n_units[2], self.n_units[2], self.n_units[2], dropout
288 | )
289 |
290 | if self.args.fusion_id == 1:
291 | self.fusion = MultiModalFusion(
292 | modal_num=self.args.inner_view_num, with_weight=self.args.with_weight
293 | )
294 | elif self.args.fusion_id == 2:
295 | self.fusion = MultiModalFusionNew_allmodal(
296 | modal_num=self.args.inner_view_num, with_weight=self.args.with_weight
297 | )
298 |
299 | def _kld_gauss(self, mu_1, logsigma_1, mu_2, logsigma_2):
300 | from torch.distributions.kl import kl_divergence
301 | from torch.distributions import Normal
302 |
303 | sigma_1 = torch.exp(0.1 + 0.9 * F.softplus(torch.clamp_max(logsigma_1, 0.4)))
304 | sigma_2 = torch.exp(0.1 + 0.9 * F.softplus(torch.clamp_max(logsigma_2, 0.4)))
305 | mu_1_fixed = mu_1.clone()
306 | sigma_1_fixed = sigma_1.clone()
307 | mu_1_fixed[torch.isnan(mu_1_fixed)] = 0
308 | mu_1_fixed[torch.isinf(mu_1_fixed)] = torch.max(
309 | mu_1_fixed[~torch.isinf(mu_1_fixed)]
310 | )
311 | sigma_1_fixed[torch.isnan(sigma_1_fixed)] = 1
312 | sigma_1_fixed[torch.isinf(sigma_1_fixed)] = torch.max(
313 | sigma_1_fixed[~torch.isinf(sigma_1_fixed)]
314 | )
315 | sigma_1_fixed[sigma_1_fixed <= 0] = 1
316 | q_target = Normal(mu_1_fixed, sigma_1_fixed)
317 |
318 | mu_2_fixed = mu_2.clone()
319 | sigma_2_fixed = sigma_2.clone()
320 | mu_2_fixed[torch.isnan(mu_2_fixed)] = 0
321 | mu_2_fixed[torch.isinf(mu_2_fixed)] = torch.max(
322 | mu_2_fixed[~torch.isinf(mu_2_fixed)]
323 | )
324 | sigma_2_fixed[torch.isnan(sigma_2_fixed)] = 1
325 | sigma_2_fixed[torch.isinf(sigma_2_fixed)] = torch.max(
326 | sigma_2_fixed[~torch.isinf(sigma_2_fixed)]
327 | )
328 | sigma_2_fixed[sigma_2_fixed <= 0] = 1
329 | q_context = Normal(mu_2_fixed, sigma_2_fixed)
330 | kl = kl_divergence(q_target, q_context).mean(dim=0).sum()
331 | return kl
332 |
333 | def forward(
334 | self,
335 | input_idx,
336 | adj,
337 | img_features=None,
338 | rel_features=None,
339 | att_features=None,
340 | name_features=None,
341 | char_features=None,
342 | ):
343 |
344 | if self.args.w_gcn:
345 | if self.use_graph_vib:
346 | if self.args.structure_encoder == "gat":
347 | gph_emb_mu = self.cross_graph_model_mu(
348 | self.entity_emb(input_idx), adj
349 | )
350 | mu = self.gph_layer_norm_mu(gph_emb_mu)
351 |
352 | gph_emb_std = self.cross_graph_model_std(
353 | self.entity_emb(input_idx), adj
354 | )
355 | std = F.elu(gph_emb_std)
356 | eps = torch.randn_like(std)
357 | gph_emb = mu + eps * std
358 | gph_kld_loss = self._kld_gauss(
359 | mu, std, torch.zeros_like(mu), torch.ones_like(std)
360 | )
361 | self.kld_loss = gph_kld_loss
362 | else:
363 | mu = self.cross_graph_model_mu(self.entity_emb(input_idx), adj)
364 | logstd = self.cross_graph_model_mu(self.entity_emb(input_idx), adj)
365 | eps = torch.randn_like(mu)
366 | gph_emb = mu + eps * torch.exp(logstd)
367 | gph_kld_loss = self._kld_gauss(
368 | mu, logstd, torch.zeros_like(mu), torch.ones_like(logstd)
369 | )
370 | self.kld_loss = gph_kld_loss
371 | else:
372 | gph_emb = self.cross_graph_model(self.entity_emb(input_idx), adj)
373 | else:
374 | gph_emb = None
375 | if self.args.w_img:
376 | if self.use_img_vib:
377 | img_emb = self.img_fc(img_features)
378 | img_emb_h = F.relu(img_emb)
379 | mu = self.img_fc_mu(img_emb_h)
380 | logvar = self.img_fc_std(img_emb_h)
381 | std = torch.exp(0.5 * logvar)
382 | eps = torch.rand_like(std)
383 | img_emb = mu + eps * std
384 | img_kld_loss = self._kld_gauss(
385 | mu, std, torch.zeros_like(mu), torch.ones_like(std)
386 | )
387 | self.img_kld_loss = img_kld_loss
388 | else:
389 | img_emb = self.img_fc(img_features)
390 | else:
391 | img_emb = None
392 | if self.args.w_rel:
393 | if self.use_rel_vib:
394 | rel_emb = self.rel_fc(rel_features)
395 | rel_emb_h = F.relu(rel_emb)
396 | mu = self.rel_fc_mu(rel_emb_h)
397 | logvar = self.rel_fc_std(rel_emb_h)
398 | std = torch.exp(0.5 * logvar)
399 | eps = torch.rand_like(std)
400 | rel_emb = mu + eps * std
401 | rel_kld_loss = self._kld_gauss(
402 | mu, std, torch.zeros_like(mu), torch.ones_like(std)
403 | )
404 | self.rel_kld_loss = rel_kld_loss
405 | else:
406 | rel_emb = self.rel_fc(rel_features)
407 | else:
408 | rel_emb = None
409 | if self.args.w_attr:
410 | if self.use_attr_vib:
411 | att_emb = self.att_fc(att_features)
412 | att_emb_h = F.relu(att_emb)
413 | mu = self.att_fc_mu(att_emb_h)
414 | logvar = self.att_fc_std(att_emb_h)
415 | std = torch.exp(0.5 * logvar)
416 | eps = torch.rand_like(std)
417 | att_emb = mu + eps * std
418 | attr_kld_loss = self._kld_gauss(
419 | mu, std, torch.zeros_like(mu), torch.ones_like(std)
420 | )
421 | self.attr_kld_loss = attr_kld_loss
422 | else:
423 | att_emb = self.att_fc(att_features)
424 | else:
425 | att_emb = None
426 |
427 | if self.args.w_name:
428 | name_emb = self.name_fc(name_features)
429 | else:
430 | name_emb = None
431 | if self.args.w_char:
432 | char_emb = self.char_fc(char_features)
433 | else:
434 | char_emb = None
435 |
436 | if self.use_project_head:
437 | gph_emb = self.gph_pro(gph_emb)
438 | img_emb = self.img_pro(img_emb)
439 | rel_emb = self.rel_pro(rel_emb)
440 | att_emb = self.att_pro(att_emb)
441 | pass
442 |
443 | joint_emb = self.fusion(
444 | [img_emb, att_emb, rel_emb, gph_emb, name_emb, char_emb]
445 | )
446 |
447 | if self.use_joint_vib:
448 | joint_emb = self.joint_fc(joint_emb)
449 | joint_emb_h = F.relu(joint_emb)
450 | mu = self.joint_fc_mu(joint_emb_h)
451 | logvar = self.joint_fc_std(joint_emb_h)
452 | std = torch.exp(0.5 * logvar)
453 | eps = torch.rand_like(std)
454 | joint_emb = mu + eps * std
455 | joint_kld_loss = self._kld_gauss(
456 | mu, std, torch.zeros_like(mu), torch.ones_like(std)
457 | )
458 | self.joint_kld_loss = joint_kld_loss
459 |
460 | return gph_emb, img_emb, rel_emb, att_emb, name_emb, char_emb, joint_emb
461 |
--------------------------------------------------------------------------------
/src/run.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | from __future__ import absolute_import
4 | from __future__ import unicode_literals
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import argparse
9 | from pprint import pprint
10 | from transformers import (
11 | get_cosine_schedule_with_warmup,
12 | )
13 | import torch.optim as optim
14 |
15 | try:
16 | from utils import *
17 | from models import *
18 | from Load import *
19 | from loss import *
20 | except:
21 | from src.utils import *
22 | from src.models import *
23 | from src.Load import *
24 | from src.loss import *
25 |
26 |
27 | def load_img_features(ent_num, file_dir, triples, use_mean_img=False):
28 | if "V1" in file_dir:
29 | split = "norm"
30 | img_vec_path = "data/pkls/dbpedia_wikidata_15k_norm_GA_id_img_feature_dict.pkl"
31 | elif "V2" in file_dir:
32 | split = "dense"
33 | img_vec_path = "data/pkls/dbpedia_wikidata_15k_dense_GA_id_img_feature_dict.pkl"
34 | elif "FB15K" in file_dir:
35 | filename = os.path.split(file_dir)[-1].upper()
36 | img_vec_path = (
37 | "data/mmkb-datasets/"
38 | + filename
39 | + "/"
40 | + filename
41 | + "_id_img_feature_dict.pkl"
42 | )
43 | else:
44 | split = file_dir.split("/")[-1]
45 | img_vec_path = "data/pkls/" + split + "_GA_id_img_feature_dict.pkl"
46 | if use_mean_img:
47 | img_features = load_img(ent_num, img_vec_path)
48 | else:
49 | img_features = load_img_new(ent_num, img_vec_path, triples)
50 | return img_features
51 |
52 |
53 | def load_img_features_dropout(
54 | ent_num, file_dir, triples, use_mean_img=False, img_dp_ratio=1.0
55 | ):
56 | if "FB15K" in file_dir:
57 | filename = os.path.split(file_dir)[-1].upper()
58 | if abs(1.0 - img_dp_ratio) > 1e-3:
59 | img_vec_path = (
60 | "data/mmkb-datasets/"
61 | + "mmkb_dropout"
62 | + "/"
63 | + filename
64 | + "_id_img_feature_dict_with_dropout{0}.pkl".format(img_dp_ratio)
65 | )
66 | print("dropout img_vec_path: ", img_vec_path)
67 | else:
68 | img_vec_path = (
69 | "data/mmkb-datasets/"
70 | + filename
71 | + "/"
72 | + filename
73 | + "_id_img_feature_dict.pkl"
74 | )
75 | else:
76 | split = file_dir.split("/")[-1]
77 | if abs(1.0 - img_dp_ratio) > 1e-3:
78 | img_vec_path = (
79 | "data/pkls/dbp_dropout/"
80 | + split
81 | + "_GA_id_img_feature_dict_with_dropout{0}.pkl".format(img_dp_ratio)
82 | )
83 | else:
84 | img_vec_path = "data/pkls/" + split + "_GA_id_img_feature_dict.pkl"
85 | if use_mean_img:
86 | img_features = load_img(ent_num, img_vec_path)
87 | else:
88 | img_features = load_img_new(ent_num, img_vec_path, triples)
89 | return img_features
90 |
91 |
92 | class IBMEA:
93 |
94 | def __init__(self):
95 |
96 | self.ent2id_dict = None
97 | self.ills = None
98 | self.triples = None
99 | self.r_hs = None
100 | self.r_ts = None
101 | self.ids = None
102 | self.left_ents = None
103 | self.right_ents = None
104 |
105 | self.img_features = None
106 | self.rel_features = None
107 | self.att_features = None
108 | self.char_features = None
109 | self.name_features = None
110 | self.ent_vec = None
111 | self.left_non_train = None
112 | self.right_non_train = None
113 | self.ENT_NUM = None
114 | self.REL_NUM = None
115 | self.adj = None
116 | self.train_ill = None
117 | self.test_ill_ = None
118 | self.test_ill = None
119 | self.test_left = None
120 | self.test_right = None
121 | self.multimodal_encoder = None
122 | self.weight_raw = None
123 | self.rel_fc = None
124 | self.att_fc = None
125 | self.img_fc = None
126 | self.char_fc = None
127 | self.shared_fc = None
128 | self.gcn_pro = None
129 | self.rel_pro = None
130 | self.attr_pro = None
131 | self.img_pro = None
132 | self.input_dim = None
133 | self.entity_emb = None
134 | self.input_idx = None
135 | self.n_units = None
136 | self.n_heads = None
137 | self.cross_graph_model = None
138 | self.params = None
139 | self.optimizer = None
140 | self.fusion = None
141 | self.parser = argparse.ArgumentParser()
142 | self.args = self.parse_options(self.parser)
143 | self.set_seed(self.args.seed, self.args.cuda)
144 | self.device = torch.device(
145 | "cuda" if self.args.cuda and torch.cuda.is_available() else "cpu"
146 | )
147 | self.init_data()
148 | self.init_model()
149 | self.print_summary()
150 | self.best_hit_1 = 0.0
151 | self.best_epoch = 0
152 | self.best_data_list = []
153 | self.best_to_write = []
154 |
155 | @staticmethod
156 | def parse_options(parser):
157 | parser.add_argument(
158 | "--file_dir",
159 | type=str,
160 | default="data/DBP15K/zh_en",
161 | required=False,
162 | help="input dataset file directory, ('data/DBP15K/zh_en', 'data/DWY100K/dbp_wd')",
163 | )
164 | parser.add_argument("--rate", type=float, default=0.3, help="training set rate")
165 |
166 | parser.add_argument(
167 | "--cuda",
168 | action="store_true",
169 | default=True,
170 | help="whether to use cuda or not",
171 | )
172 | parser.add_argument("--seed", type=int, default=2021, help="random seed")
173 | parser.add_argument(
174 | "--epochs", type=int, default=1000, help="number of epochs to train"
175 | )
176 | parser.add_argument("--check_point", type=int, default=100, help="check point")
177 | parser.add_argument(
178 | "--hidden_units",
179 | type=str,
180 | default="128,128,128",
181 | help="hidden units in each hidden layer(including in_dim and out_dim), splitted with comma",
182 | )
183 | parser.add_argument(
184 | "--heads",
185 | type=str,
186 | default="2,2",
187 | help="heads in each gat layer, splitted with comma",
188 | )
189 | parser.add_argument(
190 | "--instance_normalization",
191 | action="store_true",
192 | default=False,
193 | help="enable instance normalization",
194 | )
195 | parser.add_argument(
196 | "--lr", type=float, default=0.005, help="initial learning rate"
197 | )
198 | parser.add_argument(
199 | "--weight_decay",
200 | type=float,
201 | default=1e-2,
202 | help="weight decay (L2 loss on parameters)",
203 | )
204 | parser.add_argument(
205 | "--dropout", type=float, default=0.0, help="dropout rate for layers"
206 | )
207 | parser.add_argument(
208 | "--attn_dropout",
209 | type=float,
210 | default=0.0,
211 | help="dropout rate for gat layers",
212 | )
213 | parser.add_argument(
214 | "--dist", type=int, default=2, help="L1 distance or L2 distance. ('1', '2')"
215 | )
216 | parser.add_argument(
217 | "--csls", action="store_true", default=False, help="use CSLS for inference"
218 | )
219 | parser.add_argument("--csls_k", type=int, default=10, help="top k for csls")
220 | parser.add_argument(
221 | "--il", action="store_true", default=False, help="Iterative learning?"
222 | )
223 | parser.add_argument(
224 | "--semi_learn_step",
225 | type=int,
226 | default=10,
227 | help="If IL, what's the update step?",
228 | )
229 | parser.add_argument(
230 | "--il_start", type=int, default=500, help="If Il, when to start?"
231 | )
232 | parser.add_argument("--bsize", type=int, default=7500, help="batch size")
233 | parser.add_argument(
234 | "--alpha", type=float, default=0.2, help="the margin of InfoMaxNCE loss"
235 | )
236 | parser.add_argument(
237 | "--with_weight",
238 | type=int,
239 | default=1,
240 | help="Whether to weight the fusion of different " "modal features",
241 | )
242 | parser.add_argument(
243 | "--structure_encoder",
244 | type=str,
245 | default="gat",
246 | help="the encoder of structure view, " "[gcn|gat]",
247 | )
248 |
249 | parser.add_argument(
250 | "--projection",
251 | action="store_true",
252 | default=False,
253 | help="add projection for model",
254 | )
255 |
256 | parser.add_argument(
257 | "--attr_dim",
258 | type=int,
259 | default=100,
260 | help="the hidden size of attr and rel features",
261 | )
262 | parser.add_argument(
263 | "--img_dim", type=int, default=100, help="the hidden size of img feature"
264 | )
265 | parser.add_argument(
266 | "--name_dim", type=int, default=100, help="the hidden size of name feature"
267 | )
268 | parser.add_argument(
269 | "--char_dim", type=int, default=100, help="the hidden size of char feature"
270 | )
271 |
272 | parser.add_argument(
273 | "--w_gcn", action="store_false", default=True, help="with gcn features"
274 | )
275 | parser.add_argument(
276 | "--w_rel", action="store_false", default=True, help="with rel features"
277 | )
278 | parser.add_argument(
279 | "--w_attr", action="store_false", default=True, help="with attr features"
280 | )
281 | parser.add_argument(
282 | "--w_name", action="store_false", default=True, help="with name features"
283 | )
284 | parser.add_argument(
285 | "--w_char", action="store_false", default=True, help="with char features"
286 | )
287 | parser.add_argument(
288 | "--w_img", action="store_false", default=True, help="with img features"
289 | )
290 |
291 | parser.add_argument(
292 | "--no_diag", action="store_true", default=False, help="GAT use diag"
293 | )
294 | parser.add_argument(
295 | "--use_joint_vib", action="store_true", default=False, help="use_joint_vib"
296 | )
297 |
298 | parser.add_argument(
299 | "--use_graph_vib", action="store_false", default=True, help="use_graph_vib"
300 | )
301 | parser.add_argument(
302 | "--use_attr_vib", action="store_false", default=True, help="use_attr_vib"
303 | )
304 | parser.add_argument(
305 | "--use_img_vib", action="store_false", default=True, help="use_img_vib"
306 | )
307 | parser.add_argument(
308 | "--use_rel_vib", action="store_false", default=True, help="use_rel_vib"
309 | )
310 | parser.add_argument(
311 | "--modify_ms", action="store_true", default=False, help="modify_ms"
312 | )
313 | parser.add_argument("--ms_alpha", type=float, default=0.1, help="ms scale_pos")
314 | parser.add_argument("--ms_beta", type=float, default=40.0, help="ms scale_neg")
315 | parser.add_argument("--ms_base", type=float, default=0.5, help="ms base")
316 |
317 | parser.add_argument("--Beta_g", type=float, default=0.001, help="graph beta")
318 | parser.add_argument("--Beta_i", type=float, default=0.001, help="graph beta")
319 | parser.add_argument("--Beta_r", type=float, default=0.01, help="graph beta")
320 | parser.add_argument("--Beta_a", type=float, default=0.001, help="graph beta")
321 | parser.add_argument(
322 | "--inner_view_num", type=int, default=6, help="the number of inner view"
323 | )
324 |
325 | parser.add_argument(
326 | "--word_embedding",
327 | type=str,
328 | default="glove",
329 | help="the type of word embedding, " "[glove|fasttext]",
330 | )
331 | parser.add_argument(
332 | "--use_project_head",
333 | action="store_true",
334 | default=False,
335 | help="use projection head",
336 | )
337 |
338 | parser.add_argument(
339 | "--zoom", type=float, default=0.1, help="narrow the range of losses"
340 | )
341 | parser.add_argument(
342 | "--save_path", type=str, default="save_pkl", help="save path"
343 | )
344 | parser.add_argument(
345 | "--pred_name", type=str, default="pred.txt", help="pred name"
346 | )
347 |
348 | parser.add_argument(
349 | "--hidden_size", type=int, default=300, help="the hidden size of MEAformer"
350 | )
351 | parser.add_argument(
352 | "--intermediate_size",
353 | type=int,
354 | default=400,
355 | help="the hidden size of MEAformer",
356 | )
357 | parser.add_argument(
358 | "--num_attention_heads",
359 | type=int,
360 | default=1,
361 | help="the number of attention_heads of MEAformer",
362 | )
363 | parser.add_argument(
364 | "--num_hidden_layers",
365 | type=int,
366 | default=1,
367 | help="the number of hidden_layers of MEAformer",
368 | )
369 | parser.add_argument("--position_embedding_type", default="absolute", type=str)
370 | parser.add_argument(
371 | "--use_intermediate",
372 | type=int,
373 | default=0,
374 | help="whether to use_intermediate",
375 | )
376 | parser.add_argument(
377 | "--replay", type=int, default=0, help="whether to use replay strategy"
378 | )
379 | parser.add_argument(
380 | "--neg_cross_kg",
381 | type=int,
382 | default=0,
383 | help="whether to force the negative samples in the opposite KG",
384 | )
385 |
386 | parser.add_argument(
387 | "--tau",
388 | type=float,
389 | default=0.1,
390 | help="the temperature factor of contrastive loss",
391 | )
392 | parser.add_argument(
393 | "--ab_weight", type=float, default=0.5, help="the weight of NTXent Loss"
394 | )
395 | parser.add_argument(
396 | "--use_icl", action="store_true", default=False, help="use_icl"
397 | )
398 | parser.add_argument(
399 | "--use_bce", action="store_true", default=False, help="use_bce"
400 | )
401 | parser.add_argument(
402 | "--use_bce_new", action="store_true", default=False, help="use_bce_new"
403 | )
404 | parser.add_argument(
405 | "--use_ms", action="store_true", default=False, help="use_ms"
406 | )
407 | parser.add_argument(
408 | "--use_nt", action="store_true", default=False, help="use_nt"
409 | )
410 | parser.add_argument(
411 | "--use_nce", action="store_true", default=False, help="use_nce"
412 | )
413 | parser.add_argument(
414 | "--use_nce_new", action="store_true", default=False, help="use_nce_new"
415 | )
416 |
417 | parser.add_argument(
418 | "--use_sheduler", action="store_true", default=False, help="use_sheduler"
419 | )
420 | parser.add_argument(
421 | "--sheduler_gamma", type=float, default=0.98, help="sheduler_gamma"
422 | )
423 |
424 | parser.add_argument("--joint_beta", type=float, default=1.0, help="joint_beta")
425 |
426 | parser.add_argument(
427 | "--use_sheduler_cos",
428 | action="store_true",
429 | default=False,
430 | help="use_sheduler_cos",
431 | )
432 | parser.add_argument(
433 | "--num_warmup_steps", type=int, default=200, help="num_warmup_steps"
434 | )
435 | parser.add_argument(
436 | "--num_training_steps", type=int, default=1000, help="num_training_steps"
437 | )
438 |
439 | parser.add_argument(
440 | "--use_mean_img", action="store_true", default=False, help="use_mean_img"
441 | )
442 | parser.add_argument("--awloss", type=int, default=0, help="whether to use awl")
443 | parser.add_argument("--mtloss", type=int, default=0, help="whether to use awl")
444 |
445 | parser.add_argument(
446 | "--graph_use_icl", type=int, default=1, help="graph_use_icl"
447 | )
448 | parser.add_argument(
449 | "--graph_use_bce", type=int, default=1, help="graph_use_bce"
450 | )
451 | parser.add_argument("--graph_use_ms", type=int, default=1, help="graph_use_ms")
452 |
453 | parser.add_argument("--img_use_icl", type=int, default=1, help="img_use_icl")
454 | parser.add_argument("--img_use_bce", type=int, default=1, help="img_use_bce")
455 | parser.add_argument("--img_use_ms", type=int, default=1, help="img_use_ms")
456 |
457 | parser.add_argument("--attr_use_icl", type=int, default=1, help="attr_use_icl")
458 | parser.add_argument("--attr_use_bce", type=int, default=1, help="attr_use_bce")
459 | parser.add_argument("--attr_use_ms", type=int, default=1, help="attr_use_ms")
460 |
461 | parser.add_argument("--rel_use_icl", type=int, default=1, help="rel_use_icl")
462 | parser.add_argument("--rel_use_bce", type=int, default=1, help="rel_use_bce")
463 | parser.add_argument("--rel_use_ms", type=int, default=1, help="rel_use_ms")
464 |
465 | parser.add_argument(
466 | "--joint_use_icl", action="store_true", default=False, help="joint_use_icl"
467 | )
468 | parser.add_argument(
469 | "--joint_use_bce", action="store_true", default=False, help="joint_use_bce"
470 | )
471 | parser.add_argument(
472 | "--joint_use_ms", action="store_true", default=False, help="joint_use_ms"
473 | )
474 |
475 | parser.add_argument(
476 | "--img_dp_ratio", type=float, default=1.0, help="image dropout ratio"
477 | )
478 | parser.add_argument(
479 | "--fusion_id", type=int, default=1, help="default weight fusion"
480 | )
481 | return parser.parse_args()
482 |
483 | @staticmethod
484 | def set_seed(seed, cuda=True):
485 | random.seed(seed)
486 | np.random.seed(seed)
487 | torch.manual_seed(seed)
488 | if cuda and torch.cuda.is_available():
489 | torch.cuda.manual_seed(seed)
490 |
491 | def print_summary(self):
492 | print("-----dataset summary-----")
493 | print("dataset:\t", self.args.file_dir)
494 | print("triple num:\t", len(self.triples))
495 | print("entity num:\t", self.ENT_NUM)
496 | print("relation num:\t", self.REL_NUM)
497 | print(
498 | "train ill num:\t",
499 | self.train_ill.shape[0],
500 | "\ttest ill num:\t",
501 | self.test_ill.shape[0],
502 | )
503 | print("-------------------------")
504 |
505 | def init_data(self):
506 | # Load data
507 | lang_list = [1, 2]
508 | file_dir = self.args.file_dir
509 | device = self.device
510 |
511 | self.ent2id_dict, self.ills, self.triples, self.r_hs, self.r_ts, self.ids = (
512 | read_raw_data(file_dir, lang_list)
513 | )
514 | e1 = os.path.join(file_dir, "ent_ids_1")
515 | e2 = os.path.join(file_dir, "ent_ids_2")
516 | self.left_ents = get_ids(e1)
517 | self.right_ents = get_ids(e2)
518 |
519 | self.ENT_NUM = len(self.ent2id_dict)
520 | self.REL_NUM = len(self.r_hs)
521 | print("total ent num: {}, rel num: {}".format(self.ENT_NUM, self.REL_NUM))
522 |
523 | np.random.shuffle(self.ills)
524 |
525 | if abs(1.0 - self.args.img_dp_ratio) > 1e-3:
526 | self.img_features = load_img_features_dropout(
527 | self.ENT_NUM,
528 | file_dir,
529 | self.triples,
530 | self.args.use_mean_img,
531 | self.args.img_dp_ratio,
532 | )
533 | else:
534 | self.img_features = load_img_features(
535 | self.ENT_NUM, file_dir, self.triples, self.args.use_mean_img
536 | )
537 | self.img_features = F.normalize(torch.Tensor(self.img_features).to(device))
538 | print("image feature shape:", self.img_features.shape)
539 | data_dir, dataname = os.path.split(file_dir)
540 | if self.args.word_embedding == "glove":
541 | word2vec_path = "data/embedding/glove.6B.300d.txt"
542 | elif self.args.word_embedding == "fasttext":
543 | pass
544 | else:
545 | raise Exception("error word embedding")
546 |
547 | if "DBP15K" in file_dir:
548 | if self.args.w_name or self.args.w_char:
549 | print("name feature shape:", self.name_features.shape)
550 | print("char feature shape:", self.char_features.shape)
551 | pass
552 |
553 | self.train_ill = np.array(
554 | self.ills[: int(len(self.ills) // 1 * self.args.rate)], dtype=np.int32
555 | )
556 | self.test_ill_ = self.ills[int(len(self.ills) // 1 * self.args.rate) :]
557 | self.test_ill = np.array(self.test_ill_, dtype=np.int32)
558 |
559 | self.test_left = torch.LongTensor(self.test_ill[:, 0].squeeze()).to(device)
560 | self.test_right = torch.LongTensor(self.test_ill[:, 1].squeeze()).to(device)
561 |
562 | self.left_non_train = list(
563 | set(self.left_ents) - set(self.train_ill[:, 0].tolist())
564 | )
565 | self.right_non_train = list(
566 | set(self.right_ents) - set(self.train_ill[:, 1].tolist())
567 | )
568 |
569 | print(
570 | "#left entity : %d, #right entity: %d"
571 | % (len(self.left_ents), len(self.right_ents))
572 | )
573 | print(
574 | "#left entity not in train set: %d, #right entity not in train set: %d"
575 | % (len(self.left_non_train), len(self.right_non_train))
576 | )
577 | self.rel_features = load_relation(self.ENT_NUM, self.triples, 1000)
578 | self.rel_features = torch.Tensor(self.rel_features).to(device)
579 | print("relation feature shape:", self.rel_features.shape)
580 | a1 = os.path.join(file_dir, "training_attrs_1")
581 | a2 = os.path.join(file_dir, "training_attrs_2")
582 | self.att_features = load_attr(
583 | [a1, a2], self.ENT_NUM, self.ent2id_dict, 1000
584 | )
585 | self.att_features = torch.Tensor(self.att_features).to(device)
586 | print("attribute feature shape:", self.att_features.shape)
587 |
588 | self.adj = get_adjr(
589 | self.ENT_NUM, self.triples, norm=True
590 | )
591 | self.adj = self.adj.to(self.device)
592 |
593 | def init_model(self):
594 | img_dim = self.img_features.shape[1]
595 | char_dim = (
596 | self.char_features.shape[1] if self.char_features is not None else 100
597 | )
598 |
599 | self.multimodal_encoder = IBMultiModal(
600 | args=self.args,
601 | ent_num=self.ENT_NUM,
602 | img_feature_dim=img_dim,
603 | char_feature_dim=char_dim,
604 | use_project_head=self.args.use_project_head,
605 | ).to(self.device)
606 |
607 |
608 | self.params = [{"params": list(self.multimodal_encoder.parameters())}]
609 | total_params = sum(
610 | p.numel()
611 | for p in self.multimodal_encoder.parameters()
612 | if p.requires_grad
613 | )
614 |
615 | self.optimizer = optim.AdamW(
616 | self.params, lr=self.args.lr, weight_decay=self.args.weight_decay
617 | )
618 |
619 | if self.args.use_sheduler:
620 | self.scheduler = optim.lr_scheduler.ExponentialLR(
621 | optimizer=self.optimizer, gamma=self.args.sheduler_gamma
622 | )
623 | elif self.args.use_sheduler_cos:
624 | self.scheduler = get_cosine_schedule_with_warmup(
625 | optimizer=self.optimizer,
626 | num_warmup_steps=self.args.num_warmup_steps,
627 | num_training_steps=self.args.num_training_steps,
628 | )
629 | ms_alpha = self.args.ms_alpha
630 | ms_beta = self.args.ms_beta
631 | ms_base = self.args.ms_base
632 | self.criterion_ms = MsLoss(
633 | device=self.device, thresh=ms_base, scale_pos=ms_alpha, scale_neg=ms_beta
634 | )
635 | self.criterion_nce = InfoNCE_loss(device=self.device, temperature=self.args.tau)
636 |
637 | def semi_supervised_learning(self):
638 | with torch.no_grad():
639 | gph_emb, img_emb, rel_emb, att_emb, name_emb, char_emb, joint_emb = (
640 | self.multimodal_encoder(
641 | self.input_idx,
642 | self.adj,
643 | self.img_features,
644 | self.rel_features,
645 | self.att_features,
646 | self.name_features,
647 | self.char_features,
648 | )
649 | )
650 |
651 | final_emb = F.normalize(joint_emb)
652 | distance_list = []
653 | for i in np.arange(0, len(self.left_non_train), 1000):
654 | d = pairwise_distances(
655 | final_emb[self.left_non_train[i : i + 1000]],
656 | final_emb[self.right_non_train],
657 | )
658 | distance_list.append(d)
659 | distance = torch.cat(distance_list, dim=0)
660 | preds_l = torch.argmin(distance, dim=1).cpu().numpy().tolist()
661 | preds_r = torch.argmin(distance.t(), dim=1).cpu().numpy().tolist()
662 | del distance_list, distance, final_emb
663 | del gph_emb, img_emb, rel_emb, att_emb, name_emb, char_emb, joint_emb
664 | return preds_l, preds_r
665 |
666 | def train(self):
667 | print("model config is:")
668 | pprint(self.args, indent=2)
669 | print("[start training...] ")
670 | t_total = time.time()
671 | new_links = []
672 | epoch_KE, epoch_CG = 0, 0
673 |
674 | bsize = self.args.bsize
675 | device = self.device
676 |
677 | self.input_idx = torch.LongTensor(np.arange(self.ENT_NUM)).to(device)
678 |
679 | use_graph_vib = self.args.use_graph_vib
680 | use_attr_vib = self.args.use_attr_vib
681 | use_img_vib = self.args.use_img_vib
682 | use_rel_vib = self.args.use_rel_vib
683 | use_joint_vib = self.args.use_joint_vib
684 | use_bce = self.args.use_bce
685 | use_bce_new = self.args.use_bce_new
686 | use_icl = self.args.use_icl
687 | use_ms = self.args.use_ms
688 | use_nt = self.args.use_nt
689 | use_nce = self.args.use_nce
690 | use_nce_new = self.args.use_nce_new
691 |
692 | joint_use_icl = self.args.joint_use_icl
693 | joint_use_bce = self.args.joint_use_bce
694 | joint_use_ms = self.args.joint_use_ms
695 |
696 | graph_use_icl = self.args.graph_use_icl
697 | graph_use_bce = self.args.graph_use_bce
698 | graph_use_ms = self.args.graph_use_ms
699 |
700 | img_use_icl = self.args.img_use_icl
701 | img_use_bce = self.args.img_use_bce
702 | img_use_ms = self.args.img_use_ms
703 |
704 | attr_use_icl = self.args.attr_use_icl
705 | attr_use_bce = self.args.attr_use_bce
706 | attr_use_ms = self.args.attr_use_ms
707 |
708 | rel_use_icl = self.args.rel_use_icl
709 | rel_use_bce = self.args.rel_use_bce
710 | rel_use_ms = self.args.rel_use_ms
711 |
712 | Beta_g = self.args.Beta_g
713 | Beta_i = self.args.Beta_i
714 | Beta_r = self.args.Beta_r
715 | Beta_a = self.args.Beta_a
716 |
717 | for epoch in range(self.args.epochs):
718 | t_epoch = time.time()
719 | self.multimodal_encoder.train()
720 | self.optimizer.zero_grad()
721 |
722 | gph_emb, img_emb, rel_emb, att_emb, name_emb, char_emb, joint_emb = (
723 | self.multimodal_encoder(
724 | self.input_idx,
725 | self.adj,
726 | self.img_features,
727 | self.rel_features,
728 | self.att_features,
729 | self.name_features,
730 | self.char_features,
731 | )
732 | )
733 | epoch_CG += 1
734 | np.random.shuffle(self.train_ill)
735 | print("train_ill length:", len(self.train_ill))
736 | for si in np.arange(0, self.train_ill.shape[0], bsize):
737 | loss_all = 0
738 | Beta = 0.001
739 | print("[epoch {:d}] ".format(epoch), end="")
740 | loss_list = []
741 | if self.args.w_gcn:
742 | if use_graph_vib:
743 |
744 | gph_bce_loss = self.criterion_nce(
745 | gph_emb, self.train_ill[si : si + bsize]
746 | )
747 | gph_kld_loss = self.multimodal_encoder.kld_loss
748 | loss_G = gph_bce_loss + Beta_g * gph_kld_loss
749 | print(
750 | " gph_bce_loss: {:f}, gph_kld_loss: {:f}, Beta_g: {:f}".format(
751 | gph_bce_loss, gph_kld_loss, Beta_g
752 | ),
753 | end="",
754 | )
755 | loss_list.append(loss_G)
756 |
757 | if self.args.w_img:
758 | if use_img_vib:
759 |
760 | img_bce_loss = self.criterion_nce(
761 | img_emb, self.train_ill[si : si + bsize]
762 | )
763 |
764 | img_kld_loss = self.multimodal_encoder.img_kld_loss
765 | loss_I = img_bce_loss + Beta_i * img_kld_loss
766 | print(
767 | " img_bce_loss: {:f}, img_kld_loss: {:f}, Beta_i: {:f}".format(
768 | img_bce_loss, img_kld_loss, Beta_i
769 | ),
770 | end="",
771 | )
772 | loss_list.append(loss_I)
773 |
774 | if self.args.w_rel:
775 | if use_rel_vib:
776 | rel_bce_loss = self.criterion_nce(
777 | rel_emb, self.train_ill[si : si + bsize]
778 | )
779 | rel_kld_loss = self.multimodal_encoder.rel_kld_loss
780 | loss_R = rel_bce_loss + Beta_r * rel_kld_loss
781 | print(
782 | " rel_bce_loss: {:f}, rel_kld_loss: {:f}, Beta_r: {:f}".format(
783 | rel_bce_loss, rel_kld_loss, Beta_r
784 | ),
785 | end="",
786 | )
787 | loss_list.append(loss_R)
788 |
789 | if self.args.w_attr:
790 | if use_attr_vib:
791 | attr_bce_loss = self.criterion_nce(
792 | att_emb, self.train_ill[si : si + bsize]
793 | )
794 | attr_kld_loss = self.multimodal_encoder.attr_kld_loss
795 | loss_A = attr_bce_loss + Beta_a * attr_kld_loss
796 | print(
797 | " attr_bce_loss: {:f}, attr_kld_loss: {:f}, Beta_a: {:f}".format(
798 | attr_bce_loss, attr_kld_loss, Beta_a
799 | ),
800 | end="",
801 | )
802 | loss_list.append(loss_A)
803 |
804 | if use_joint_vib:
805 | pass
806 | else:
807 | joint_ms_loss = self.criterion_ms(
808 | joint_emb, self.train_ill[si : si + bsize]
809 | )
810 | loss_J = joint_ms_loss * self.args.joint_beta
811 | print(
812 | " joint_ms_loss: {:f}, joint_beta: {:f}".format(
813 | joint_ms_loss, self.args.joint_beta
814 | ),
815 | end="",
816 | )
817 | loss_list.append(loss_J)
818 | loss_all = sum(loss_list)
819 |
820 | print(" loss_all: {:f},".format(loss_all))
821 | loss_all.backward(retain_graph=True)
822 | torch.nn.utils.clip_grad_norm_(
823 | parameters=self.multimodal_encoder.parameters(),
824 | max_norm=0.1,
825 | norm_type=2,
826 | )
827 |
828 | self.optimizer.step()
829 | if self.args.use_sheduler and epoch > 400:
830 | self.scheduler.step()
831 | if self.args.use_sheduler_cos:
832 | self.scheduler.step()
833 | if (
834 | epoch >= self.args.il_start
835 | and (epoch + 1) % self.args.semi_learn_step == 0
836 | and self.args.il
837 | ):
838 | preds_l, preds_r = self.semi_supervised_learning()
839 | if (epoch + 1) % (
840 | self.args.semi_learn_step * 10
841 | ) == self.args.semi_learn_step:
842 | new_links = [
843 | (self.left_non_train[i], self.right_non_train[p])
844 | for i, p in enumerate(preds_l)
845 | if preds_r[p] == i
846 | ]
847 | else:
848 | new_links = [
849 | (self.left_non_train[i], self.right_non_train[p])
850 | for i, p in enumerate(preds_l)
851 | if (preds_r[p] == i)
852 | and (
853 | (self.left_non_train[i], self.right_non_train[p])
854 | in new_links
855 | )
856 | ]
857 | print(
858 | "[epoch %d] #links in candidate set: %d" % (epoch, len(new_links))
859 | )
860 |
861 | if (
862 | epoch >= self.args.il_start
863 | and (epoch + 1) % (self.args.semi_learn_step * 10) == 0
864 | and len(new_links) != 0
865 | and self.args.il
866 | ):
867 | new_links_elect = new_links
868 | print("\n#new_links_elect:", len(new_links_elect))
869 |
870 | self.train_ill = np.vstack((self.train_ill, np.array(new_links_elect)))
871 | print("train_ill.shape:", self.train_ill.shape)
872 |
873 | num_true = len([nl for nl in new_links_elect if nl in self.test_ill_])
874 | print("#true_links: %d" % num_true)
875 | print(
876 | "true link ratio: %.1f%%" % (100 * num_true / len(new_links_elect))
877 | )
878 | for nl in new_links_elect:
879 | self.left_non_train.remove(nl[0])
880 | self.right_non_train.remove(nl[1])
881 | print(
882 | "#entity not in train set: %d (left) %d (right)"
883 | % (len(self.left_non_train), len(self.right_non_train))
884 | )
885 |
886 | new_links = []
887 | if (epoch + 1) % self.args.check_point == 0:
888 | print("\n[epoch {:d}] checkpoint!".format(epoch))
889 | self.test(epoch)
890 | del joint_emb, gph_emb, img_emb, rel_emb, att_emb, name_emb, char_emb
891 |
892 | print("[optimization finished!]")
893 | print(
894 | "best epoch is {}, hits@1 hits@10 MRR MR is: {}\n".format(
895 | self.best_epoch, self.best_data_list
896 | )
897 | )
898 | print("[total time elapsed: {:.4f} s]".format(time.time() - t_total))
899 |
900 | def test(self, epoch):
901 | with torch.no_grad():
902 | t_test = time.time()
903 | self.multimodal_encoder.eval()
904 |
905 | gph_emb, img_emb, rel_emb, att_emb, name_emb, char_emb, joint_emb = (
906 | self.multimodal_encoder(
907 | self.input_idx,
908 | self.adj,
909 | self.img_features,
910 | self.rel_features,
911 | self.att_features,
912 | self.name_features,
913 | self.char_features,
914 | )
915 | )
916 |
917 | final_emb = F.normalize(joint_emb)
918 | top_k = [1, 10, 50]
919 | if "100" in self.args.file_dir:
920 | Lvec = final_emb[self.test_left].cpu().data.numpy()
921 | Rvec = final_emb[self.test_right].cpu().data.numpy()
922 | acc_l2r, mean_l2r, mrr_l2r, acc_r2l, mean_r2l, mrr_r2l = multi_get_hits(
923 | Lvec, Rvec, top_k=top_k, args=self.args
924 | )
925 | del final_emb
926 | gc.collect()
927 | else:
928 | acc_l2r = np.zeros((len(top_k)), dtype=np.float32)
929 | acc_r2l = np.zeros((len(top_k)), dtype=np.float32)
930 | test_total, test_loss, mean_l2r, mean_r2l, mrr_l2r, mrr_r2l = (
931 | 0,
932 | 0.0,
933 | 0.0,
934 | 0.0,
935 | 0.0,
936 | 0.0,
937 | )
938 | if self.args.dist == 2:
939 | distance = pairwise_distances(
940 | final_emb[self.test_left], final_emb[self.test_right]
941 | )
942 | elif self.args.dist == 1:
943 | import scipy.spatial as T
944 |
945 | distance = torch.FloatTensor(
946 | T.distance.cdist(
947 | final_emb[self.test_left].cpu().data.numpy(),
948 | final_emb[self.test_right].cpu().data.numpy(),
949 | metric="cityblock",
950 | )
951 | )
952 | else:
953 | raise NotImplementedError
954 |
955 | if self.args.csls is True:
956 | distance = 1 - csls_sim(1 - distance, self.args.csls_k)
957 |
958 | to_write = []
959 | test_left_np = self.test_left.cpu().numpy()
960 | test_right_np = self.test_right.cpu().numpy()
961 | to_write.append(
962 | [
963 | "idx",
964 | "rank",
965 | "query_id",
966 | "gt_id",
967 | "ret1",
968 | "ret2",
969 | "ret3",
970 | "v1",
971 | "v2",
972 | "v3",
973 | ]
974 | )
975 |
976 | for idx in range(self.test_left.shape[0]):
977 | values, indices = torch.sort(distance[idx, :], descending=False)
978 | rank = (indices == idx).nonzero().squeeze().item()
979 | mean_l2r += rank + 1
980 | mrr_l2r += 1.0 / (rank + 1)
981 | for i in range(len(top_k)):
982 | if rank < top_k[i]:
983 | acc_l2r[i] += 1
984 |
985 | indices = indices.cpu().numpy()
986 | to_write.append(
987 | [
988 | idx,
989 | rank,
990 | test_left_np[idx],
991 | test_right_np[idx],
992 | test_right_np[indices[0]],
993 | test_right_np[indices[1]],
994 | test_right_np[indices[2]],
995 | round(values[0].item(), 4),
996 | round(values[1].item(), 4),
997 | round(values[2].item(), 4),
998 | ]
999 | )
1000 |
1001 | for idx in range(self.test_right.shape[0]):
1002 | _, indices = torch.sort(distance[:, idx], descending=False)
1003 | rank = (indices == idx).nonzero().squeeze().item()
1004 | mean_r2l += rank + 1
1005 | mrr_r2l += 1.0 / (rank + 1)
1006 | for i in range(len(top_k)):
1007 | if rank < top_k[i]:
1008 | acc_r2l[i] += 1
1009 |
1010 | mean_l2r /= self.test_left.size(0)
1011 | mean_r2l /= self.test_right.size(0)
1012 | mrr_l2r /= self.test_left.size(0)
1013 | mrr_r2l /= self.test_right.size(0)
1014 | for i in range(len(top_k)):
1015 | acc_l2r[i] = round(acc_l2r[i] / self.test_left.size(0), 4)
1016 | acc_r2l[i] = round(acc_r2l[i] / self.test_right.size(0), 4)
1017 | del (
1018 | distance,
1019 | gph_emb,
1020 | img_emb,
1021 | rel_emb,
1022 | att_emb,
1023 | name_emb,
1024 | char_emb,
1025 | joint_emb,
1026 | )
1027 | gc.collect()
1028 | print(
1029 | "l2r: acc of top {} = {}, mr = {:.3f}, mrr = {:.3f}, time = {:.4f} s ".format(
1030 | top_k, acc_l2r, mean_l2r, mrr_l2r, time.time() - t_test
1031 | )
1032 | )
1033 | print(
1034 | "r2l: acc of top {} = {}, mr = {:.3f}, mrr = {:.3f}, time = {:.4f} s \n".format(
1035 | top_k, acc_r2l, mean_r2l, mrr_r2l, time.time() - t_test
1036 | )
1037 | )
1038 | if acc_l2r[0] > self.best_hit_1:
1039 | self.best_hit_1 = acc_l2r[0]
1040 | self.best_epoch = epoch
1041 | self.best_data_list = [
1042 | acc_l2r[0],
1043 | acc_l2r[1],
1044 | mrr_l2r,
1045 | mean_l2r,
1046 | ]
1047 | import copy
1048 |
1049 | self.best_to_write = copy.deepcopy(to_write)
1050 |
1051 | if epoch + 1 == self.args.epochs:
1052 | pred_name = self.args.pred_name
1053 | import csv
1054 |
1055 | save_path = os.path.join(self.args.save_path, pred_name)
1056 | if not os.path.exists(save_path):
1057 | os.mkdir(save_path)
1058 | with open(os.path.join(save_path, pred_name + ".txt"), "w") as f:
1059 | wr = csv.writer(f, dialect="excel")
1060 | wr.writerows(self.best_to_write)
1061 |
1062 |
1063 | if __name__ == "__main__":
1064 | model = IBMEA()
1065 | model.train()
1066 |
--------------------------------------------------------------------------------