├── logs └── readme.md ├── save_pkl └── readme.md ├── src ├── __init__.py ├── loss.py ├── Load.py ├── layers.py ├── utils.py ├── models.py └── run.py ├── files └── IBMEA.png ├── requirements.txt ├── run_mmkb_main.sh ├── run_dbp_main.sh ├── README.md └── env.yml /logs/readme.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /save_pkl/readme.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /files/IBMEA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sutaoyu/IBMEA/HEAD/files/IBMEA.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.10.2 2 | scipy==1.5.2 3 | numpy==1.19.2 4 | pytorch-metric-learning==2.1.1 -------------------------------------------------------------------------------- /run_mmkb_main.sh: -------------------------------------------------------------------------------- 1 | gpu_id='0' 2 | # dataset: 'FB15K_DB15K' 'FB15K_YAGO15K' 3 | dataset='FB15K_DB15K' 4 | 5 | # ratio: 0.2 0.5 0.8 6 | ratio=0.2 7 | seed=2023 8 | warm=400 9 | bsize=1000 10 | if [[ "$dataset" == *"FB"* ]]; then 11 | dataset_dir='mmkb-datasets' 12 | tau=400 13 | else 14 | dataset_dir='DBP15K' 15 | tau=0.1 16 | ratio=0.3 17 | fi 18 | echo "Running with dataset=${dataset}, ratio=${ratio}" 19 | current_datetime=$(date +"%Y-%m-%d-%H-%M") 20 | head_name=${current_datetime}_${dataset} 21 | file_name=${head_name}_bsize${bsize}_${ratio} 22 | echo ${file_name} 23 | CUDA_VISIBLE_DEVICES=${gpu_id} python3 -u src/run.py \ 24 | --file_dir data/${dataset_dir}/${dataset} \ 25 | --pred_name ${file_name} \ 26 | --rate ${ratio} \ 27 | --lr .006 \ 28 | --epochs 500 \ 29 | --dropout 0.45 \ 30 | --hidden_units "300,300,300" \ 31 | --check_point 50 \ 32 | --bsize ${bsize} \ 33 | --il \ 34 | --il_start 20 \ 35 | --semi_learn_step 5 \ 36 | --csls \ 37 | --csls_k 3 \ 38 | --seed ${seed} \ 39 | --structure_encoder "gat" \ 40 | --img_dim 100 \ 41 | --attr_dim 100 \ 42 | --use_nce \ 43 | --tau ${tau} \ 44 | --use_sheduler_cos \ 45 | --num_warmup_steps ${warm} \ 46 | --w_name \ 47 | --w_char > logs/${file_name}.log -------------------------------------------------------------------------------- /run_dbp_main.sh: -------------------------------------------------------------------------------- 1 | gpu_id='0' 2 | # dataset: 'zh_en' 'ja_en' 'fr_en' 3 | dataset='zh_en' 4 | ratio=0.3 5 | seed=2023 6 | warm=200 7 | joint_beta=3 8 | ms_alpha=5 9 | training_step=1500 10 | bsize=7500 11 | if [[ "$dataset" == *"FB"* ]]; then 12 | dataset_dir='mmkb-datasets' 13 | tau=400 14 | else 15 | dataset_dir='DBP15K' 16 | tau=0.1 17 | ratio=0.3 18 | fi 19 | echo "Running with dataset=${dataset}" 20 | current_datetime=$(date +"%Y-%m-%d-%H-%M") 21 | head_name=${current_datetime}_${dataset} 22 | file_name=${head_name}_bsize${bsize}_${ratio} 23 | echo ${file_name} 24 | CUDA_VISIBLE_DEVICES=${gpu_id} python3 -u src/run.py \ 25 | --file_dir data/${dataset_dir}/${dataset} \ 26 | --pred_name ${file_name} \ 27 | --rate ${ratio} \ 28 | --lr .006 \ 29 | --epochs 1000 \ 30 | --dropout 0.45 \ 31 | --hidden_units "300,300,300" \ 32 | --check_point 50 \ 33 | --bsize ${bsize} \ 34 | --il \ 35 | --il_start 20 \ 36 | --semi_learn_step 5 \ 37 | --csls \ 38 | --csls_k 3 \ 39 | --seed ${seed} \ 40 | --structure_encoder "gat" \ 41 | --img_dim 100 \ 42 | --attr_dim 100 \ 43 | --use_nce \ 44 | --tau ${tau} \ 45 | --use_sheduler_cos \ 46 | --num_warmup_steps ${warm} \ 47 | --num_training_steps ${training_step} \ 48 | --joint_beta ${joint_beta} \ 49 | --ms_alpha ${ms_alpha} \ 50 | --w_name \ 51 | --w_char > logs/${file_name}.log 52 | -------------------------------------------------------------------------------- /src/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from pytorch_metric_learning import losses, miners 4 | 5 | try: 6 | from models import * 7 | from utils import * 8 | except: 9 | from src.models import * 10 | from src.utils import * 11 | 12 | 13 | class MsLoss(nn.Module): 14 | def __init__(self, device, thresh=0.5, scale_pos=0.1, scale_neg=40.0): 15 | super(MsLoss, self).__init__() 16 | self.device = device 17 | alpha, beta, base = scale_pos, scale_neg, thresh 18 | self.loss_func = losses.MultiSimilarityLoss(alpha=alpha, beta=beta, base=base) 19 | 20 | def sim(self, emb_left, emb_right): 21 | return emb_left.mm(emb_right.t()) 22 | 23 | def forward(self, emb, train_links): 24 | emb = F.normalize(emb) 25 | emb_train_left = emb[train_links[:, 0]] 26 | emb_train_right = emb[train_links[:, 1]] 27 | labels = torch.arange(emb_train_left.size(0)) 28 | embeddings = torch.cat([emb_train_left, emb_train_right], dim=0) 29 | labels = torch.cat([labels, labels], dim=0) 30 | loss = self.loss_func(embeddings, labels) 31 | return loss 32 | 33 | 34 | class InfoNCE_loss(nn.Module): 35 | def __init__(self, device, temperature=0.05) -> None: 36 | super().__init__() 37 | self.device = device 38 | self.t = temperature 39 | 40 | self.ce_loss = nn.CrossEntropyLoss() 41 | 42 | def sim(self, emb_left, emb_right): 43 | return emb_left.mm(emb_right.t()) 44 | 45 | def forward(self, emb, train_links): 46 | emb = F.normalize(emb) 47 | emb_train_left = emb[train_links[:, 0]] 48 | emb_train_right = emb[train_links[:, 1]] 49 | 50 | score = self.sim(emb_train_left, emb_train_right) 51 | 52 | bsize = emb_train_left.size()[0] 53 | label = torch.arange(bsize, dtype=torch.long).cuda(self.device) 54 | 55 | loss = self.ce_loss(score / self.t, label) 56 | return loss 57 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # IBMEA 2 | 3 | The code of paper _**IBMEA: Exploring Variational Information Bottleneck for Multi-modal Entity Alignment**_ [[arxiv](https://arxiv.org/abs/2407.19302)] [[ACM MM](https://dl.acm.org/doi/10.1145/3664647.3680954)] in Proceedings of ACM MM 2024. 4 | 5 | 6 |
7 | 8 |
9 | 10 | 11 | 12 | 13 | ## Env 14 | 15 | Run the following command to create the required conda environment: 16 | 17 | ``` 18 | conda env create -f environment.yml -n IBMEA 19 | ``` 20 | 21 | ## Datasets 22 | 23 | ### Cross-KG datasets 24 | 25 | The original cross-KG datasets (FB15K-DB15K/YAGO15K) comes from [MMKB](https://github.com/mniepert/mmkb), in which the image embeddings are extracted from the pre-trained VGG16. We follow [MCLEA](https://github.com/lzxlin/MCLEA) to use the image embeddings provided by [MMKB](https://github.com/mniepert/mmkb#visual-data-for-fb15k-yago15k-and-dbpedia15k) and transform the data into the format consistent with DBP15K. 26 | The converted dataset can be downloaded from [BaiduDisk](https://pan.baidu.com/s/1xOz8ga_J1LPbBSIgSFHbNQ) (the password is `IBME`) 27 | Place `mmkb-datasets` directory in the `data` directory. 28 | 29 | ### Bilingual datasets 30 | 31 | The multi-modal version of DBP15K dataset comes from the [EVA](https://github.com/cambridgeltl/eva) repository, and the folder `pkls` of DBP15K image features should be downloaded according to the guidance of EVA repository, and the downloaded folder `pkls` is placed in the `data` directory of this repository. 32 | 33 | ## How to run script 34 | 35 | ``` 36 | run_mmkb_main.sh 37 | run_dbp_main.sh 38 | ``` 39 | 40 | ## Citation 41 | 42 | If you use this model or code, please cite it as follows: 43 | 44 | ``` 45 | @inproceedings{IBMEA, 46 | author = {Taoyu Su and 47 | Jiawei Sheng and 48 | Shicheng Wang and 49 | Xinghua Zhang and 50 | Hongbo Xu and 51 | Tingwen Liu}, 52 | editor = {Jianfei Cai and 53 | Mohan S. Kankanhalli and 54 | Balakrishnan Prabhakaran and 55 | Susanne Boll and 56 | Ramanathan Subramanian and 57 | Liang Zheng and 58 | Vivek K. Singh and 59 | Pablo C{\'{e}}sar and 60 | Lexing Xie and 61 | Dong Xu}, 62 | title = {{IBMEA:} Exploring Variational Information Bottleneck for Multi-modal 63 | Entity Alignment}, 64 | booktitle = {Proceedings of the 32nd {ACM} International Conference on Multimedia, 65 | {MM} 2024, Melbourne, VIC, Australia, 28 October 2024 - 1 November 66 | 2024}, 67 | pages = {4436--4445}, 68 | publisher = {{ACM}}, 69 | year = {2024} 70 | } 71 | ``` 72 | 73 | # Acknowledgement 74 | 75 | We appreciate [MCLEA](https://github.com/lzxlin/MCLEA), [EVA](https://github.com/cambridgeltl/eva), [MMEA](https://github.com/liyichen-cly/MMEA) and many other related works for their open-source contributions. 76 | 77 | 78 | #### Remarks 79 | - Welcome to give me a star ⭐, let me know how many people want to know about this work 🤩. 80 | - It is said that people who click on stars ⭐ will have better acceptance rates for their papers 😃. 81 | - **路过点star,下次中A啊!!!!** 82 | 83 | 84 | -------------------------------------------------------------------------------- /src/Load.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import Counter 3 | import json 4 | import pickle 5 | from tqdm import tqdm 6 | 7 | 8 | def loadfile(fn, num=1): 9 | print("loading a file..." + fn) 10 | ret = [] 11 | with open(fn, encoding="utf-8") as f: 12 | for line in f: 13 | th = line[:-1].split("\t") 14 | x = [] 15 | for i in range(num): 16 | x.append(int(th[i])) 17 | ret.append(tuple(x)) 18 | return ret 19 | 20 | 21 | def get_ids(fn): 22 | ids = [] 23 | with open(fn, encoding="utf-8") as f: 24 | for line in f: 25 | th = line[:-1].split("\t") 26 | ids.append(int(th[0])) 27 | return ids 28 | 29 | 30 | def get_ent2id(fns): 31 | ent2id = {} 32 | for fn in fns: 33 | with open(fn, "r", encoding="utf-8") as f: 34 | for line in f: 35 | th = line[:-1].split("\t") 36 | ent2id[th[1]] = int(th[0]) 37 | return ent2id 38 | 39 | 40 | def load_attr(fns, e, ent2id, topA=1000): 41 | cnt = {} 42 | for fn in fns: 43 | with open(fn, "r", encoding="utf-8") as f: 44 | for line in f: 45 | th = line[:-1].split("\t") 46 | if th[0] not in ent2id: 47 | continue 48 | for i in range(1, len(th)): 49 | if th[i] not in cnt: 50 | cnt[th[i]] = 1 51 | else: 52 | cnt[th[i]] += 1 53 | fre = [(k, cnt[k]) for k in sorted(cnt, key=cnt.get, reverse=True)] 54 | attr2id = {} 55 | for i in range(min(topA, len(fre))): 56 | attr2id[fre[i][0]] = i 57 | attr = np.zeros((e, topA), dtype=np.float32) 58 | for fn in fns: 59 | with open(fn, "r", encoding="utf-8") as f: 60 | for line in f: 61 | th = line[:-1].split("\t") 62 | if th[0] in ent2id: 63 | for i in range(1, len(th)): 64 | if th[i] in attr2id: 65 | attr[ent2id[th[0]]][attr2id[th[i]]] = 1.0 66 | return attr 67 | 68 | 69 | def load_relation(e, KG, topR=1000): 70 | rel_mat = np.zeros((e, topR), dtype=np.float32) 71 | rels = np.array(KG)[:, 1] 72 | top_rels = Counter(rels).most_common(topR) 73 | rel_index_dict = {r: i for i, (r, cnt) in enumerate(top_rels)} 74 | for tri in KG: 75 | h = tri[0] 76 | r = tri[1] 77 | o = tri[2] 78 | if r in rel_index_dict: 79 | rel_mat[h][rel_index_dict[r]] += 1.0 80 | rel_mat[o][rel_index_dict[r]] += 1.0 81 | return np.array(rel_mat) 82 | 83 | 84 | def load_json_embd(path): 85 | embd_dict = {} 86 | with open(path) as f: 87 | for line in f: 88 | example = json.loads(line.strip()) 89 | vec = np.array([float(e) for e in example["feature"].split()]) 90 | embd_dict[int(example["guid"])] = vec 91 | return embd_dict 92 | 93 | 94 | def load_img(e_num, path): 95 | img_dict = pickle.load(open(path, "rb")) 96 | imgs_np = np.array(list(img_dict.values())) 97 | mean = np.mean(imgs_np, axis=0) 98 | std = np.std(imgs_np, axis=0) 99 | img_embd = np.array( 100 | [ 101 | img_dict[i] if i in img_dict else np.random.normal(mean, std, mean.shape[0]) 102 | for i in range(e_num) 103 | ] 104 | ) 105 | print("%.2f%% entities have images" % (100 * len(img_dict) / e_num)) 106 | return img_embd 107 | 108 | 109 | def load_img_new(e_num, path, triples): 110 | from collections import defaultdict 111 | 112 | img_dict = pickle.load(open(path, "rb")) 113 | neighbor_list = defaultdict(list) 114 | for triple in triples: 115 | head = triple[0] 116 | relation = triple[1] 117 | tail = triple[2] 118 | if tail in img_dict: 119 | neighbor_list[head].append(tail) 120 | if head in img_dict: 121 | neighbor_list[tail].append(head) 122 | imgs_np = np.array(list(img_dict.values())) 123 | mean = np.mean(imgs_np, axis=0) 124 | std = np.std(imgs_np, axis=0) 125 | all_img_emb_normal = np.random.normal(mean, std, mean.shape[0]) 126 | img_embd = [] 127 | follow_neirbor_img_num = 0 128 | follow_all_img_num = 0 129 | for i in range(e_num): 130 | if i in img_dict: 131 | img_embd.append(img_dict[i]) 132 | else: 133 | if len(neighbor_list[i]) > 0: 134 | follow_neirbor_img_num += 1 135 | if i in img_dict: 136 | neighbor_list[i].append(i) 137 | neighbor_imgs_emb = np.array([img_dict[id] for id in neighbor_list[i]]) 138 | neighbor_imgs_emb_mean = np.mean(neighbor_imgs_emb, axis=0) 139 | img_embd.append(neighbor_imgs_emb_mean) 140 | else: 141 | follow_all_img_num += 1 142 | img_embd.append(all_img_emb_normal) 143 | print( 144 | "%.2f%% entities have images," % (100 * len(img_dict) / e_num), 145 | " follow_neirbor_img_num is {0},".format(follow_neirbor_img_num), 146 | " follow_all_img_num is {0}".format(follow_all_img_num), 147 | ) 148 | return np.array(img_embd) -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: IBMEA 2 | channels: 3 | - pytorch 4 | - https://repo.anaconda.com/pkgs/main 5 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/Paddle/ 6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ 7 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 8 | - defaults 9 | dependencies: 10 | - _libgcc_mutex=0.1=main 11 | - _openmp_mutex=5.1=1_gnu 12 | - blas=1.0=mkl 13 | - bzip2=1.0.8=h7b6447c_0 14 | - ca-certificates=2022.07.19=h06a4308_0 15 | - certifi=2021.5.30=py36h06a4308_0 16 | - cudatoolkit=11.3.1=h2bc3f7f_2 17 | - dataclasses=0.8=pyh4f3eec9_6 18 | - ffmpeg=4.3=hf484d3e_0 19 | - freetype=2.11.0=h70c0345_0 20 | - gmp=6.2.1=h295c915_3 21 | - gnutls=3.6.15=he1e5248_0 22 | - intel-openmp=2022.1.0=h9e868ea_3769 23 | - jpeg=9e=h7f8727e_0 24 | - lame=3.100=h7b6447c_0 25 | - lcms2=2.12=h3be6417_0 26 | - ld_impl_linux-64=2.38=h1181459_1 27 | - lerc=3.0=h295c915_0 28 | - libdeflate=1.8=h7f8727e_5 29 | - libffi=3.3=he6710b0_2 30 | - libgcc-ng=11.2.0=h1234567_1 31 | - libgomp=11.2.0=h1234567_1 32 | - libiconv=1.16=h7f8727e_2 33 | - libidn2=2.3.2=h7f8727e_0 34 | - libpng=1.6.37=hbc83047_0 35 | - libstdcxx-ng=11.2.0=h1234567_1 36 | - libtasn1=4.16.0=h27cfd23_0 37 | - libtiff=4.4.0=hecacb30_0 38 | - libunistring=0.9.10=h27cfd23_0 39 | - libuv=1.40.0=h7b6447c_0 40 | - libwebp-base=1.2.4=h5eee18b_0 41 | - lz4-c=1.9.3=h295c915_1 42 | - mkl=2020.2=256 43 | - mkl-service=2.3.0=py36he8ac12f_0 44 | - mkl_fft=1.3.0=py36h54f3939_0 45 | - mkl_random=1.1.1=py36h0573a6f_0 46 | - ncurses=6.3=h5eee18b_3 47 | - nettle=3.7.3=hbbd107a_1 48 | - numpy=1.19.2=py36h54aff64_0 49 | - numpy-base=1.19.2=py36hfa32c7d_0 50 | - olefile=0.46=py36_0 51 | - openh264=2.1.1=h4ff587b_0 52 | - openjpeg=2.4.0=h3ad879b_0 53 | - openssl=1.1.1q=h7f8727e_0 54 | - pillow=8.3.1=py36h2c7a002_0 55 | - pip=21.2.2=py36h06a4308_0 56 | - python=3.6.13=h12debd9_1 57 | - pytorch=1.10.2=py3.6_cuda11.3_cudnn8.2.0_0 58 | - pytorch-mutex=1.0=cuda 59 | - readline=8.1.2=h7f8727e_1 60 | - setuptools=58.0.4=py36h06a4308_0 61 | - six=1.16.0=pyhd3eb1b0_1 62 | - sqlite=3.39.3=h5082296_0 63 | - tk=8.6.12=h1ccaba5_0 64 | - torchaudio=0.10.2=py36_cu113 65 | - torchvision=0.11.3=py36_cu113 66 | - typing_extensions=4.1.1=pyh06a4308_0 67 | - wheel=0.37.1=pyhd3eb1b0_0 68 | - xz=5.2.6=h5eee18b_0 69 | - zlib=1.2.12=h5eee18b_3 70 | - zstd=1.5.2=ha4553b6_0 71 | - pip: 72 | - absl-py==1.4.0 73 | - argon2-cffi==21.3.0 74 | - argon2-cffi-bindings==21.2.0 75 | - async-generator==1.10 76 | - attrs==22.1.0 77 | - backcall==0.2.0 78 | - bleach==4.1.0 79 | - cachetools==4.2.4 80 | - cffi==1.15.1 81 | - charset-normalizer==2.0.12 82 | - click==8.0.4 83 | - cycler==0.11.0 84 | - debugpy==1.5.1 85 | - decorator==5.1.1 86 | - defusedxml==0.7.1 87 | - easydict==1.11 88 | - entrypoints==0.4 89 | - filelock==3.4.1 90 | - google-auth==2.22.0 91 | - google-auth-oauthlib==0.4.6 92 | - grpcio==1.48.2 93 | - huggingface-hub==0.4.0 94 | - idna==3.6 95 | - importlib-metadata==4.8.3 96 | - importlib-resources==5.4.0 97 | - ipykernel==5.5.6 98 | - ipython==7.16.3 99 | - ipython-genutils==0.2.0 100 | - ipywidgets==7.7.2 101 | - jedi==0.17.2 102 | - jinja2==3.0.3 103 | - joblib==1.1.1 104 | - jsonschema==3.2.0 105 | - jupyter==1.0.0 106 | - jupyter-client==7.1.2 107 | - jupyter-console==6.4.3 108 | - jupyter-core==4.9.2 109 | - jupyterlab-pygments==0.1.2 110 | - jupyterlab-widgets==1.1.1 111 | - kiwisolver==1.3.1 112 | - markdown==3.3.7 113 | - markupsafe==2.0.1 114 | - matplotlib==3.3.4 115 | - mistune==0.8.4 116 | - nbclient==0.5.9 117 | - nbconvert==6.0.7 118 | - nbformat==5.1.3 119 | - nest-asyncio==1.5.6 120 | - notebook==6.4.10 121 | - oauthlib==3.2.2 122 | - packaging==21.3 123 | - pandas==1.1.5 124 | - pandocfilters==1.5.0 125 | - parso==0.7.1 126 | - pexpect==4.8.0 127 | - pickleshare==0.7.5 128 | - prometheus-client==0.15.0 129 | - prompt-toolkit==3.0.31 130 | - protobuf==3.19.6 131 | - ptyprocess==0.7.0 132 | - pyasn1==0.5.1 133 | - pyasn1-modules==0.3.0 134 | - pycparser==2.21 135 | - pygments==2.13.0 136 | - pyparsing==3.0.9 137 | - pyrsistent==0.18.0 138 | - python-dateutil==2.8.2 139 | - pytorch-metric-learning==2.1.1 140 | - pytz==2022.6 141 | - pyyaml==6.0.1 142 | - pyzmq==24.0.1 143 | - qtconsole==5.2.2 144 | - qtpy==2.0.1 145 | - regex==2023.8.8 146 | - requests==2.27.1 147 | - requests-oauthlib==1.3.1 148 | - rsa==4.9 149 | - sacremoses==0.0.53 150 | - scikit-learn==0.24.2 151 | - scipy==1.5.2 152 | - send2trash==1.8.0 153 | - tensorboard==2.10.1 154 | - tensorboard-data-server==0.6.1 155 | - tensorboard-plugin-wit==1.8.1 156 | - terminado==0.12.1 157 | - testpath==0.6.0 158 | - threadpoolctl==3.1.0 159 | - tokenizers==0.12.1 160 | - tornado==6.1 161 | - tqdm==4.64.1 162 | - traitlets==4.3.3 163 | - transformers==4.18.0 164 | - unidecode==1.3.7 165 | - urllib3==1.26.18 166 | - wcwidth==0.2.5 167 | - webencodings==0.5.1 168 | - werkzeug==2.0.3 169 | - widgetsnbextension==3.6.1 170 | - zipp==3.6.0 171 | 172 | -------------------------------------------------------------------------------- /src/layers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import absolute_import 5 | from __future__ import unicode_literals 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import math 10 | import torch 11 | import torch.nn as nn 12 | from torch.nn.parameter import Parameter 13 | from torch.nn.modules.module import Module 14 | import torch.nn.functional as F 15 | 16 | 17 | class SpecialSpmmFunction(torch.autograd.Function): 18 | @staticmethod 19 | def forward(ctx, indices, values, shape, b): 20 | assert indices.requires_grad == False 21 | a = torch.sparse_coo_tensor(indices, values, shape) 22 | ctx.save_for_backward(a, b) 23 | ctx.N = shape[0] 24 | return torch.matmul(a, b) 25 | 26 | @staticmethod 27 | def backward(ctx, grad_output): 28 | a, b = ctx.saved_tensors 29 | grad_values = grad_b = None 30 | if ctx.needs_input_grad[1]: 31 | grad_a_dense = grad_output.matmul(b.t()) 32 | edge_idx = a._indices()[0, :] * ctx.N + a._indices()[1, :] 33 | grad_values = grad_a_dense.view(-1)[edge_idx] 34 | if ctx.needs_input_grad[3]: 35 | grad_b = a.t().matmul(grad_output) 36 | return None, grad_values, None, grad_b 37 | 38 | 39 | class SpecialSpmm(nn.Module): 40 | def forward(self, indices, values, shape, b): 41 | return SpecialSpmmFunction.apply(indices, values, shape, b) 42 | 43 | 44 | class MultiHeadGraphAttention(nn.Module): 45 | def __init__( 46 | self, n_head, f_in, f_out, attn_dropout, diag=True, init=None, bias=False 47 | ): 48 | super(MultiHeadGraphAttention, self).__init__() 49 | self.n_head = n_head 50 | self.f_in = f_in 51 | self.f_out = f_out 52 | self.diag = diag 53 | if self.diag: 54 | self.w = Parameter(torch.Tensor(n_head, 1, f_out)) 55 | else: 56 | self.w = Parameter(torch.Tensor(n_head, f_in, f_out)) 57 | self.a_src_dst = Parameter(torch.Tensor(n_head, f_out * 2, 1)) 58 | self.attn_dropout = attn_dropout 59 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2) 60 | self.special_spmm = SpecialSpmm() 61 | if bias: 62 | self.bias = Parameter(torch.Tensor(f_out)) 63 | nn.init.constant_(self.bias, 0) 64 | else: 65 | self.register_parameter("bias", None) 66 | if init is not None and diag: 67 | init(self.w) 68 | stdv = 1.0 / math.sqrt(self.a_src_dst.size(1)) 69 | nn.init.uniform_(self.a_src_dst, -stdv, stdv) 70 | else: 71 | nn.init.xavier_uniform_(self.w) 72 | nn.init.xavier_uniform_(self.a_src_dst) 73 | 74 | def forward(self, input, adj): 75 | output = [] 76 | for i in range(self.n_head): 77 | N = input.size()[0] 78 | edge = adj._indices() 79 | if self.diag: 80 | h = torch.mul(input, self.w[i]) 81 | else: 82 | h = torch.mm(input, self.w[i]) 83 | 84 | edge_h = torch.cat( 85 | (h[edge[0, :], :], h[edge[1, :], :]), dim=1 86 | ) # edge: 2*D x E 87 | edge_e = torch.exp( 88 | -self.leaky_relu(edge_h.mm(self.a_src_dst[i]).squeeze()) 89 | ) # edge_e: 1 x E 90 | 91 | e_rowsum = self.special_spmm( 92 | edge, 93 | edge_e, 94 | torch.Size([N, N]), 95 | ( 96 | torch.ones(size=(N, 1)).cuda() 97 | if next(self.parameters()).is_cuda 98 | else torch.ones(size=(N, 1)) 99 | ), 100 | ) # e_rowsum: N x 1 101 | edge_e = F.dropout( 102 | edge_e, self.attn_dropout, training=self.training 103 | ) # edge_e: 1 x E 104 | 105 | h_prime = self.special_spmm(edge, edge_e, torch.Size([N, N]), h) 106 | h_prime = h_prime.div(e_rowsum) 107 | 108 | output.append(h_prime.unsqueeze(0)) 109 | 110 | output = torch.cat(output, dim=0) 111 | if self.bias is not None: 112 | return output + self.bias 113 | else: 114 | return output 115 | 116 | def __repr__(self): 117 | if self.diag: 118 | return ( 119 | self.__class__.__name__ 120 | + " (" 121 | + str(self.f_out) 122 | + " -> " 123 | + str(self.f_out) 124 | + ") * " 125 | + str(self.n_head) 126 | + " heads" 127 | ) 128 | else: 129 | return ( 130 | self.__class__.__name__ 131 | + " (" 132 | + str(self.f_in) 133 | + " -> " 134 | + str(self.f_out) 135 | + ") * " 136 | + str(self.n_head) 137 | + " heads" 138 | ) 139 | 140 | 141 | class GraphConvolution(Module): 142 | def __init__(self, in_features, out_features, bias=True): 143 | super(GraphConvolution, self).__init__() 144 | self.in_features = in_features 145 | self.out_features = out_features 146 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 147 | if bias: 148 | self.bias = Parameter(torch.FloatTensor(out_features)) 149 | else: 150 | self.register_parameter("bias", None) 151 | self.reset_parameters() 152 | 153 | def reset_parameters(self): 154 | stdv = 1.0 / math.sqrt(self.weight.size(1)) 155 | self.weight.data.uniform_(-stdv, stdv) 156 | if self.bias is not None: 157 | self.bias.data.uniform_(-stdv, stdv) 158 | 159 | def forward(self, input, adj): 160 | support = torch.mm(input, self.weight) 161 | output = torch.spmm(adj, support) 162 | if self.bias is not None: 163 | return output + self.bias 164 | else: 165 | return output 166 | 167 | def __repr__(self): 168 | return ( 169 | self.__class__.__name__ 170 | + " (" 171 | + str(self.in_features) 172 | + " -> " 173 | + str(self.out_features) 174 | + ")" 175 | ) 176 | 177 | 178 | class ProjectionHead(nn.Module): 179 | def __init__(self, in_dim, hidden_dim, out_dim, dropout): 180 | super(ProjectionHead, self).__init__() 181 | self.l1 = nn.Linear(in_dim, hidden_dim, bias=False) 182 | self.l2 = nn.Linear(hidden_dim, out_dim, bias=False) 183 | self.dropout = dropout 184 | 185 | def forward(self, x): 186 | x = self.l1(x) 187 | x = F.relu(x) 188 | x = F.dropout(x, self.dropout, training=self.training) 189 | x = self.l2(x) 190 | return x 191 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import absolute_import 5 | from __future__ import unicode_literals 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import os, time, multiprocessing 10 | import math 11 | import random 12 | import numpy as np 13 | import scipy 14 | import scipy.sparse as sp 15 | import torch 16 | import gc 17 | from tqdm import tqdm 18 | import json 19 | from torch.utils.data import Dataset 20 | 21 | import torch.optim as optim 22 | 23 | 24 | def normalize_adj(mx): 25 | rowsum = np.array(mx.sum(1)) 26 | r_inv_sqrt = np.power(rowsum, -0.5).flatten() 27 | r_inv_sqrt[np.isinf(r_inv_sqrt)] = 0.0 28 | r_mat_inv_sqrt = sp.diags(r_inv_sqrt) 29 | return mx.dot(r_mat_inv_sqrt).transpose().dot(r_mat_inv_sqrt) 30 | 31 | def normalize_features(mx): 32 | rowsum = np.array(mx.sum(1)) 33 | r_inv = np.power(rowsum, -1).flatten() 34 | r_inv[np.isinf(r_inv)] = 0.0 35 | r_mat_inv = sp.diags(r_inv) 36 | mx = r_mat_inv.dot(mx) 37 | return mx 38 | 39 | 40 | def sparse_mx_to_torch_sparse_tensor(sparse_mx): 41 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 42 | indices = torch.from_numpy( 43 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64) 44 | ) 45 | values = torch.FloatTensor(sparse_mx.data) 46 | shape = torch.Size(sparse_mx.shape) 47 | return torch.sparse.FloatTensor(indices, values, shape) 48 | 49 | 50 | def read_raw_data(file_dir, l=[1, 2], reverse=False): 51 | print("loading raw data...") 52 | def read_file(file_paths): 53 | tups = [] 54 | for file_path in file_paths: 55 | with open(file_path, "r", encoding="utf-8") as fr: 56 | for line in fr: 57 | params = line.strip("\n").split("\t") 58 | tups.append(tuple([int(x) for x in params])) 59 | return tups 60 | 61 | def read_dict(file_paths): 62 | ent2id_dict = {} 63 | ids = [] 64 | for file_path in file_paths: 65 | id = set() 66 | with open(file_path, "r", encoding="utf-8") as fr: 67 | for line in fr: 68 | params = line.strip("\n").split("\t") 69 | ent2id_dict[params[1]] = int(params[0]) 70 | id.add(int(params[0])) 71 | ids.append(id) 72 | return ent2id_dict, ids 73 | 74 | ent2id_dict, ids = read_dict([file_dir + "/ent_ids_" + str(i) for i in l]) 75 | ills = read_file([file_dir + "/ill_ent_ids"]) 76 | triples = read_file([file_dir + "/triples_" + str(i) for i in l]) 77 | rel_size = max([t[1] for t in triples]) + 1 78 | reverse_triples = [] 79 | r_hs, r_ts = {}, {} 80 | for h, r, t in triples: 81 | if r not in r_hs: 82 | r_hs[r] = set() 83 | if r not in r_ts: 84 | r_ts[r] = set() 85 | r_hs[r].add(h) 86 | r_ts[r].add(t) 87 | if reverse: 88 | reverse_r = r + rel_size 89 | reverse_triples.append((t, reverse_r, h)) 90 | if reverse_r not in r_hs: 91 | r_hs[reverse_r] = set() 92 | if reverse_r not in r_ts: 93 | r_ts[reverse_r] = set() 94 | r_hs[reverse_r].add(t) 95 | r_ts[reverse_r].add(h) 96 | if reverse: 97 | triples = triples + reverse_triples 98 | assert len(r_hs) == len(r_ts) 99 | return ent2id_dict, ills, triples, r_hs, r_ts, ids 100 | 101 | 102 | def div_list(ls, n): 103 | ls_len = len(ls) 104 | if n <= 0 or 0 == ls_len: 105 | return [] 106 | if n > ls_len: 107 | return [] 108 | elif n == ls_len: 109 | return [[i] for i in ls] 110 | else: 111 | j = ls_len // n 112 | k = ls_len % n 113 | ls_return = [] 114 | for i in range(0, (n - 1) * j, j): 115 | ls_return.append(ls[i : i + j]) 116 | ls_return.append(ls[(n - 1) * j :]) 117 | return ls_return 118 | 119 | 120 | def get_adjr(ent_size, triples, norm=False): 121 | print("getting a sparse tensor r_adj...") 122 | M = {} 123 | for tri in triples: 124 | if tri[0] == tri[2]: 125 | continue 126 | if (tri[0], tri[2]) not in M: 127 | M[(tri[0], tri[2])] = 0 128 | M[(tri[0], tri[2])] += 1 129 | ind, val = [], [] 130 | for fir, sec in M: 131 | ind.append((fir, sec)) 132 | ind.append((sec, fir)) # 关系逆 133 | val.append(M[(fir, sec)]) 134 | val.append(M[(fir, sec)]) 135 | for i in range(ent_size): 136 | ind.append((i, i)) 137 | val.append(1) 138 | if norm: 139 | ind = np.array(ind, dtype=np.int32) 140 | val = np.array(val, dtype=np.float32) 141 | adj = sp.coo_matrix( 142 | (val, (ind[:, 0], ind[:, 1])), shape=(ent_size, ent_size), dtype=np.float32 143 | ) 144 | return sparse_mx_to_torch_sparse_tensor(normalize_adj(adj)) 145 | else: 146 | M = torch.sparse_coo_tensor( 147 | torch.LongTensor(ind).t(), 148 | torch.FloatTensor(val), 149 | torch.Size([ent_size, ent_size]), 150 | ) 151 | return M 152 | 153 | def pairwise_distances(x, y=None): 154 | x_norm = (x**2).sum(1).view(-1, 1) 155 | if y is not None: 156 | y_norm = (y**2).sum(1).view(1, -1) 157 | else: 158 | y = x 159 | y_norm = x_norm.view(1, -1) 160 | dist = x_norm + y_norm - 2.0 * torch.mm(x, torch.transpose(y, 0, 1)) 161 | return torch.clamp(dist, 0.0, np.inf) 162 | 163 | 164 | def multi_cal_rank(task, sim, top_k, l_or_r): 165 | mean = 0 166 | mrr = 0 167 | num = [0 for k in top_k] 168 | for i in range(len(task)): 169 | ref = task[i] 170 | if l_or_r == 0: 171 | rank = (sim[i, :]).argsort() 172 | else: 173 | rank = (sim[:, i]).argsort() 174 | assert ref in rank 175 | rank_index = np.where(rank == ref)[0][0] 176 | mean += rank_index + 1 177 | mrr += 1.0 / (rank_index + 1) 178 | for j in range(len(top_k)): 179 | if rank_index < top_k[j]: 180 | num[j] += 1 181 | return mean, num, mrr 182 | 183 | 184 | def multi_get_hits(Lvec, Rvec, top_k=(1, 5, 10, 50, 100), args=None): 185 | result = [] 186 | sim = pairwise_distances(torch.FloatTensor(Lvec), torch.FloatTensor(Rvec)).numpy() 187 | if args.csls is True: 188 | sim = 1 - csls_sim(1 - sim, args.csls_k) 189 | for i in [0, 1]: 190 | top_total = np.array([0] * len(top_k)) 191 | mean_total, mrr_total = 0.0, 0.0 192 | s_len = Lvec.shape[0] if i == 0 else Rvec.shape[0] 193 | tasks = div_list(np.array(range(s_len)), 10) 194 | pool = multiprocessing.Pool(processes=len(tasks)) 195 | reses = list() 196 | for task in tasks: 197 | if i == 0: 198 | reses.append( 199 | pool.apply_async(multi_cal_rank, (task, sim[task, :], top_k, i)) 200 | ) 201 | else: 202 | reses.append( 203 | pool.apply_async(multi_cal_rank, (task, sim[:, task], top_k, i)) 204 | ) 205 | pool.close() 206 | pool.join() 207 | for res in reses: 208 | mean, num, mrr = res.get() 209 | mean_total += mean 210 | mrr_total += mrr 211 | top_total += np.array(num) 212 | acc_total = top_total / s_len 213 | for i in range(len(acc_total)): 214 | acc_total[i] = round(acc_total[i], 4) 215 | mean_total /= s_len 216 | mrr_total /= s_len 217 | result.append(acc_total) 218 | result.append(mean_total) 219 | result.append(mrr_total) 220 | return result 221 | 222 | 223 | def csls_sim(sim_mat, k): 224 | nearest_values1 = torch.mean(torch.topk(sim_mat, k)[0], 1) 225 | nearest_values2 = torch.mean(torch.topk(sim_mat.t(), k)[0], 1) 226 | csls_sim_mat = 2 * sim_mat.t() - nearest_values1 227 | csls_sim_mat = csls_sim_mat.t() - nearest_values2 228 | return csls_sim_mat 229 | 230 | 231 | def get_topk_indices(M, K=1000): 232 | H, W = M.shape 233 | M_view = M.view(-1) 234 | vals, indices = M_view.topk(K) 235 | print("highest sim:", vals[0].item(), "lowest sim:", vals[-1].item()) 236 | two_d_indices = torch.cat( 237 | ((indices // W).unsqueeze(1), (indices % W).unsqueeze(1)), dim=1 238 | ) 239 | return two_d_indices 240 | 241 | 242 | def normalize_zero_one(A): 243 | A -= A.min(1, keepdim=True)[0] 244 | A /= A.max(1, keepdim=True)[0] 245 | return A 246 | 247 | 248 | if __name__ == "__main__": 249 | pass 250 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import absolute_import 5 | from __future__ import unicode_literals 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import random 13 | import numpy as np 14 | import debugpy 15 | 16 | try: 17 | from layers import * 18 | except: 19 | from src.layers import * 20 | 21 | 22 | class GAT(nn.Module): 23 | def __init__( 24 | self, n_units, n_heads, dropout, attn_dropout, instance_normalization, diag 25 | ): 26 | super(GAT, self).__init__() 27 | self.num_layer = len(n_units) - 1 28 | self.dropout = dropout 29 | self.inst_norm = instance_normalization 30 | if self.inst_norm: 31 | self.norm = nn.InstanceNorm1d(n_units[0], momentum=0.0, affine=True) 32 | self.layer_stack = nn.ModuleList() 33 | self.diag = diag 34 | for i in range(self.num_layer): 35 | f_in = n_units[i] * n_heads[i - 1] if i else n_units[i] 36 | self.layer_stack.append( 37 | MultiHeadGraphAttention( 38 | n_heads[i], 39 | f_in, 40 | n_units[i + 1], 41 | attn_dropout, 42 | diag, 43 | nn.init.ones_, 44 | False, 45 | ) 46 | ) 47 | 48 | def forward(self, x, adj): 49 | if self.inst_norm: 50 | x = self.norm(x) 51 | for i, gat_layer in enumerate(self.layer_stack): 52 | if i + 1 < self.num_layer: 53 | x = F.dropout(x, self.dropout, training=self.training) 54 | x = gat_layer(x, adj) 55 | if self.diag: 56 | x = x.mean(dim=0) 57 | if i + 1 < self.num_layer: 58 | if self.diag: 59 | x = F.elu(x) 60 | else: 61 | x = F.elu(x.transpose(0, 1).contiguous().view(adj.size(0), -1)) 62 | if not self.diag: 63 | x = x.mean(dim=0) 64 | 65 | return x 66 | 67 | 68 | class GCN(nn.Module): 69 | def __init__(self, nfeat, nhid, nout, dropout): 70 | super(GCN, self).__init__() 71 | 72 | self.gc1 = GraphConvolution(nfeat, nhid) 73 | self.gc2 = GraphConvolution(nhid, nout) 74 | self.dropout = dropout 75 | 76 | def forward(self, x, adj): 77 | x = F.relu(self.gc1(x, adj)) # change to leaky relu 78 | x = F.dropout(x, self.dropout, training=self.training) 79 | x = self.gc2(x, adj) 80 | # x = F.relu(x) 81 | return x 82 | 83 | 84 | def cosine_sim(im, s): 85 | """Cosine similarity between all the image and sentence pairs""" 86 | return im.mm(s.t()) 87 | 88 | 89 | def l2norm(X): 90 | """L2-normalize columns of X""" 91 | norm = torch.pow(X, 2).sum(dim=1, keepdim=True).sqrt() 92 | a = norm.expand_as(X) + 1e-8 93 | X = torch.div(X, a) 94 | return X 95 | 96 | 97 | class MultiModalFusion(nn.Module): 98 | def __init__(self, modal_num, with_weight=1): 99 | super().__init__() 100 | self.modal_num = modal_num 101 | self.requires_grad = True if with_weight > 0 else False 102 | self.weight = nn.Parameter( 103 | torch.ones((self.modal_num, 1)), requires_grad=self.requires_grad 104 | ) 105 | 106 | def forward(self, embs): 107 | assert len(embs) == self.modal_num 108 | weight_norm = F.softmax(self.weight, dim=0) 109 | embs = [ 110 | weight_norm[idx] * F.normalize(embs[idx]) 111 | for idx in range(self.modal_num) 112 | if embs[idx] is not None 113 | ] 114 | joint_emb = torch.cat(embs, dim=1) 115 | return joint_emb 116 | 117 | 118 | class MultiModalFusionNew_allmodal(nn.Module): 119 | def __init__(self, modal_num, with_weight=1): 120 | super().__init__() 121 | self.modal_num = modal_num 122 | self.requires_grad = True if with_weight > 0 else False 123 | self.weight = nn.Parameter( 124 | torch.ones((self.modal_num, 1)), requires_grad=self.requires_grad 125 | ) 126 | 127 | self.linear_0 = nn.Linear(100, 600) 128 | self.linear_1 = nn.Linear(100, 600) 129 | self.linear_2 = nn.Linear(100, 600) 130 | self.linear_3 = nn.Linear(300, 600) 131 | self.linear = nn.Linear(600, 600) 132 | self.v = nn.Linear(600, 1, bias=False) 133 | 134 | self.LN_pre = nn.LayerNorm(600) 135 | self.LN_pre = nn.LayerNorm(600) 136 | 137 | def forward(self, embs): 138 | assert len(embs) == self.modal_num 139 | 140 | emb_list = [] 141 | 142 | if embs[0] is not None: 143 | emb_list.append(self.linear_0(embs[0]).unsqueeze(1)) 144 | if embs[1] is not None: 145 | emb_list.append(self.linear_1(embs[1]).unsqueeze(1)) 146 | if embs[2] is not None: 147 | emb_list.append(self.linear_2(embs[2]).unsqueeze(1)) 148 | if embs[3] is not None: 149 | emb_list.append(self.linear_3(embs[3]).unsqueeze(1)) 150 | new_embs = torch.cat(emb_list, dim=1) # [n, 4, e] 151 | new_embs = self.LN_pre(new_embs) 152 | s = self.v(torch.tanh(self.linear(new_embs))) # [n, 4, 1] 153 | a = torch.softmax(s, dim=-1) 154 | joint_emb_1 = torch.matmul(a.transpose(-1, -2), new_embs).squeeze(1) # [n, e] 155 | joint_emb = joint_emb_1 156 | return joint_emb 157 | 158 | 159 | class IBMultiModal(nn.Module): 160 | def __init__( 161 | self, 162 | args, 163 | ent_num, 164 | img_feature_dim, 165 | char_feature_dim=None, 166 | use_project_head=False, 167 | ): 168 | super(IBMultiModal, self).__init__() 169 | 170 | self.args = args 171 | attr_dim = self.args.attr_dim 172 | img_dim = self.args.img_dim 173 | char_dim = self.args.char_dim 174 | dropout = self.args.dropout 175 | self.ENT_NUM = ent_num 176 | self.use_project_head = use_project_head 177 | 178 | self.n_units = [int(x) for x in self.args.hidden_units.strip().split(",")] 179 | self.n_heads = [int(x) for x in self.args.heads.strip().split(",")] 180 | self.input_dim = int(self.args.hidden_units.strip().split(",")[0]) 181 | self.entity_emb = nn.Embedding(self.ENT_NUM, self.input_dim) 182 | nn.init.normal_(self.entity_emb.weight, std=1.0 / math.sqrt(self.ENT_NUM)) 183 | self.entity_emb.requires_grad = True 184 | self.rel_fc = nn.Linear(1000, attr_dim) 185 | self.rel_fc_mu = nn.Linear(attr_dim, attr_dim) 186 | self.rel_fc_std = nn.Linear(attr_dim, attr_dim) 187 | self.rel_fc_d1 = nn.Linear(attr_dim, attr_dim) 188 | self.rel_fc_d2 = nn.Linear(attr_dim, attr_dim) 189 | 190 | self.att_fc = nn.Linear(1000, attr_dim) 191 | self.att_fc_mu = nn.Linear(attr_dim, attr_dim) 192 | self.att_fc_std = nn.Linear(attr_dim, attr_dim) 193 | self.att_fc_d1 = nn.Linear(attr_dim, attr_dim) 194 | self.att_fc_d2 = nn.Linear(attr_dim, attr_dim) 195 | 196 | self.img_fc = nn.Linear(img_feature_dim, img_dim) 197 | self.img_fc_mu = nn.Linear(img_dim, img_dim) 198 | self.img_fc_std = nn.Linear(img_dim, img_dim) 199 | self.img_fc_d1 = nn.Linear(img_dim, img_dim) 200 | self.img_fc_d2 = nn.Linear(img_dim, img_dim) 201 | 202 | joint_dim = 600 203 | self.joint_fc = nn.Linear(joint_dim, joint_dim) 204 | self.joint_fc_mu = nn.Linear(joint_dim, joint_dim) 205 | self.joint_fc_std = nn.Linear(joint_dim, joint_dim) 206 | self.joint_fc_d1 = nn.Linear(joint_dim, joint_dim) 207 | self.joint_fc_d2 = nn.Linear(joint_dim, joint_dim) 208 | 209 | use_graph_vib = self.args.use_graph_vib 210 | use_attr_vib = self.args.use_attr_vib 211 | use_img_vib = self.args.use_img_vib 212 | use_rel_vib = self.args.use_rel_vib 213 | 214 | no_diag = self.args.no_diag 215 | 216 | if no_diag: 217 | diag = False 218 | else: 219 | diag = True 220 | 221 | self.use_graph_vib = use_graph_vib 222 | self.use_attr_vib = use_attr_vib 223 | self.use_img_vib = use_img_vib 224 | self.use_rel_vib = use_rel_vib 225 | self.use_joint_vib = self.args.use_joint_vib 226 | 227 | self.name_fc = nn.Linear(300, char_dim) 228 | self.char_fc = nn.Linear(char_feature_dim, char_dim) 229 | 230 | self.kld_loss = 0 231 | self.gph_layer_norm_mu = nn.LayerNorm(self.input_dim, elementwise_affine=True) 232 | self.gph_layer_norm_std = nn.LayerNorm(self.input_dim, elementwise_affine=True) 233 | if self.args.structure_encoder == "gcn": 234 | if self.use_graph_vib: 235 | self.cross_graph_model_mu = GCN( 236 | self.n_units[0], 237 | self.n_units[1], 238 | self.n_units[2], 239 | dropout=self.args.dropout, 240 | ) 241 | self.cross_graph_model_std = GCN( 242 | self.n_units[0], 243 | self.n_units[1], 244 | self.n_units[2], 245 | dropout=self.args.dropout, 246 | ) 247 | else: 248 | self.cross_graph_model = GCN( 249 | self.n_units[0], 250 | self.n_units[1], 251 | self.n_units[2], 252 | dropout=self.args.dropout, 253 | ) 254 | elif self.args.structure_encoder == "gat": 255 | if self.use_graph_vib: 256 | self.cross_graph_model_mu = GAT( 257 | n_units=self.n_units, 258 | n_heads=self.n_heads, 259 | dropout=args.dropout, 260 | attn_dropout=args.attn_dropout, 261 | instance_normalization=self.args.instance_normalization, 262 | diag=diag, 263 | ) 264 | self.cross_graph_model_std = GAT( 265 | n_units=self.n_units, 266 | n_heads=self.n_heads, 267 | dropout=args.dropout, 268 | attn_dropout=args.attn_dropout, 269 | instance_normalization=self.args.instance_normalization, 270 | diag=diag, 271 | ) 272 | else: 273 | self.cross_graph_model = GAT( 274 | n_units=self.n_units, 275 | n_heads=self.n_heads, 276 | dropout=args.dropout, 277 | attn_dropout=args.attn_dropout, 278 | instance_normalization=self.args.instance_normalization, 279 | diag=True, 280 | ) 281 | 282 | if self.use_project_head: 283 | self.img_pro = ProjectionHead(img_dim, img_dim, img_dim, dropout) 284 | self.att_pro = ProjectionHead(attr_dim, attr_dim, attr_dim, dropout) 285 | self.rel_pro = ProjectionHead(attr_dim, attr_dim, attr_dim, dropout) 286 | self.gph_pro = ProjectionHead( 287 | self.n_units[2], self.n_units[2], self.n_units[2], dropout 288 | ) 289 | 290 | if self.args.fusion_id == 1: 291 | self.fusion = MultiModalFusion( 292 | modal_num=self.args.inner_view_num, with_weight=self.args.with_weight 293 | ) 294 | elif self.args.fusion_id == 2: 295 | self.fusion = MultiModalFusionNew_allmodal( 296 | modal_num=self.args.inner_view_num, with_weight=self.args.with_weight 297 | ) 298 | 299 | def _kld_gauss(self, mu_1, logsigma_1, mu_2, logsigma_2): 300 | from torch.distributions.kl import kl_divergence 301 | from torch.distributions import Normal 302 | 303 | sigma_1 = torch.exp(0.1 + 0.9 * F.softplus(torch.clamp_max(logsigma_1, 0.4))) 304 | sigma_2 = torch.exp(0.1 + 0.9 * F.softplus(torch.clamp_max(logsigma_2, 0.4))) 305 | mu_1_fixed = mu_1.clone() 306 | sigma_1_fixed = sigma_1.clone() 307 | mu_1_fixed[torch.isnan(mu_1_fixed)] = 0 308 | mu_1_fixed[torch.isinf(mu_1_fixed)] = torch.max( 309 | mu_1_fixed[~torch.isinf(mu_1_fixed)] 310 | ) 311 | sigma_1_fixed[torch.isnan(sigma_1_fixed)] = 1 312 | sigma_1_fixed[torch.isinf(sigma_1_fixed)] = torch.max( 313 | sigma_1_fixed[~torch.isinf(sigma_1_fixed)] 314 | ) 315 | sigma_1_fixed[sigma_1_fixed <= 0] = 1 316 | q_target = Normal(mu_1_fixed, sigma_1_fixed) 317 | 318 | mu_2_fixed = mu_2.clone() 319 | sigma_2_fixed = sigma_2.clone() 320 | mu_2_fixed[torch.isnan(mu_2_fixed)] = 0 321 | mu_2_fixed[torch.isinf(mu_2_fixed)] = torch.max( 322 | mu_2_fixed[~torch.isinf(mu_2_fixed)] 323 | ) 324 | sigma_2_fixed[torch.isnan(sigma_2_fixed)] = 1 325 | sigma_2_fixed[torch.isinf(sigma_2_fixed)] = torch.max( 326 | sigma_2_fixed[~torch.isinf(sigma_2_fixed)] 327 | ) 328 | sigma_2_fixed[sigma_2_fixed <= 0] = 1 329 | q_context = Normal(mu_2_fixed, sigma_2_fixed) 330 | kl = kl_divergence(q_target, q_context).mean(dim=0).sum() 331 | return kl 332 | 333 | def forward( 334 | self, 335 | input_idx, 336 | adj, 337 | img_features=None, 338 | rel_features=None, 339 | att_features=None, 340 | name_features=None, 341 | char_features=None, 342 | ): 343 | 344 | if self.args.w_gcn: 345 | if self.use_graph_vib: 346 | if self.args.structure_encoder == "gat": 347 | gph_emb_mu = self.cross_graph_model_mu( 348 | self.entity_emb(input_idx), adj 349 | ) 350 | mu = self.gph_layer_norm_mu(gph_emb_mu) 351 | 352 | gph_emb_std = self.cross_graph_model_std( 353 | self.entity_emb(input_idx), adj 354 | ) 355 | std = F.elu(gph_emb_std) 356 | eps = torch.randn_like(std) 357 | gph_emb = mu + eps * std 358 | gph_kld_loss = self._kld_gauss( 359 | mu, std, torch.zeros_like(mu), torch.ones_like(std) 360 | ) 361 | self.kld_loss = gph_kld_loss 362 | else: 363 | mu = self.cross_graph_model_mu(self.entity_emb(input_idx), adj) 364 | logstd = self.cross_graph_model_mu(self.entity_emb(input_idx), adj) 365 | eps = torch.randn_like(mu) 366 | gph_emb = mu + eps * torch.exp(logstd) 367 | gph_kld_loss = self._kld_gauss( 368 | mu, logstd, torch.zeros_like(mu), torch.ones_like(logstd) 369 | ) 370 | self.kld_loss = gph_kld_loss 371 | else: 372 | gph_emb = self.cross_graph_model(self.entity_emb(input_idx), adj) 373 | else: 374 | gph_emb = None 375 | if self.args.w_img: 376 | if self.use_img_vib: 377 | img_emb = self.img_fc(img_features) 378 | img_emb_h = F.relu(img_emb) 379 | mu = self.img_fc_mu(img_emb_h) 380 | logvar = self.img_fc_std(img_emb_h) 381 | std = torch.exp(0.5 * logvar) 382 | eps = torch.rand_like(std) 383 | img_emb = mu + eps * std 384 | img_kld_loss = self._kld_gauss( 385 | mu, std, torch.zeros_like(mu), torch.ones_like(std) 386 | ) 387 | self.img_kld_loss = img_kld_loss 388 | else: 389 | img_emb = self.img_fc(img_features) 390 | else: 391 | img_emb = None 392 | if self.args.w_rel: 393 | if self.use_rel_vib: 394 | rel_emb = self.rel_fc(rel_features) 395 | rel_emb_h = F.relu(rel_emb) 396 | mu = self.rel_fc_mu(rel_emb_h) 397 | logvar = self.rel_fc_std(rel_emb_h) 398 | std = torch.exp(0.5 * logvar) 399 | eps = torch.rand_like(std) 400 | rel_emb = mu + eps * std 401 | rel_kld_loss = self._kld_gauss( 402 | mu, std, torch.zeros_like(mu), torch.ones_like(std) 403 | ) 404 | self.rel_kld_loss = rel_kld_loss 405 | else: 406 | rel_emb = self.rel_fc(rel_features) 407 | else: 408 | rel_emb = None 409 | if self.args.w_attr: 410 | if self.use_attr_vib: 411 | att_emb = self.att_fc(att_features) 412 | att_emb_h = F.relu(att_emb) 413 | mu = self.att_fc_mu(att_emb_h) 414 | logvar = self.att_fc_std(att_emb_h) 415 | std = torch.exp(0.5 * logvar) 416 | eps = torch.rand_like(std) 417 | att_emb = mu + eps * std 418 | attr_kld_loss = self._kld_gauss( 419 | mu, std, torch.zeros_like(mu), torch.ones_like(std) 420 | ) 421 | self.attr_kld_loss = attr_kld_loss 422 | else: 423 | att_emb = self.att_fc(att_features) 424 | else: 425 | att_emb = None 426 | 427 | if self.args.w_name: 428 | name_emb = self.name_fc(name_features) 429 | else: 430 | name_emb = None 431 | if self.args.w_char: 432 | char_emb = self.char_fc(char_features) 433 | else: 434 | char_emb = None 435 | 436 | if self.use_project_head: 437 | gph_emb = self.gph_pro(gph_emb) 438 | img_emb = self.img_pro(img_emb) 439 | rel_emb = self.rel_pro(rel_emb) 440 | att_emb = self.att_pro(att_emb) 441 | pass 442 | 443 | joint_emb = self.fusion( 444 | [img_emb, att_emb, rel_emb, gph_emb, name_emb, char_emb] 445 | ) 446 | 447 | if self.use_joint_vib: 448 | joint_emb = self.joint_fc(joint_emb) 449 | joint_emb_h = F.relu(joint_emb) 450 | mu = self.joint_fc_mu(joint_emb_h) 451 | logvar = self.joint_fc_std(joint_emb_h) 452 | std = torch.exp(0.5 * logvar) 453 | eps = torch.rand_like(std) 454 | joint_emb = mu + eps * std 455 | joint_kld_loss = self._kld_gauss( 456 | mu, std, torch.zeros_like(mu), torch.ones_like(std) 457 | ) 458 | self.joint_kld_loss = joint_kld_loss 459 | 460 | return gph_emb, img_emb, rel_emb, att_emb, name_emb, char_emb, joint_emb 461 | -------------------------------------------------------------------------------- /src/run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from __future__ import absolute_import 4 | from __future__ import unicode_literals 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import argparse 9 | from pprint import pprint 10 | from transformers import ( 11 | get_cosine_schedule_with_warmup, 12 | ) 13 | import torch.optim as optim 14 | 15 | try: 16 | from utils import * 17 | from models import * 18 | from Load import * 19 | from loss import * 20 | except: 21 | from src.utils import * 22 | from src.models import * 23 | from src.Load import * 24 | from src.loss import * 25 | 26 | 27 | def load_img_features(ent_num, file_dir, triples, use_mean_img=False): 28 | if "V1" in file_dir: 29 | split = "norm" 30 | img_vec_path = "data/pkls/dbpedia_wikidata_15k_norm_GA_id_img_feature_dict.pkl" 31 | elif "V2" in file_dir: 32 | split = "dense" 33 | img_vec_path = "data/pkls/dbpedia_wikidata_15k_dense_GA_id_img_feature_dict.pkl" 34 | elif "FB15K" in file_dir: 35 | filename = os.path.split(file_dir)[-1].upper() 36 | img_vec_path = ( 37 | "data/mmkb-datasets/" 38 | + filename 39 | + "/" 40 | + filename 41 | + "_id_img_feature_dict.pkl" 42 | ) 43 | else: 44 | split = file_dir.split("/")[-1] 45 | img_vec_path = "data/pkls/" + split + "_GA_id_img_feature_dict.pkl" 46 | if use_mean_img: 47 | img_features = load_img(ent_num, img_vec_path) 48 | else: 49 | img_features = load_img_new(ent_num, img_vec_path, triples) 50 | return img_features 51 | 52 | 53 | def load_img_features_dropout( 54 | ent_num, file_dir, triples, use_mean_img=False, img_dp_ratio=1.0 55 | ): 56 | if "FB15K" in file_dir: 57 | filename = os.path.split(file_dir)[-1].upper() 58 | if abs(1.0 - img_dp_ratio) > 1e-3: 59 | img_vec_path = ( 60 | "data/mmkb-datasets/" 61 | + "mmkb_dropout" 62 | + "/" 63 | + filename 64 | + "_id_img_feature_dict_with_dropout{0}.pkl".format(img_dp_ratio) 65 | ) 66 | print("dropout img_vec_path: ", img_vec_path) 67 | else: 68 | img_vec_path = ( 69 | "data/mmkb-datasets/" 70 | + filename 71 | + "/" 72 | + filename 73 | + "_id_img_feature_dict.pkl" 74 | ) 75 | else: 76 | split = file_dir.split("/")[-1] 77 | if abs(1.0 - img_dp_ratio) > 1e-3: 78 | img_vec_path = ( 79 | "data/pkls/dbp_dropout/" 80 | + split 81 | + "_GA_id_img_feature_dict_with_dropout{0}.pkl".format(img_dp_ratio) 82 | ) 83 | else: 84 | img_vec_path = "data/pkls/" + split + "_GA_id_img_feature_dict.pkl" 85 | if use_mean_img: 86 | img_features = load_img(ent_num, img_vec_path) 87 | else: 88 | img_features = load_img_new(ent_num, img_vec_path, triples) 89 | return img_features 90 | 91 | 92 | class IBMEA: 93 | 94 | def __init__(self): 95 | 96 | self.ent2id_dict = None 97 | self.ills = None 98 | self.triples = None 99 | self.r_hs = None 100 | self.r_ts = None 101 | self.ids = None 102 | self.left_ents = None 103 | self.right_ents = None 104 | 105 | self.img_features = None 106 | self.rel_features = None 107 | self.att_features = None 108 | self.char_features = None 109 | self.name_features = None 110 | self.ent_vec = None 111 | self.left_non_train = None 112 | self.right_non_train = None 113 | self.ENT_NUM = None 114 | self.REL_NUM = None 115 | self.adj = None 116 | self.train_ill = None 117 | self.test_ill_ = None 118 | self.test_ill = None 119 | self.test_left = None 120 | self.test_right = None 121 | self.multimodal_encoder = None 122 | self.weight_raw = None 123 | self.rel_fc = None 124 | self.att_fc = None 125 | self.img_fc = None 126 | self.char_fc = None 127 | self.shared_fc = None 128 | self.gcn_pro = None 129 | self.rel_pro = None 130 | self.attr_pro = None 131 | self.img_pro = None 132 | self.input_dim = None 133 | self.entity_emb = None 134 | self.input_idx = None 135 | self.n_units = None 136 | self.n_heads = None 137 | self.cross_graph_model = None 138 | self.params = None 139 | self.optimizer = None 140 | self.fusion = None 141 | self.parser = argparse.ArgumentParser() 142 | self.args = self.parse_options(self.parser) 143 | self.set_seed(self.args.seed, self.args.cuda) 144 | self.device = torch.device( 145 | "cuda" if self.args.cuda and torch.cuda.is_available() else "cpu" 146 | ) 147 | self.init_data() 148 | self.init_model() 149 | self.print_summary() 150 | self.best_hit_1 = 0.0 151 | self.best_epoch = 0 152 | self.best_data_list = [] 153 | self.best_to_write = [] 154 | 155 | @staticmethod 156 | def parse_options(parser): 157 | parser.add_argument( 158 | "--file_dir", 159 | type=str, 160 | default="data/DBP15K/zh_en", 161 | required=False, 162 | help="input dataset file directory, ('data/DBP15K/zh_en', 'data/DWY100K/dbp_wd')", 163 | ) 164 | parser.add_argument("--rate", type=float, default=0.3, help="training set rate") 165 | 166 | parser.add_argument( 167 | "--cuda", 168 | action="store_true", 169 | default=True, 170 | help="whether to use cuda or not", 171 | ) 172 | parser.add_argument("--seed", type=int, default=2021, help="random seed") 173 | parser.add_argument( 174 | "--epochs", type=int, default=1000, help="number of epochs to train" 175 | ) 176 | parser.add_argument("--check_point", type=int, default=100, help="check point") 177 | parser.add_argument( 178 | "--hidden_units", 179 | type=str, 180 | default="128,128,128", 181 | help="hidden units in each hidden layer(including in_dim and out_dim), splitted with comma", 182 | ) 183 | parser.add_argument( 184 | "--heads", 185 | type=str, 186 | default="2,2", 187 | help="heads in each gat layer, splitted with comma", 188 | ) 189 | parser.add_argument( 190 | "--instance_normalization", 191 | action="store_true", 192 | default=False, 193 | help="enable instance normalization", 194 | ) 195 | parser.add_argument( 196 | "--lr", type=float, default=0.005, help="initial learning rate" 197 | ) 198 | parser.add_argument( 199 | "--weight_decay", 200 | type=float, 201 | default=1e-2, 202 | help="weight decay (L2 loss on parameters)", 203 | ) 204 | parser.add_argument( 205 | "--dropout", type=float, default=0.0, help="dropout rate for layers" 206 | ) 207 | parser.add_argument( 208 | "--attn_dropout", 209 | type=float, 210 | default=0.0, 211 | help="dropout rate for gat layers", 212 | ) 213 | parser.add_argument( 214 | "--dist", type=int, default=2, help="L1 distance or L2 distance. ('1', '2')" 215 | ) 216 | parser.add_argument( 217 | "--csls", action="store_true", default=False, help="use CSLS for inference" 218 | ) 219 | parser.add_argument("--csls_k", type=int, default=10, help="top k for csls") 220 | parser.add_argument( 221 | "--il", action="store_true", default=False, help="Iterative learning?" 222 | ) 223 | parser.add_argument( 224 | "--semi_learn_step", 225 | type=int, 226 | default=10, 227 | help="If IL, what's the update step?", 228 | ) 229 | parser.add_argument( 230 | "--il_start", type=int, default=500, help="If Il, when to start?" 231 | ) 232 | parser.add_argument("--bsize", type=int, default=7500, help="batch size") 233 | parser.add_argument( 234 | "--alpha", type=float, default=0.2, help="the margin of InfoMaxNCE loss" 235 | ) 236 | parser.add_argument( 237 | "--with_weight", 238 | type=int, 239 | default=1, 240 | help="Whether to weight the fusion of different " "modal features", 241 | ) 242 | parser.add_argument( 243 | "--structure_encoder", 244 | type=str, 245 | default="gat", 246 | help="the encoder of structure view, " "[gcn|gat]", 247 | ) 248 | 249 | parser.add_argument( 250 | "--projection", 251 | action="store_true", 252 | default=False, 253 | help="add projection for model", 254 | ) 255 | 256 | parser.add_argument( 257 | "--attr_dim", 258 | type=int, 259 | default=100, 260 | help="the hidden size of attr and rel features", 261 | ) 262 | parser.add_argument( 263 | "--img_dim", type=int, default=100, help="the hidden size of img feature" 264 | ) 265 | parser.add_argument( 266 | "--name_dim", type=int, default=100, help="the hidden size of name feature" 267 | ) 268 | parser.add_argument( 269 | "--char_dim", type=int, default=100, help="the hidden size of char feature" 270 | ) 271 | 272 | parser.add_argument( 273 | "--w_gcn", action="store_false", default=True, help="with gcn features" 274 | ) 275 | parser.add_argument( 276 | "--w_rel", action="store_false", default=True, help="with rel features" 277 | ) 278 | parser.add_argument( 279 | "--w_attr", action="store_false", default=True, help="with attr features" 280 | ) 281 | parser.add_argument( 282 | "--w_name", action="store_false", default=True, help="with name features" 283 | ) 284 | parser.add_argument( 285 | "--w_char", action="store_false", default=True, help="with char features" 286 | ) 287 | parser.add_argument( 288 | "--w_img", action="store_false", default=True, help="with img features" 289 | ) 290 | 291 | parser.add_argument( 292 | "--no_diag", action="store_true", default=False, help="GAT use diag" 293 | ) 294 | parser.add_argument( 295 | "--use_joint_vib", action="store_true", default=False, help="use_joint_vib" 296 | ) 297 | 298 | parser.add_argument( 299 | "--use_graph_vib", action="store_false", default=True, help="use_graph_vib" 300 | ) 301 | parser.add_argument( 302 | "--use_attr_vib", action="store_false", default=True, help="use_attr_vib" 303 | ) 304 | parser.add_argument( 305 | "--use_img_vib", action="store_false", default=True, help="use_img_vib" 306 | ) 307 | parser.add_argument( 308 | "--use_rel_vib", action="store_false", default=True, help="use_rel_vib" 309 | ) 310 | parser.add_argument( 311 | "--modify_ms", action="store_true", default=False, help="modify_ms" 312 | ) 313 | parser.add_argument("--ms_alpha", type=float, default=0.1, help="ms scale_pos") 314 | parser.add_argument("--ms_beta", type=float, default=40.0, help="ms scale_neg") 315 | parser.add_argument("--ms_base", type=float, default=0.5, help="ms base") 316 | 317 | parser.add_argument("--Beta_g", type=float, default=0.001, help="graph beta") 318 | parser.add_argument("--Beta_i", type=float, default=0.001, help="graph beta") 319 | parser.add_argument("--Beta_r", type=float, default=0.01, help="graph beta") 320 | parser.add_argument("--Beta_a", type=float, default=0.001, help="graph beta") 321 | parser.add_argument( 322 | "--inner_view_num", type=int, default=6, help="the number of inner view" 323 | ) 324 | 325 | parser.add_argument( 326 | "--word_embedding", 327 | type=str, 328 | default="glove", 329 | help="the type of word embedding, " "[glove|fasttext]", 330 | ) 331 | parser.add_argument( 332 | "--use_project_head", 333 | action="store_true", 334 | default=False, 335 | help="use projection head", 336 | ) 337 | 338 | parser.add_argument( 339 | "--zoom", type=float, default=0.1, help="narrow the range of losses" 340 | ) 341 | parser.add_argument( 342 | "--save_path", type=str, default="save_pkl", help="save path" 343 | ) 344 | parser.add_argument( 345 | "--pred_name", type=str, default="pred.txt", help="pred name" 346 | ) 347 | 348 | parser.add_argument( 349 | "--hidden_size", type=int, default=300, help="the hidden size of MEAformer" 350 | ) 351 | parser.add_argument( 352 | "--intermediate_size", 353 | type=int, 354 | default=400, 355 | help="the hidden size of MEAformer", 356 | ) 357 | parser.add_argument( 358 | "--num_attention_heads", 359 | type=int, 360 | default=1, 361 | help="the number of attention_heads of MEAformer", 362 | ) 363 | parser.add_argument( 364 | "--num_hidden_layers", 365 | type=int, 366 | default=1, 367 | help="the number of hidden_layers of MEAformer", 368 | ) 369 | parser.add_argument("--position_embedding_type", default="absolute", type=str) 370 | parser.add_argument( 371 | "--use_intermediate", 372 | type=int, 373 | default=0, 374 | help="whether to use_intermediate", 375 | ) 376 | parser.add_argument( 377 | "--replay", type=int, default=0, help="whether to use replay strategy" 378 | ) 379 | parser.add_argument( 380 | "--neg_cross_kg", 381 | type=int, 382 | default=0, 383 | help="whether to force the negative samples in the opposite KG", 384 | ) 385 | 386 | parser.add_argument( 387 | "--tau", 388 | type=float, 389 | default=0.1, 390 | help="the temperature factor of contrastive loss", 391 | ) 392 | parser.add_argument( 393 | "--ab_weight", type=float, default=0.5, help="the weight of NTXent Loss" 394 | ) 395 | parser.add_argument( 396 | "--use_icl", action="store_true", default=False, help="use_icl" 397 | ) 398 | parser.add_argument( 399 | "--use_bce", action="store_true", default=False, help="use_bce" 400 | ) 401 | parser.add_argument( 402 | "--use_bce_new", action="store_true", default=False, help="use_bce_new" 403 | ) 404 | parser.add_argument( 405 | "--use_ms", action="store_true", default=False, help="use_ms" 406 | ) 407 | parser.add_argument( 408 | "--use_nt", action="store_true", default=False, help="use_nt" 409 | ) 410 | parser.add_argument( 411 | "--use_nce", action="store_true", default=False, help="use_nce" 412 | ) 413 | parser.add_argument( 414 | "--use_nce_new", action="store_true", default=False, help="use_nce_new" 415 | ) 416 | 417 | parser.add_argument( 418 | "--use_sheduler", action="store_true", default=False, help="use_sheduler" 419 | ) 420 | parser.add_argument( 421 | "--sheduler_gamma", type=float, default=0.98, help="sheduler_gamma" 422 | ) 423 | 424 | parser.add_argument("--joint_beta", type=float, default=1.0, help="joint_beta") 425 | 426 | parser.add_argument( 427 | "--use_sheduler_cos", 428 | action="store_true", 429 | default=False, 430 | help="use_sheduler_cos", 431 | ) 432 | parser.add_argument( 433 | "--num_warmup_steps", type=int, default=200, help="num_warmup_steps" 434 | ) 435 | parser.add_argument( 436 | "--num_training_steps", type=int, default=1000, help="num_training_steps" 437 | ) 438 | 439 | parser.add_argument( 440 | "--use_mean_img", action="store_true", default=False, help="use_mean_img" 441 | ) 442 | parser.add_argument("--awloss", type=int, default=0, help="whether to use awl") 443 | parser.add_argument("--mtloss", type=int, default=0, help="whether to use awl") 444 | 445 | parser.add_argument( 446 | "--graph_use_icl", type=int, default=1, help="graph_use_icl" 447 | ) 448 | parser.add_argument( 449 | "--graph_use_bce", type=int, default=1, help="graph_use_bce" 450 | ) 451 | parser.add_argument("--graph_use_ms", type=int, default=1, help="graph_use_ms") 452 | 453 | parser.add_argument("--img_use_icl", type=int, default=1, help="img_use_icl") 454 | parser.add_argument("--img_use_bce", type=int, default=1, help="img_use_bce") 455 | parser.add_argument("--img_use_ms", type=int, default=1, help="img_use_ms") 456 | 457 | parser.add_argument("--attr_use_icl", type=int, default=1, help="attr_use_icl") 458 | parser.add_argument("--attr_use_bce", type=int, default=1, help="attr_use_bce") 459 | parser.add_argument("--attr_use_ms", type=int, default=1, help="attr_use_ms") 460 | 461 | parser.add_argument("--rel_use_icl", type=int, default=1, help="rel_use_icl") 462 | parser.add_argument("--rel_use_bce", type=int, default=1, help="rel_use_bce") 463 | parser.add_argument("--rel_use_ms", type=int, default=1, help="rel_use_ms") 464 | 465 | parser.add_argument( 466 | "--joint_use_icl", action="store_true", default=False, help="joint_use_icl" 467 | ) 468 | parser.add_argument( 469 | "--joint_use_bce", action="store_true", default=False, help="joint_use_bce" 470 | ) 471 | parser.add_argument( 472 | "--joint_use_ms", action="store_true", default=False, help="joint_use_ms" 473 | ) 474 | 475 | parser.add_argument( 476 | "--img_dp_ratio", type=float, default=1.0, help="image dropout ratio" 477 | ) 478 | parser.add_argument( 479 | "--fusion_id", type=int, default=1, help="default weight fusion" 480 | ) 481 | return parser.parse_args() 482 | 483 | @staticmethod 484 | def set_seed(seed, cuda=True): 485 | random.seed(seed) 486 | np.random.seed(seed) 487 | torch.manual_seed(seed) 488 | if cuda and torch.cuda.is_available(): 489 | torch.cuda.manual_seed(seed) 490 | 491 | def print_summary(self): 492 | print("-----dataset summary-----") 493 | print("dataset:\t", self.args.file_dir) 494 | print("triple num:\t", len(self.triples)) 495 | print("entity num:\t", self.ENT_NUM) 496 | print("relation num:\t", self.REL_NUM) 497 | print( 498 | "train ill num:\t", 499 | self.train_ill.shape[0], 500 | "\ttest ill num:\t", 501 | self.test_ill.shape[0], 502 | ) 503 | print("-------------------------") 504 | 505 | def init_data(self): 506 | # Load data 507 | lang_list = [1, 2] 508 | file_dir = self.args.file_dir 509 | device = self.device 510 | 511 | self.ent2id_dict, self.ills, self.triples, self.r_hs, self.r_ts, self.ids = ( 512 | read_raw_data(file_dir, lang_list) 513 | ) 514 | e1 = os.path.join(file_dir, "ent_ids_1") 515 | e2 = os.path.join(file_dir, "ent_ids_2") 516 | self.left_ents = get_ids(e1) 517 | self.right_ents = get_ids(e2) 518 | 519 | self.ENT_NUM = len(self.ent2id_dict) 520 | self.REL_NUM = len(self.r_hs) 521 | print("total ent num: {}, rel num: {}".format(self.ENT_NUM, self.REL_NUM)) 522 | 523 | np.random.shuffle(self.ills) 524 | 525 | if abs(1.0 - self.args.img_dp_ratio) > 1e-3: 526 | self.img_features = load_img_features_dropout( 527 | self.ENT_NUM, 528 | file_dir, 529 | self.triples, 530 | self.args.use_mean_img, 531 | self.args.img_dp_ratio, 532 | ) 533 | else: 534 | self.img_features = load_img_features( 535 | self.ENT_NUM, file_dir, self.triples, self.args.use_mean_img 536 | ) 537 | self.img_features = F.normalize(torch.Tensor(self.img_features).to(device)) 538 | print("image feature shape:", self.img_features.shape) 539 | data_dir, dataname = os.path.split(file_dir) 540 | if self.args.word_embedding == "glove": 541 | word2vec_path = "data/embedding/glove.6B.300d.txt" 542 | elif self.args.word_embedding == "fasttext": 543 | pass 544 | else: 545 | raise Exception("error word embedding") 546 | 547 | if "DBP15K" in file_dir: 548 | if self.args.w_name or self.args.w_char: 549 | print("name feature shape:", self.name_features.shape) 550 | print("char feature shape:", self.char_features.shape) 551 | pass 552 | 553 | self.train_ill = np.array( 554 | self.ills[: int(len(self.ills) // 1 * self.args.rate)], dtype=np.int32 555 | ) 556 | self.test_ill_ = self.ills[int(len(self.ills) // 1 * self.args.rate) :] 557 | self.test_ill = np.array(self.test_ill_, dtype=np.int32) 558 | 559 | self.test_left = torch.LongTensor(self.test_ill[:, 0].squeeze()).to(device) 560 | self.test_right = torch.LongTensor(self.test_ill[:, 1].squeeze()).to(device) 561 | 562 | self.left_non_train = list( 563 | set(self.left_ents) - set(self.train_ill[:, 0].tolist()) 564 | ) 565 | self.right_non_train = list( 566 | set(self.right_ents) - set(self.train_ill[:, 1].tolist()) 567 | ) 568 | 569 | print( 570 | "#left entity : %d, #right entity: %d" 571 | % (len(self.left_ents), len(self.right_ents)) 572 | ) 573 | print( 574 | "#left entity not in train set: %d, #right entity not in train set: %d" 575 | % (len(self.left_non_train), len(self.right_non_train)) 576 | ) 577 | self.rel_features = load_relation(self.ENT_NUM, self.triples, 1000) 578 | self.rel_features = torch.Tensor(self.rel_features).to(device) 579 | print("relation feature shape:", self.rel_features.shape) 580 | a1 = os.path.join(file_dir, "training_attrs_1") 581 | a2 = os.path.join(file_dir, "training_attrs_2") 582 | self.att_features = load_attr( 583 | [a1, a2], self.ENT_NUM, self.ent2id_dict, 1000 584 | ) 585 | self.att_features = torch.Tensor(self.att_features).to(device) 586 | print("attribute feature shape:", self.att_features.shape) 587 | 588 | self.adj = get_adjr( 589 | self.ENT_NUM, self.triples, norm=True 590 | ) 591 | self.adj = self.adj.to(self.device) 592 | 593 | def init_model(self): 594 | img_dim = self.img_features.shape[1] 595 | char_dim = ( 596 | self.char_features.shape[1] if self.char_features is not None else 100 597 | ) 598 | 599 | self.multimodal_encoder = IBMultiModal( 600 | args=self.args, 601 | ent_num=self.ENT_NUM, 602 | img_feature_dim=img_dim, 603 | char_feature_dim=char_dim, 604 | use_project_head=self.args.use_project_head, 605 | ).to(self.device) 606 | 607 | 608 | self.params = [{"params": list(self.multimodal_encoder.parameters())}] 609 | total_params = sum( 610 | p.numel() 611 | for p in self.multimodal_encoder.parameters() 612 | if p.requires_grad 613 | ) 614 | 615 | self.optimizer = optim.AdamW( 616 | self.params, lr=self.args.lr, weight_decay=self.args.weight_decay 617 | ) 618 | 619 | if self.args.use_sheduler: 620 | self.scheduler = optim.lr_scheduler.ExponentialLR( 621 | optimizer=self.optimizer, gamma=self.args.sheduler_gamma 622 | ) 623 | elif self.args.use_sheduler_cos: 624 | self.scheduler = get_cosine_schedule_with_warmup( 625 | optimizer=self.optimizer, 626 | num_warmup_steps=self.args.num_warmup_steps, 627 | num_training_steps=self.args.num_training_steps, 628 | ) 629 | ms_alpha = self.args.ms_alpha 630 | ms_beta = self.args.ms_beta 631 | ms_base = self.args.ms_base 632 | self.criterion_ms = MsLoss( 633 | device=self.device, thresh=ms_base, scale_pos=ms_alpha, scale_neg=ms_beta 634 | ) 635 | self.criterion_nce = InfoNCE_loss(device=self.device, temperature=self.args.tau) 636 | 637 | def semi_supervised_learning(self): 638 | with torch.no_grad(): 639 | gph_emb, img_emb, rel_emb, att_emb, name_emb, char_emb, joint_emb = ( 640 | self.multimodal_encoder( 641 | self.input_idx, 642 | self.adj, 643 | self.img_features, 644 | self.rel_features, 645 | self.att_features, 646 | self.name_features, 647 | self.char_features, 648 | ) 649 | ) 650 | 651 | final_emb = F.normalize(joint_emb) 652 | distance_list = [] 653 | for i in np.arange(0, len(self.left_non_train), 1000): 654 | d = pairwise_distances( 655 | final_emb[self.left_non_train[i : i + 1000]], 656 | final_emb[self.right_non_train], 657 | ) 658 | distance_list.append(d) 659 | distance = torch.cat(distance_list, dim=0) 660 | preds_l = torch.argmin(distance, dim=1).cpu().numpy().tolist() 661 | preds_r = torch.argmin(distance.t(), dim=1).cpu().numpy().tolist() 662 | del distance_list, distance, final_emb 663 | del gph_emb, img_emb, rel_emb, att_emb, name_emb, char_emb, joint_emb 664 | return preds_l, preds_r 665 | 666 | def train(self): 667 | print("model config is:") 668 | pprint(self.args, indent=2) 669 | print("[start training...] ") 670 | t_total = time.time() 671 | new_links = [] 672 | epoch_KE, epoch_CG = 0, 0 673 | 674 | bsize = self.args.bsize 675 | device = self.device 676 | 677 | self.input_idx = torch.LongTensor(np.arange(self.ENT_NUM)).to(device) 678 | 679 | use_graph_vib = self.args.use_graph_vib 680 | use_attr_vib = self.args.use_attr_vib 681 | use_img_vib = self.args.use_img_vib 682 | use_rel_vib = self.args.use_rel_vib 683 | use_joint_vib = self.args.use_joint_vib 684 | use_bce = self.args.use_bce 685 | use_bce_new = self.args.use_bce_new 686 | use_icl = self.args.use_icl 687 | use_ms = self.args.use_ms 688 | use_nt = self.args.use_nt 689 | use_nce = self.args.use_nce 690 | use_nce_new = self.args.use_nce_new 691 | 692 | joint_use_icl = self.args.joint_use_icl 693 | joint_use_bce = self.args.joint_use_bce 694 | joint_use_ms = self.args.joint_use_ms 695 | 696 | graph_use_icl = self.args.graph_use_icl 697 | graph_use_bce = self.args.graph_use_bce 698 | graph_use_ms = self.args.graph_use_ms 699 | 700 | img_use_icl = self.args.img_use_icl 701 | img_use_bce = self.args.img_use_bce 702 | img_use_ms = self.args.img_use_ms 703 | 704 | attr_use_icl = self.args.attr_use_icl 705 | attr_use_bce = self.args.attr_use_bce 706 | attr_use_ms = self.args.attr_use_ms 707 | 708 | rel_use_icl = self.args.rel_use_icl 709 | rel_use_bce = self.args.rel_use_bce 710 | rel_use_ms = self.args.rel_use_ms 711 | 712 | Beta_g = self.args.Beta_g 713 | Beta_i = self.args.Beta_i 714 | Beta_r = self.args.Beta_r 715 | Beta_a = self.args.Beta_a 716 | 717 | for epoch in range(self.args.epochs): 718 | t_epoch = time.time() 719 | self.multimodal_encoder.train() 720 | self.optimizer.zero_grad() 721 | 722 | gph_emb, img_emb, rel_emb, att_emb, name_emb, char_emb, joint_emb = ( 723 | self.multimodal_encoder( 724 | self.input_idx, 725 | self.adj, 726 | self.img_features, 727 | self.rel_features, 728 | self.att_features, 729 | self.name_features, 730 | self.char_features, 731 | ) 732 | ) 733 | epoch_CG += 1 734 | np.random.shuffle(self.train_ill) 735 | print("train_ill length:", len(self.train_ill)) 736 | for si in np.arange(0, self.train_ill.shape[0], bsize): 737 | loss_all = 0 738 | Beta = 0.001 739 | print("[epoch {:d}] ".format(epoch), end="") 740 | loss_list = [] 741 | if self.args.w_gcn: 742 | if use_graph_vib: 743 | 744 | gph_bce_loss = self.criterion_nce( 745 | gph_emb, self.train_ill[si : si + bsize] 746 | ) 747 | gph_kld_loss = self.multimodal_encoder.kld_loss 748 | loss_G = gph_bce_loss + Beta_g * gph_kld_loss 749 | print( 750 | " gph_bce_loss: {:f}, gph_kld_loss: {:f}, Beta_g: {:f}".format( 751 | gph_bce_loss, gph_kld_loss, Beta_g 752 | ), 753 | end="", 754 | ) 755 | loss_list.append(loss_G) 756 | 757 | if self.args.w_img: 758 | if use_img_vib: 759 | 760 | img_bce_loss = self.criterion_nce( 761 | img_emb, self.train_ill[si : si + bsize] 762 | ) 763 | 764 | img_kld_loss = self.multimodal_encoder.img_kld_loss 765 | loss_I = img_bce_loss + Beta_i * img_kld_loss 766 | print( 767 | " img_bce_loss: {:f}, img_kld_loss: {:f}, Beta_i: {:f}".format( 768 | img_bce_loss, img_kld_loss, Beta_i 769 | ), 770 | end="", 771 | ) 772 | loss_list.append(loss_I) 773 | 774 | if self.args.w_rel: 775 | if use_rel_vib: 776 | rel_bce_loss = self.criterion_nce( 777 | rel_emb, self.train_ill[si : si + bsize] 778 | ) 779 | rel_kld_loss = self.multimodal_encoder.rel_kld_loss 780 | loss_R = rel_bce_loss + Beta_r * rel_kld_loss 781 | print( 782 | " rel_bce_loss: {:f}, rel_kld_loss: {:f}, Beta_r: {:f}".format( 783 | rel_bce_loss, rel_kld_loss, Beta_r 784 | ), 785 | end="", 786 | ) 787 | loss_list.append(loss_R) 788 | 789 | if self.args.w_attr: 790 | if use_attr_vib: 791 | attr_bce_loss = self.criterion_nce( 792 | att_emb, self.train_ill[si : si + bsize] 793 | ) 794 | attr_kld_loss = self.multimodal_encoder.attr_kld_loss 795 | loss_A = attr_bce_loss + Beta_a * attr_kld_loss 796 | print( 797 | " attr_bce_loss: {:f}, attr_kld_loss: {:f}, Beta_a: {:f}".format( 798 | attr_bce_loss, attr_kld_loss, Beta_a 799 | ), 800 | end="", 801 | ) 802 | loss_list.append(loss_A) 803 | 804 | if use_joint_vib: 805 | pass 806 | else: 807 | joint_ms_loss = self.criterion_ms( 808 | joint_emb, self.train_ill[si : si + bsize] 809 | ) 810 | loss_J = joint_ms_loss * self.args.joint_beta 811 | print( 812 | " joint_ms_loss: {:f}, joint_beta: {:f}".format( 813 | joint_ms_loss, self.args.joint_beta 814 | ), 815 | end="", 816 | ) 817 | loss_list.append(loss_J) 818 | loss_all = sum(loss_list) 819 | 820 | print(" loss_all: {:f},".format(loss_all)) 821 | loss_all.backward(retain_graph=True) 822 | torch.nn.utils.clip_grad_norm_( 823 | parameters=self.multimodal_encoder.parameters(), 824 | max_norm=0.1, 825 | norm_type=2, 826 | ) 827 | 828 | self.optimizer.step() 829 | if self.args.use_sheduler and epoch > 400: 830 | self.scheduler.step() 831 | if self.args.use_sheduler_cos: 832 | self.scheduler.step() 833 | if ( 834 | epoch >= self.args.il_start 835 | and (epoch + 1) % self.args.semi_learn_step == 0 836 | and self.args.il 837 | ): 838 | preds_l, preds_r = self.semi_supervised_learning() 839 | if (epoch + 1) % ( 840 | self.args.semi_learn_step * 10 841 | ) == self.args.semi_learn_step: 842 | new_links = [ 843 | (self.left_non_train[i], self.right_non_train[p]) 844 | for i, p in enumerate(preds_l) 845 | if preds_r[p] == i 846 | ] 847 | else: 848 | new_links = [ 849 | (self.left_non_train[i], self.right_non_train[p]) 850 | for i, p in enumerate(preds_l) 851 | if (preds_r[p] == i) 852 | and ( 853 | (self.left_non_train[i], self.right_non_train[p]) 854 | in new_links 855 | ) 856 | ] 857 | print( 858 | "[epoch %d] #links in candidate set: %d" % (epoch, len(new_links)) 859 | ) 860 | 861 | if ( 862 | epoch >= self.args.il_start 863 | and (epoch + 1) % (self.args.semi_learn_step * 10) == 0 864 | and len(new_links) != 0 865 | and self.args.il 866 | ): 867 | new_links_elect = new_links 868 | print("\n#new_links_elect:", len(new_links_elect)) 869 | 870 | self.train_ill = np.vstack((self.train_ill, np.array(new_links_elect))) 871 | print("train_ill.shape:", self.train_ill.shape) 872 | 873 | num_true = len([nl for nl in new_links_elect if nl in self.test_ill_]) 874 | print("#true_links: %d" % num_true) 875 | print( 876 | "true link ratio: %.1f%%" % (100 * num_true / len(new_links_elect)) 877 | ) 878 | for nl in new_links_elect: 879 | self.left_non_train.remove(nl[0]) 880 | self.right_non_train.remove(nl[1]) 881 | print( 882 | "#entity not in train set: %d (left) %d (right)" 883 | % (len(self.left_non_train), len(self.right_non_train)) 884 | ) 885 | 886 | new_links = [] 887 | if (epoch + 1) % self.args.check_point == 0: 888 | print("\n[epoch {:d}] checkpoint!".format(epoch)) 889 | self.test(epoch) 890 | del joint_emb, gph_emb, img_emb, rel_emb, att_emb, name_emb, char_emb 891 | 892 | print("[optimization finished!]") 893 | print( 894 | "best epoch is {}, hits@1 hits@10 MRR MR is: {}\n".format( 895 | self.best_epoch, self.best_data_list 896 | ) 897 | ) 898 | print("[total time elapsed: {:.4f} s]".format(time.time() - t_total)) 899 | 900 | def test(self, epoch): 901 | with torch.no_grad(): 902 | t_test = time.time() 903 | self.multimodal_encoder.eval() 904 | 905 | gph_emb, img_emb, rel_emb, att_emb, name_emb, char_emb, joint_emb = ( 906 | self.multimodal_encoder( 907 | self.input_idx, 908 | self.adj, 909 | self.img_features, 910 | self.rel_features, 911 | self.att_features, 912 | self.name_features, 913 | self.char_features, 914 | ) 915 | ) 916 | 917 | final_emb = F.normalize(joint_emb) 918 | top_k = [1, 10, 50] 919 | if "100" in self.args.file_dir: 920 | Lvec = final_emb[self.test_left].cpu().data.numpy() 921 | Rvec = final_emb[self.test_right].cpu().data.numpy() 922 | acc_l2r, mean_l2r, mrr_l2r, acc_r2l, mean_r2l, mrr_r2l = multi_get_hits( 923 | Lvec, Rvec, top_k=top_k, args=self.args 924 | ) 925 | del final_emb 926 | gc.collect() 927 | else: 928 | acc_l2r = np.zeros((len(top_k)), dtype=np.float32) 929 | acc_r2l = np.zeros((len(top_k)), dtype=np.float32) 930 | test_total, test_loss, mean_l2r, mean_r2l, mrr_l2r, mrr_r2l = ( 931 | 0, 932 | 0.0, 933 | 0.0, 934 | 0.0, 935 | 0.0, 936 | 0.0, 937 | ) 938 | if self.args.dist == 2: 939 | distance = pairwise_distances( 940 | final_emb[self.test_left], final_emb[self.test_right] 941 | ) 942 | elif self.args.dist == 1: 943 | import scipy.spatial as T 944 | 945 | distance = torch.FloatTensor( 946 | T.distance.cdist( 947 | final_emb[self.test_left].cpu().data.numpy(), 948 | final_emb[self.test_right].cpu().data.numpy(), 949 | metric="cityblock", 950 | ) 951 | ) 952 | else: 953 | raise NotImplementedError 954 | 955 | if self.args.csls is True: 956 | distance = 1 - csls_sim(1 - distance, self.args.csls_k) 957 | 958 | to_write = [] 959 | test_left_np = self.test_left.cpu().numpy() 960 | test_right_np = self.test_right.cpu().numpy() 961 | to_write.append( 962 | [ 963 | "idx", 964 | "rank", 965 | "query_id", 966 | "gt_id", 967 | "ret1", 968 | "ret2", 969 | "ret3", 970 | "v1", 971 | "v2", 972 | "v3", 973 | ] 974 | ) 975 | 976 | for idx in range(self.test_left.shape[0]): 977 | values, indices = torch.sort(distance[idx, :], descending=False) 978 | rank = (indices == idx).nonzero().squeeze().item() 979 | mean_l2r += rank + 1 980 | mrr_l2r += 1.0 / (rank + 1) 981 | for i in range(len(top_k)): 982 | if rank < top_k[i]: 983 | acc_l2r[i] += 1 984 | 985 | indices = indices.cpu().numpy() 986 | to_write.append( 987 | [ 988 | idx, 989 | rank, 990 | test_left_np[idx], 991 | test_right_np[idx], 992 | test_right_np[indices[0]], 993 | test_right_np[indices[1]], 994 | test_right_np[indices[2]], 995 | round(values[0].item(), 4), 996 | round(values[1].item(), 4), 997 | round(values[2].item(), 4), 998 | ] 999 | ) 1000 | 1001 | for idx in range(self.test_right.shape[0]): 1002 | _, indices = torch.sort(distance[:, idx], descending=False) 1003 | rank = (indices == idx).nonzero().squeeze().item() 1004 | mean_r2l += rank + 1 1005 | mrr_r2l += 1.0 / (rank + 1) 1006 | for i in range(len(top_k)): 1007 | if rank < top_k[i]: 1008 | acc_r2l[i] += 1 1009 | 1010 | mean_l2r /= self.test_left.size(0) 1011 | mean_r2l /= self.test_right.size(0) 1012 | mrr_l2r /= self.test_left.size(0) 1013 | mrr_r2l /= self.test_right.size(0) 1014 | for i in range(len(top_k)): 1015 | acc_l2r[i] = round(acc_l2r[i] / self.test_left.size(0), 4) 1016 | acc_r2l[i] = round(acc_r2l[i] / self.test_right.size(0), 4) 1017 | del ( 1018 | distance, 1019 | gph_emb, 1020 | img_emb, 1021 | rel_emb, 1022 | att_emb, 1023 | name_emb, 1024 | char_emb, 1025 | joint_emb, 1026 | ) 1027 | gc.collect() 1028 | print( 1029 | "l2r: acc of top {} = {}, mr = {:.3f}, mrr = {:.3f}, time = {:.4f} s ".format( 1030 | top_k, acc_l2r, mean_l2r, mrr_l2r, time.time() - t_test 1031 | ) 1032 | ) 1033 | print( 1034 | "r2l: acc of top {} = {}, mr = {:.3f}, mrr = {:.3f}, time = {:.4f} s \n".format( 1035 | top_k, acc_r2l, mean_r2l, mrr_r2l, time.time() - t_test 1036 | ) 1037 | ) 1038 | if acc_l2r[0] > self.best_hit_1: 1039 | self.best_hit_1 = acc_l2r[0] 1040 | self.best_epoch = epoch 1041 | self.best_data_list = [ 1042 | acc_l2r[0], 1043 | acc_l2r[1], 1044 | mrr_l2r, 1045 | mean_l2r, 1046 | ] 1047 | import copy 1048 | 1049 | self.best_to_write = copy.deepcopy(to_write) 1050 | 1051 | if epoch + 1 == self.args.epochs: 1052 | pred_name = self.args.pred_name 1053 | import csv 1054 | 1055 | save_path = os.path.join(self.args.save_path, pred_name) 1056 | if not os.path.exists(save_path): 1057 | os.mkdir(save_path) 1058 | with open(os.path.join(save_path, pred_name + ".txt"), "w") as f: 1059 | wr = csv.writer(f, dialect="excel") 1060 | wr.writerows(self.best_to_write) 1061 | 1062 | 1063 | if __name__ == "__main__": 1064 | model = IBMEA() 1065 | model.train() 1066 | --------------------------------------------------------------------------------