├── codes ├── utils │ ├── functions.py │ ├── utils.py │ ├── contrastive_loss.py │ └── dataset.py └── model │ ├── lightning_m3el.py │ └── modeling_m3el.py ├── run.sh ├── config ├── wikimel.yaml ├── wikidiverse.yaml └── richpediamel.yaml ├── main.py └── README.md /codes/utils/functions.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from omegaconf import OmegaConf 3 | 4 | 5 | def setup_parser(): 6 | parser = argparse.ArgumentParser(add_help=False) 7 | parser.add_argument('--config', type=str, default='config/wikidiverse.yaml') 8 | _args = parser.parse_args() 9 | args = OmegaConf.load(_args.config) 10 | return args 11 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | export GPU=2 2 | 3 | #export CONFIG=./config/wikimel.yaml 4 | #export LOG=./logs/wikimel_baseline.logs 5 | 6 | #export CONFIG=./config/richpediamel.yaml 7 | #export LOG=./logs/richpediamel_baseline.logs 8 | 9 | export CONFIG=./config/wikidiverse.yaml 10 | export LOG=./logs/wikidiverse_baseline.logs 11 | 12 | CUDA_VISIBLE_DEVICES=$GPU nohup python -u ./main.py --config $CONFIG \ 13 | > $LOG 2>&1 & -------------------------------------------------------------------------------- /config/wikimel.yaml: -------------------------------------------------------------------------------- 1 | run_name: WikiMEL 2 | seed: 43 3 | pretrained_model: '/checkpoint/clip-vit-base-patch32' 4 | lr: 1e-5 5 | 6 | 7 | data: 8 | num_entity: 109976 9 | kb_img_folder: /data/WikiMEL/kb_image 10 | mention_img_folder: /data/WikiMEL/mention_image 11 | qid2id: /data/WikiMEL/qid2id.json 12 | entity: /data/WikiMEL/kb_entity.json 13 | train_file: /data/WikiMEL/WIKIMEL_train.json 14 | dev_file: /data/WikiMEL/WIKIMEL_dev.json 15 | test_file: /data/WikiMEL/WIKIMEL_test.json 16 | 17 | batch_size: 128 18 | num_workers: 8 19 | text_max_length: 40 20 | 21 | eval_chunk_size: 6000 22 | eval_batch_size: 20 23 | embed_update_batch_size: 512 24 | 25 | 26 | model: 27 | input_hidden_dim: 512 28 | input_image_hidden_dim: 768 29 | hidden_dim: 96 30 | dv: 96 31 | dt: 512 32 | TIMM_hidden_dim: 96 33 | IIMM_hidden_dim: 96 34 | CMM_hidden_dim: 96 35 | head_num: 5 36 | weight: 1.0 37 | loss_type: 0 38 | loss_temperature: 0.03 39 | inter_weight: 1.0 40 | intra_weight: 0.8 41 | with_cl_loss: 1 42 | 43 | 44 | trainer: 45 | accelerator: 'gpu' 46 | devices: 1 47 | max_epochs: 20 48 | num_sanity_val_steps: 0 49 | check_val_every_n_epoch: 2 50 | log_every_n_steps: 30 -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping 4 | from codes.utils.functions import setup_parser 5 | from codes.model.lightning_m3el import LightningForM3EL 6 | from codes.utils.dataset import DataModuleForM3EL 7 | 8 | if __name__ == '__main__': 9 | args = setup_parser() 10 | pl.seed_everything(args.seed, workers=True) 11 | torch.set_num_threads(1) 12 | 13 | data_module = DataModuleForM3EL(args) 14 | lightning_model = LightningForM3EL(args) 15 | 16 | logger = pl.loggers.CSVLogger("./runs", name=args.run_name, flush_logs_every_n_steps=30) 17 | 18 | ckpt_callbacks = ModelCheckpoint(monitor='Val/mrr', save_weights_only=True, mode='max') 19 | early_stop_callback = EarlyStopping(monitor="Val/mrr", min_delta=0.00, patience=3, verbose=True, mode="max") 20 | 21 | trainer = pl.Trainer(**args.trainer, 22 | deterministic=True, logger=logger, default_root_dir="./runs", 23 | callbacks=[ckpt_callbacks, early_stop_callback]) 24 | 25 | trainer.fit(lightning_model, datamodule=data_module) 26 | trainer.test(lightning_model, datamodule=data_module, ckpt_path='best') 27 | -------------------------------------------------------------------------------- /config/wikidiverse.yaml: -------------------------------------------------------------------------------- 1 | run_name: WikiDiverse 2 | seed: 43 3 | pretrained_model: '/checkpoint/clip-vit-base-patch32' 4 | lr: 1e-5 5 | 6 | 7 | data: 8 | num_entity: 132460 9 | kb_img_folder: /data/WikiDiverse/kb_image 10 | mention_img_folder: /data/WikiDiverse/mention_image 11 | qid2id: /data/WikiDiverse/qid2id.json 12 | entity: /data/WikiDiverse/kb_entity.json 13 | train_file: /data/WikiDiverse/WikiDiverse_train.json 14 | dev_file: /data/WikiDiverse/WikiDiverse_dev.json 15 | test_file: /data/WikiDiverse/WikiDiverse_test.json 16 | 17 | batch_size: 128 18 | num_workers: 8 19 | text_max_length: 40 20 | 21 | eval_chunk_size: 6000 22 | eval_batch_size: 20 23 | embed_update_batch_size: 512 24 | 25 | 26 | model: 27 | input_hidden_dim: 512 28 | input_image_hidden_dim: 768 29 | hidden_dim: 96 30 | dv: 96 31 | dt: 512 32 | TIMM_hidden_dim: 96 33 | IIMM_hidden_dim: 96 34 | CMM_hidden_dim: 96 35 | head_num: 5 36 | weight: 1.0 37 | loss_type: 0 38 | loss_temperature: 0.03 39 | inter_weight: 1.0 40 | intra_weight: 0.8 41 | with_cl_loss: 1 42 | 43 | 44 | trainer: 45 | accelerator: 'gpu' 46 | devices: 1 47 | max_epochs: 20 48 | num_sanity_val_steps: 0 49 | check_val_every_n_epoch: 2 50 | log_every_n_steps: 30 -------------------------------------------------------------------------------- /config/richpediamel.yaml: -------------------------------------------------------------------------------- 1 | run_name: RichpediaMEL 2 | seed: 43 3 | pretrained_model: '/checkpoint/clip-vit-base-patch32' 4 | lr: 1e-5 5 | 6 | 7 | data: 8 | num_entity: 160933 9 | kb_img_folder: /data/RichpediaMEL/kb_image 10 | mention_img_folder: /data/RichpediaMEL/mention_image 11 | qid2id: /data/RichpediaMEL/qid2id.json 12 | entity: /data/RichpediaMEL/kb_entity.json 13 | train_file: /data/RichpediaMEL/RichpediaMEL_train.json 14 | dev_file: /data/RichpediaMEL/RichpediaMEL_dev.json 15 | test_file: /data/RichpediaMEL/RichpediaMEL_test.json 16 | 17 | batch_size: 128 18 | num_workers: 8 19 | text_max_length: 40 20 | 21 | eval_chunk_size: 6000 22 | eval_batch_size: 20 23 | embed_update_batch_size: 512 24 | 25 | 26 | model: 27 | input_hidden_dim: 512 28 | input_image_hidden_dim: 768 29 | hidden_dim: 96 30 | dv: 96 31 | dt: 512 32 | TIMM_hidden_dim: 96 33 | IIMM_hidden_dim: 96 34 | CMM_hidden_dim: 96 35 | head_num: 5 36 | weight: 1.0 37 | loss_type: 0 38 | loss_temperature: 0.03 39 | inter_weight: 1.0 40 | intra_weight: 0.8 41 | with_cl_loss: 1 42 | 43 | 44 | trainer: 45 | accelerator: 'gpu' 46 | devices: 1 47 | max_epochs: 20 48 | num_sanity_val_steps: 0 49 | check_val_every_n_epoch: 2 50 | log_every_n_steps: 30 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-level Matching Network for Multimodal Entity Linking 2 | #### This repo provides the source code & data of our paper: [Multi-level Matching Network for Multimodal Entity Linking(KDD2025)](https://arxiv.org/pdf/2412.10440). 3 | 4 | ## Dependencies 5 | * conda create -n m3el python=3.7 -y 6 | * torch==1.11.0+cu113 7 | * transformers==4.27.1 8 | * torchmetrics==0.11.0 9 | * tokenizers==0.12.1 10 | * pytorch-lightning==1.7.7 11 | * omegaconf==2.2.3 12 | * pillow==9.3.0 13 | 14 | **Note:** We found that the version of transformers affects performance. 15 | 16 | ## Running the code 17 | ### Dataset 18 | 1. Download the datasets from [MIMIC paper](https://github.com/pengfei-luo/MIMIC). 19 | 2. Create the root directory ./data and put the dataset in. 20 | 3. Download the [M3EL](https://drive.google.com/drive/folders/1mZoE28f6FSxRyogjZuKSjA7-7YaC1nKe?usp=sharing) datasets, and replace the dataset with the same name in MIMIC. 21 | 4. Download the pretrained_weight from [clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32). 22 | 5. Create the root directory ./checkpoint and put the pretrained_weight in. 23 | 24 | ### Training model 25 | ```python 26 | sh run.sh 27 | ``` 28 | **Note:** We provide commands for running three datasets in run.sh. You can switch commands by opening comments. 29 | 30 | ### Training logs 31 | **Note:** We provide logs of our training in the logs directory. 32 | 33 | ## Citation 34 | If you find this code useful, please consider citing the following paper. 35 | ``` 36 | @article{ 37 | author={Zhiwei Hu and Víctor Gutiérrez-Basulto and Ru Li and Jeff Z. Pan}, 38 | title={Multi-level Matching Network for Multimodal Entity Linking}, 39 | publisher="ACM SIGKDD Conference on Knowledge Discovery and Data Mining", 40 | year={2025} 41 | } 42 | ``` 43 | ## Acknowledgement 44 | We refer to the code of [MIMIC](https://github.com/pengfei-luo/MIMIC). Thanks for their contributions. 45 | -------------------------------------------------------------------------------- /codes/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class MatchModule(nn.Module): 5 | def __init__(self, hidden_size): 6 | super(MatchModule, self).__init__() 7 | self.trans_linear = nn.Linear(hidden_size, hidden_size) 8 | 9 | def forward(self, inputs): 10 | proj_p, proj_q = inputs 11 | trans_q = self.trans_linear(proj_q) 12 | att_weights = proj_p.bmm(torch.transpose(trans_q, 1, 2)) 13 | att_norm = nn.Softmax(dim=-1)(att_weights) 14 | att_vec = att_norm.bmm(proj_q) 15 | output = nn.ReLU()(self.trans_linear(att_vec)) 16 | 17 | return output 18 | 19 | class CSRA(nn.Module): 20 | def __init__(self, T, lam): 21 | super(CSRA, self).__init__() 22 | self.T = T 23 | self.lam = lam 24 | self.softmax = nn.Softmax(dim=1) 25 | 26 | def forward(self, score): 27 | 28 | base_logit = torch.mean(score, dim=1) 29 | 30 | if self.T == 99: 31 | att_logit = torch.max(score, dim=1)[0] 32 | else: 33 | score_soft = self.softmax(score * self.T) 34 | att_logit = torch.sum(score * score_soft, dim=1) 35 | 36 | return base_logit + self.lam * att_logit 37 | 38 | class MultiHeadModule(nn.Module): 39 | temp_settings = { 40 | 1: [3], 41 | 2: [3, 99], 42 | 3: [2, 4, 99], 43 | 4: [2, 3, 4, 99], 44 | 5: [2, 2.5, 3.5, 4.5, 99], 45 | 6: [2, 3, 4, 5, 6, 99], 46 | 7: [0.5, 2.5, 3.5, 4.5, 5.5, 6.5, 99], 47 | 8: [0.5, 2, 3, 4, 5, 6, 7, 99] 48 | } 49 | 50 | def __init__(self, num_heads, lam, weight=False): 51 | super(MultiHeadModule, self).__init__() 52 | self.num_heads = num_heads 53 | self.temp_list = self.temp_settings[num_heads] 54 | self.multi_head = nn.ModuleList([ 55 | CSRA(self.temp_list[i], lam) 56 | for i in range(num_heads) 57 | ]) 58 | self.weight = nn.Parameter(torch.ones(num_heads, 1)) 59 | if weight: 60 | self.weight.requires_grad = True 61 | else: 62 | self.weight.requires_grad = False 63 | 64 | def forward(self, x): 65 | logit = 0. 66 | for head, weight in zip(self.multi_head, self.weight): 67 | logit += head(x) * weight 68 | return logit / self.num_heads 69 | 70 | class FusionModule(nn.Module): 71 | def __init__(self, hidden_size): 72 | super(FusionModule, self).__init__() 73 | self.linear = nn.Linear(hidden_size, hidden_size) 74 | 75 | def forward(self, inputs): 76 | p, q = inputs 77 | lq = self.linear(q) 78 | lp = self.linear(p) 79 | mid = nn.Sigmoid()(lq+lp) 80 | output = p * mid + q * (1-mid) 81 | return output -------------------------------------------------------------------------------- /codes/utils/contrastive_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.distributed as dist 6 | 7 | class ContiguousGrad(torch.autograd.Function): 8 | @staticmethod 9 | def forward(ctx, x): 10 | return x 11 | 12 | @staticmethod 13 | def backward(ctx, grad_out): 14 | return grad_out.contiguous() 15 | 16 | def normalize(t, dim, eps=1e-6): 17 | return F.normalize(t, dim=dim, eps=eps) 18 | 19 | def gather_cat(x: torch.Tensor, grad=False, contiguous_grad=False) -> torch.Tensor: 20 | if not grad: 21 | gathers = [torch.empty_like(x) for _ in range(dist.get_world_size())] 22 | dist.all_gather(gathers, x) 23 | else: 24 | gathers = torch.distributed.nn.all_gather(x) 25 | 26 | if x.ndim == 0: 27 | gathers = torch.stack(gathers) 28 | else: 29 | gathers = torch.cat(gathers) 30 | 31 | if contiguous_grad: 32 | gathers = ContiguousGrad.apply(gathers) 33 | 34 | return gathers 35 | 36 | class InfoNCELoss(nn.Module): 37 | def __init__(self, T_init=0.07, **kwargs): 38 | super().__init__() 39 | self.tau = T_init 40 | self.xe = nn.CrossEntropyLoss(reduction='none') 41 | 42 | def forward(self, image_emb, text_emb): 43 | n = image_emb.shape[0] 44 | logits = image_emb @ text_emb.T 45 | labels = torch.arange(n).cuda() 46 | loss_t = self.xe(logits, labels) 47 | loss_i = self.xe(logits.T, labels) 48 | loss = (loss_i + loss_t) / 2 49 | loss = loss.mean() 50 | return loss 51 | 52 | class MCLETLoss(nn.Module): 53 | def __init__(self, embedding_dim, cl_temperature=0.05): 54 | super(MCLETLoss, self).__init__() 55 | self.cl_temperature = cl_temperature 56 | self.embedding_dim = embedding_dim 57 | self.cl_fc = nn.Sequential( 58 | nn.Linear(self.embedding_dim, self.embedding_dim, bias=True), 59 | nn.ELU(), 60 | nn.Linear(self.embedding_dim, self.embedding_dim, bias=True), 61 | ) 62 | 63 | def sim(self, z1: torch.Tensor, z2: torch.Tensor): 64 | z1 = F.normalize(z1) 65 | z2 = F.normalize(z2) 66 | return torch.mm(z1, z2.t()) 67 | 68 | def forward(self, A_embedding, B_embedding): 69 | tau = self.cl_temperature 70 | f = lambda x: torch.exp(x / tau) 71 | A_embedding = self.cl_fc(A_embedding) 72 | B_embedding = self.cl_fc(B_embedding) 73 | 74 | refl_sim_1 = f(self.sim(A_embedding, A_embedding)) 75 | between_sim_1 = f(self.sim(A_embedding, B_embedding)) 76 | loss_1 = -torch.log(between_sim_1.diag() / (refl_sim_1.sum(1) + between_sim_1.sum(1) - refl_sim_1.diag())) 77 | 78 | refl_sim_2 = f(self.sim(B_embedding, B_embedding)) 79 | between_sim_2 = f(self.sim(B_embedding, A_embedding)) 80 | loss_2 = -torch.log(between_sim_2.diag() / (refl_sim_2.sum(1) + between_sim_2.sum(1) - refl_sim_2.diag())) 81 | 82 | loss = (loss_1 + loss_2) * 0.5 83 | loss = loss.mean() 84 | 85 | return loss 86 | 87 | class WeightedContrastiveLoss(nn.Module): 88 | def __init__(self, temperature=0.03, inter_weight=1.0, intra_weight=0.8, logger=None): 89 | super().__init__() 90 | self.logit_scale = nn.Parameter(torch.ones([])) 91 | self.criterion = torch.nn.CrossEntropyLoss(reduction='none') 92 | self.temperature = temperature 93 | self.logger = logger 94 | self.inter_weight = inter_weight 95 | self.intra_weight = intra_weight 96 | 97 | def compute_loss(self, logits, mask): 98 | return - torch.log((F.softmax(logits, dim=1) * mask).sum(1)) 99 | 100 | def _get_positive_mask(self, batch_size): 101 | diag = np.eye(batch_size) 102 | mask = torch.from_numpy((diag)) 103 | mask = (1 - mask) 104 | return mask.cuda(non_blocking=True) 105 | 106 | def forward(self, entity_features, mention_features): 107 | batch_size = entity_features.shape[0] 108 | 109 | # Normalize features 110 | entity_features = nn.functional.normalize(entity_features, dim=1) 111 | mention_features = nn.functional.normalize(mention_features, dim=1) 112 | 113 | # Inter-modality alignment 114 | inter_entity = entity_features @ mention_features.t() 115 | inter_mention = mention_features @ entity_features.t() 116 | 117 | # Intra-modality alignment 118 | inner_entity = entity_features @ entity_features.t() 119 | inner_mention = mention_features @ mention_features.t() 120 | 121 | inter_entity /= self.temperature 122 | inter_mention /= self.temperature 123 | inner_entity /= self.temperature 124 | inner_mention /= self.temperature 125 | 126 | positive_mask = self._get_positive_mask(entity_features.shape[0]) 127 | inner_entity = inner_entity * positive_mask 128 | inner_mention = inner_mention * positive_mask 129 | 130 | entity_logits = torch.cat([self.inter_weight * inter_entity, self.intra_weight * inner_entity], dim=1) 131 | mention_logits = torch.cat([self.inter_weight * inter_mention, self.intra_weight * inner_mention], dim=1) 132 | 133 | diag = np.eye(batch_size) 134 | mask_entity = torch.from_numpy((diag)).cuda() 135 | mask_mention = torch.from_numpy((diag)).cuda() 136 | 137 | mask_neg_entity = torch.zeros_like(inner_entity) 138 | mask_neg_mention = torch.zeros_like(inner_mention) 139 | mask_entity = torch.cat([mask_entity, mask_neg_entity], dim=1) 140 | mask_mention = torch.cat([mask_mention, mask_neg_mention], dim=1) 141 | 142 | loss_entity = self.compute_loss(entity_logits, mask_entity) 143 | loss_mention = self.compute_loss(mention_logits, mask_mention) 144 | 145 | return ((loss_entity.mean() + loss_mention.mean())) / 2 -------------------------------------------------------------------------------- /codes/utils/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import json 4 | import os.path 5 | import random 6 | import pickle 7 | 8 | import torch 9 | import pytorch_lightning as pl 10 | from PIL import Image 11 | from tqdm import tqdm 12 | from torch.utils.data import DataLoader 13 | from transformers import CLIPProcessor 14 | from urllib.parse import unquote 15 | 16 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 17 | 18 | 19 | def _load_json_file(filepath): 20 | data = [] 21 | if isinstance(filepath, str): 22 | with open(filepath, 'r', encoding='utf-8') as f: 23 | d = json.load(f) 24 | data.extend(d) 25 | elif isinstance(filepath, list): 26 | for path in filepath: 27 | with open(path, 'r', encoding='utf-8') as f: 28 | d = json.load(f) 29 | data.extend(d) 30 | return data 31 | 32 | 33 | class DataModuleForM3EL(pl.LightningDataModule): 34 | def __init__(self, args): 35 | super(DataModuleForM3EL, self).__init__() 36 | self.args = args 37 | current_directory = os.path.dirname(os.path.abspath(__file__)) 38 | base_path = current_directory[0:current_directory.rfind('/')] 39 | self.base_path = base_path[0:base_path.rfind('/')] 40 | self.tokenizer = CLIPProcessor.from_pretrained(self.base_path + self.args.pretrained_model).tokenizer 41 | self.image_processor = CLIPProcessor.from_pretrained(self.base_path + self.args.pretrained_model).feature_extractor 42 | with open(self.base_path + self.args.data.qid2id, 'r', encoding='utf-8') as f: 43 | self.qid2id = json.loads(f.readline()) 44 | self.raw_kb_entity = sorted(_load_json_file(self.base_path + self.args.data.entity), key=lambda x: x['id']) 45 | self.kb_entity = self.setup_dataset_for_entity(self.base_path + self.args.data.entity, self.raw_kb_entity) 46 | self.kb_id2entity = {raw_ent['id']: ent for raw_ent, ent in zip(self.raw_kb_entity, self.kb_entity)} 47 | 48 | self.train_data = self.setup_dataset_for_mention(self.base_path + self.args.data.train_file, _load_json_file(self.base_path + self.args.data.train_file)) 49 | self.val_data = self.setup_dataset_for_mention(self.base_path + self.args.data.dev_file, _load_json_file(self.base_path + self.args.data.dev_file)) 50 | self.test_data = self.setup_dataset_for_mention(self.base_path + self.args.data.test_file, _load_json_file(self.base_path + self.args.data.test_file)) 51 | 52 | def setup_dataset_for_entity(self, path, data): 53 | # prepare entity information 54 | pkl_path = path[0:path.rfind('.')] + '.pkl' 55 | # if os.path.exists(pkl_path): 56 | # with open(pkl_path, 'rb') as file: 57 | # input_data = pickle.load(file) 58 | # return input_data 59 | 60 | input_data = [] 61 | for sample_dict in tqdm(data, desc='PreProcessing'): 62 | sample_type = sample_dict['type'] 63 | if sample_type == 'entity': 64 | entity, desc = unquote(sample_dict.pop('entity_name')), sample_dict.pop('desc') 65 | input_text = entity + ' [SEP] ' + desc # concat entity and sentence 66 | input_dict = self.tokenizer(input_text, padding='max_length', max_length=self.args.data.text_max_length, truncation=True) 67 | input_dict['img_list'] = sample_dict['image_list'] 68 | input_dict['sample_type'] = 0 if sample_type == 'entity' else 1 69 | if 'answer' in sample_dict.keys(): 70 | input_dict['answer'] = self.qid2id[sample_dict['answer']] 71 | input_data.append(input_dict) 72 | 73 | with open(pkl_path, 'wb') as file: 74 | pickle.dump(input_data, file) 75 | 76 | return input_data 77 | 78 | def setup_dataset_for_mention(self, path, data): 79 | # prepare mention information 80 | pkl_path = path[0:path.rfind('.')] + '.pkl' 81 | # if os.path.exists(pkl_path): 82 | # with open(pkl_path, 'rb') as file: 83 | # input_data = pickle.load(file) 84 | # return input_data 85 | 86 | input_data = [] 87 | for sample_dict in tqdm(data, desc='PreProcessing'): 88 | sample_type = 1 89 | entity, mention, text = unquote(sample_dict.pop('entities')), unquote(sample_dict.pop('mentions')), sample_dict.pop('sentence') 90 | input_text = mention + ' [SEP] ' + text # concat entity and text 91 | input_dict = self.tokenizer(input_text, padding='max_length', max_length=self.args.data.text_max_length, truncation=True) 92 | 93 | input_dict['img_list'] = [sample_dict['imgPath']] if sample_dict['imgPath'] != '' else [] 94 | input_dict['sample_type'] = sample_type 95 | if 'answer' in sample_dict.keys(): 96 | input_dict['answer'] = self.qid2id[sample_dict['answer']] 97 | if sample_dict['answer'] == 'nil': # ignore the sample without ground truth 98 | continue 99 | input_data.append(input_dict) 100 | 101 | with open(pkl_path, 'wb') as file: 102 | pickle.dump(input_data, file) 103 | 104 | return input_data 105 | 106 | def choose_image(self, sample_type, img_list, is_eval=False): 107 | if len(img_list): 108 | img_name = random.choice(img_list) 109 | # when evaluation, we choose the first image 110 | if is_eval: 111 | img_name = img_list[0] 112 | if sample_type == 1: 113 | img_name = img_name.split('/')[-1].split('.')[0] + '.jpg' # we already convert all image to jpg format 114 | try: 115 | img_path = os.path.join( 116 | self.base_path + self.args.data.kb_img_folder if sample_type == 0 else self.base_path + self.args.data.mention_img_folder, 117 | img_name) 118 | img = Image.open(img_path).resize((224, 224), Image.Resampling.LANCZOS) 119 | pixel_values = self.image_processor(img, return_tensors='pt')['pixel_values'].squeeze() 120 | except: 121 | pixel_values = torch.zeros((3, 224, 224)) 122 | else: 123 | pixel_values = torch.zeros((3, 224, 224)) 124 | return pixel_values 125 | 126 | def train_collator(self, samples): 127 | cls_idx, img_list, sample_type, input_dict_list = [], [], [], [] 128 | pixel_values, gt_ent_id = [], [] 129 | 130 | # collect the metadata that need to further process 131 | for sample_idx, sample in enumerate(samples): 132 | img_list.append(sample.pop('img_list')) # mention image list 133 | sample_type.append(sample.pop('sample_type')) # input type: 0 for mention and 1 for entity 134 | input_dict_list.append(sample) # mention input dict (input_tokens, token_type_ids, attention_mask) 135 | gt_ent_id.append(sample.pop('answer')) # ground truth entity id of mentions 136 | ### 137 | # Now we process mention information 138 | # choose an image 139 | for idx, _ in enumerate(input_dict_list): 140 | pixel_values.append(self.choose_image(sample_type[idx], img_list[idx])) 141 | # pad textual input 142 | input_dict = self.tokenizer.pad(input_dict_list, 143 | padding='max_length', 144 | max_length=self.args.data.text_max_length, 145 | return_tensors='pt') 146 | # concat all images 147 | pixel_values = torch.stack(pixel_values) 148 | input_dict['pixel_values'] = pixel_values 149 | 150 | ### 151 | # now we process entity information 152 | # fetch the entities' metadata 153 | ent_info_list = [copy.deepcopy(self.kb_id2entity[idx]) for idx in gt_ent_id] 154 | ent_img_list, ent_type, ent_input_dict_list, ent_pixel_values = [], [], [], [] 155 | for ent_dict in ent_info_list: 156 | ent_img_list.append(ent_dict.pop('img_list')) # entity image list 157 | ent_type.append(ent_dict.pop('sample_type')) # input type: 0 for mention and 1 for entity 158 | ent_input_dict_list.append(ent_dict) # entity input dict (input_tokens, token_type_ids, attention_mask) 159 | # choose an image 160 | for idx, _ in enumerate(ent_input_dict_list): 161 | ent_pixel_values.append(self.choose_image(ent_type[idx], ent_img_list[idx])) 162 | # some of the entities do not have image, so we use bool flags to tag them 163 | ent_empty_img_flag = torch.tensor([True if not len(_) else False for _ in ent_img_list], dtype=torch.bool) 164 | # pad textual input 165 | ent_input_dict = self.tokenizer.pad(ent_input_dict_list, 166 | padding='max_length', 167 | max_length=self.args.data.text_max_length, 168 | return_tensors='pt') 169 | # concat all image 170 | ent_pixel_values = torch.stack(ent_pixel_values) 171 | ent_input_dict['pixel_values'] = ent_pixel_values 172 | ent_input_dict['empty_img_flag'] = ent_empty_img_flag 173 | 174 | # for the entity information, we use prefix 'ent_' to tag them 175 | for k, v in ent_input_dict.items(): 176 | input_dict[f'ent_{k}'] = v 177 | return input_dict 178 | 179 | def eval_collator(self, samples): 180 | # eval collator is similar to train collator, but only include mention information 181 | cls_idx, img_list, sample_type, input_dict_list = [], [], [], [] 182 | pixel_values, gt_ent_id = [], [] 183 | 184 | for sample_idx, sample in enumerate(samples): 185 | img_list.append(sample.pop('img_list')) 186 | sample_type.append(sample.pop('sample_type')) 187 | gt_ent_id.append(sample.pop('answer')) 188 | input_dict_list.append(sample) 189 | 190 | for idx, _ in enumerate(input_dict_list): 191 | pixel_values.append(self.choose_image(sample_type[idx], img_list[idx], is_eval=True)) 192 | 193 | input_dict = self.tokenizer.pad(input_dict_list, 194 | padding='max_length', 195 | max_length=self.args.data.text_max_length, 196 | return_tensors='pt') 197 | input_dict['pixel_values'] = torch.stack(pixel_values) 198 | input_dict['answer'] = torch.tensor(gt_ent_id, dtype=torch.long) 199 | return input_dict 200 | 201 | def entity_collator(self, samples): 202 | # entity collator is similar to train collator, but only include entity information 203 | pixel_values, img_list, sample_type, input_dict_list = [], [], [], [] 204 | for sample_idx, sample in enumerate(samples): 205 | img_list.append(sample.pop('img_list')) 206 | sample_type.append(sample.pop('sample_type')) 207 | input_dict_list.append(sample) 208 | for idx, input_dict in enumerate(input_dict_list): 209 | pixel_values.append(self.choose_image(sample_type[idx], img_list[idx], is_eval=True)) 210 | 211 | input_dict = self.tokenizer.pad(input_dict_list, 212 | padding='max_length', 213 | max_length=self.args.data.text_max_length, 214 | return_tensors='pt') 215 | input_dict['pixel_values'] = torch.stack(pixel_values) 216 | 217 | return input_dict 218 | 219 | def entity_dataloader(self): 220 | return DataLoader(self.kb_entity, 221 | batch_size=self.args.data.embed_update_batch_size, 222 | num_workers=self.args.data.num_workers, 223 | shuffle=False, 224 | collate_fn=self.entity_collator) 225 | 226 | def train_dataloader(self): 227 | return DataLoader(self.train_data, 228 | batch_size=self.args.data.batch_size, 229 | num_workers=self.args.data.num_workers, 230 | shuffle=True, 231 | collate_fn=self.train_collator) 232 | 233 | def val_dataloader(self): 234 | return DataLoader(self.val_data, 235 | batch_size=self.args.data.eval_batch_size, 236 | num_workers=self.args.data.num_workers, 237 | shuffle=False, 238 | collate_fn=self.eval_collator) 239 | 240 | def test_dataloader(self): 241 | return DataLoader(self.test_data, 242 | batch_size=self.args.data.eval_batch_size, 243 | num_workers=self.args.data.num_workers, 244 | shuffle=False, 245 | collate_fn=self.eval_collator) 246 | -------------------------------------------------------------------------------- /codes/model/lightning_m3el.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import pytorch_lightning as pl 5 | from tqdm import tqdm 6 | 7 | from codes.model.modeling_m3el import M3ELEncoder, M3ELMatcher 8 | 9 | class LightningForM3EL(pl.LightningModule): 10 | def __init__(self, args): 11 | super(LightningForM3EL, self).__init__() 12 | self.args = args 13 | self.save_hyperparameters(args) 14 | 15 | self.encoder = M3ELEncoder(args) 16 | self.matcher = M3ELMatcher(args) 17 | self.loss_fct = torch.nn.CrossEntropyLoss() 18 | 19 | def training_step(self, batch): 20 | ent_batch = {} 21 | mention_batch = {} 22 | for k, v in batch.items(): 23 | if k.startswith('ent_'): 24 | ent_batch[k.replace('ent_', '')] = v 25 | else: 26 | mention_batch[k] = v 27 | entity_empty_image_flag = ent_batch.pop('empty_img_flag') 28 | 29 | mention_text_embeds, mention_image_embeds, mention_text_seq_tokens, mention_image_patch_tokens = self.encoder(**mention_batch) 30 | entity_text_embeds, entity_image_embeds, entity_text_seq_tokens, entity_image_patch_tokens = self.encoder(**ent_batch) 31 | 32 | if 'WikiDiverse' in self.args.data.kb_img_folder: 33 | logits, (text_logits, image_logits, text_image_logits, image_text_logits), (cl_loss) = self.matcher(entity_text_embeds, 34 | entity_text_seq_tokens, 35 | mention_text_embeds, 36 | mention_text_seq_tokens, 37 | entity_image_embeds, 38 | entity_image_patch_tokens, 39 | mention_image_embeds, 40 | mention_image_patch_tokens, 41 | train_flag=True, bidirection=True) 42 | labels = torch.arange(len(mention_text_embeds)).long().to(mention_text_embeds.device) 43 | text_loss = self.loss_fct(text_logits, labels) 44 | image_loss = self.loss_fct(image_logits, labels) 45 | text_image_loss = self.loss_fct(text_image_logits, labels) 46 | image_text_loss = self.loss_fct(image_text_logits, labels) 47 | overall_loss = self.loss_fct(logits, labels) 48 | if cl_loss is not None: 49 | loss = overall_loss + text_loss + image_loss + text_image_loss + image_text_loss + cl_loss 50 | else: 51 | loss = overall_loss + text_loss + image_loss + text_image_loss + image_text_loss 52 | else: 53 | logits, (text_logits, image_logits, image_text_logits), (cl_loss) = self.matcher(entity_text_embeds, 54 | entity_text_seq_tokens, 55 | mention_text_embeds, 56 | mention_text_seq_tokens, 57 | entity_image_embeds, 58 | entity_image_patch_tokens, 59 | mention_image_embeds, 60 | mention_image_patch_tokens, 61 | train_flag=True, bidirection=False) 62 | labels = torch.arange(len(mention_text_embeds)).long().to(mention_text_embeds.device) 63 | text_loss = self.loss_fct(text_logits, labels) 64 | image_loss = self.loss_fct(image_logits, labels) 65 | image_text_loss = self.loss_fct(image_text_logits, labels) 66 | overall_loss = self.loss_fct(logits, labels) 67 | if cl_loss is not None: 68 | loss = overall_loss + text_loss + image_loss + image_text_loss + cl_loss 69 | else: 70 | loss = overall_loss + text_loss + image_loss + image_text_loss 71 | 72 | self.log('Train/loss', loss.detach().cpu().item(), on_epoch=True, prog_bar=True) 73 | return loss 74 | 75 | def validation_step(self, batch, batch_idx): 76 | answer = batch.pop('answer') 77 | batch_size = len(answer) 78 | mention_text_embeds, mention_image_embeds, mention_text_seq_tokens, mention_image_patch_tokens = \ 79 | self.encoder(**batch) 80 | 81 | scores = [] 82 | chunk_size = self.args.data.eval_chunk_size 83 | for idx in range(math.ceil(self.args.data.num_entity / chunk_size)): 84 | start_pos = idx * chunk_size 85 | end_pos = (idx + 1) * chunk_size 86 | 87 | chunk_entity_text_embeds = self.entity_text_embeds[start_pos:end_pos].to(mention_text_embeds.device) 88 | chunk_entity_image_embeds = self.entity_image_embeds[start_pos:end_pos].to(mention_text_embeds.device) 89 | chunk_entity_text_seq_tokens = self.entity_text_seq_tokens[start_pos:end_pos].to(mention_text_embeds.device) 90 | chunk_entity_image_patch_tokens = self.entity_image_patch_tokens[start_pos:end_pos].to( 91 | mention_text_embeds.device) 92 | 93 | chunk_score, _, _ = self.matcher(chunk_entity_text_embeds, chunk_entity_text_seq_tokens, 94 | mention_text_embeds, mention_text_seq_tokens, 95 | chunk_entity_image_embeds, chunk_entity_image_patch_tokens, 96 | mention_image_embeds, mention_image_patch_tokens, train_flag=False) 97 | scores.append(chunk_score) 98 | 99 | scores = torch.concat(scores, dim=-1) 100 | rank = torch.argsort(torch.argsort(scores, dim=-1, descending=True), dim=-1, descending=False) + 1 101 | tgt_rank = rank[torch.arange(batch_size), answer].detach().cpu() 102 | return dict(rank=tgt_rank, all_rank=rank.detach().cpu().numpy()) 103 | 104 | def on_validation_start(self): 105 | entity_dataloader = self.trainer.datamodule.entity_dataloader() 106 | outputs_text_embed = [] 107 | outputs_image_embed = [] 108 | outputs_text_seq_tokens = [] 109 | outputs_image_patch_tokens = [] 110 | 111 | with torch.no_grad(): 112 | for batch in tqdm(entity_dataloader, desc='UpdateEmbed', total=len(entity_dataloader), disable=True): 113 | batch = pl.utilities.move_data_to_device(batch, self.device) 114 | entity_text_embeds, entity_image_embeds, entity_text_seq_tokens, entity_image_patch_tokens = \ 115 | self.encoder(**batch) 116 | outputs_text_embed.append(entity_text_embeds.cpu()) 117 | outputs_image_embed.append(entity_image_embeds.cpu()) 118 | outputs_text_seq_tokens.append(entity_text_seq_tokens.cpu()) 119 | outputs_image_patch_tokens.append(entity_image_patch_tokens.cpu()) 120 | 121 | self.entity_text_embeds = torch.concat(outputs_text_embed, dim=0) 122 | self.entity_image_embeds = torch.concat(outputs_image_embed, dim=0) 123 | self.entity_text_seq_tokens = torch.concat(outputs_text_seq_tokens, dim=0) 124 | self.entity_image_patch_tokens = torch.concat(outputs_image_patch_tokens, dim=0) 125 | 126 | def validation_epoch_end(self, outputs): 127 | self.entity_text_embeds = None 128 | self.entity_image_embeds = None 129 | self.entity_text_seq_tokens = None 130 | self.entity_image_patch_tokens = None 131 | 132 | ranks = np.concatenate([_['rank'] for _ in outputs]) 133 | hits1 = (ranks <= 1).mean() 134 | hits3 = (ranks <= 3).mean() 135 | hits5 = (ranks <= 5).mean() 136 | hits10 = (ranks <= 10).mean() 137 | hits20 = (ranks <= 20).mean() 138 | 139 | self.log("Val/mr", ranks.mean()) 140 | self.log("Val/mrr", (1. / ranks).mean()) 141 | self.log("Val/hits1", hits1) 142 | self.log("Val/hits3", hits3) 143 | self.log("Val/hits5", hits5) 144 | self.log("Val/hits10", hits10) 145 | self.log("Val/hits20", hits20) 146 | 147 | 148 | def test_step(self, batch, batch_idx, dataloader_idx=None): 149 | answer = batch.pop('answer') 150 | batch_size = len(answer) 151 | mention_text_embeds, mention_image_embeds, mention_text_seq_tokens, mention_image_patch_tokens = \ 152 | self.encoder(**batch) 153 | 154 | scores = [] 155 | chunk_size = self.args.data.eval_chunk_size 156 | for idx in range(math.ceil(self.args.data.num_entity / chunk_size)): 157 | start_pos = idx * chunk_size 158 | end_pos = (idx + 1) * chunk_size 159 | 160 | chunk_entity_text_embeds = self.entity_text_embeds[start_pos:end_pos].to(mention_text_embeds.device) 161 | chunk_entity_image_embeds = self.entity_image_embeds[start_pos:end_pos].to(mention_text_embeds.device) 162 | chunk_entity_text_seq_tokens = self.entity_text_seq_tokens[start_pos:end_pos].to(mention_text_embeds.device) 163 | chunk_entity_image_patch_tokens = self.entity_image_patch_tokens[start_pos:end_pos].to( 164 | mention_text_embeds.device) 165 | 166 | chunk_score, _, _ = self.matcher(chunk_entity_text_embeds, chunk_entity_text_seq_tokens, 167 | mention_text_embeds, mention_text_seq_tokens, 168 | chunk_entity_image_embeds, chunk_entity_image_patch_tokens, 169 | mention_image_embeds, mention_image_patch_tokens, train_flag=False) 170 | scores.append(chunk_score) 171 | 172 | scores = torch.concat(scores, dim=-1) 173 | rank = torch.argsort(torch.argsort(scores, dim=-1, descending=True), dim=-1, descending=False) + 1 174 | tgt_rank = rank[torch.arange(batch_size), answer].detach().cpu() 175 | return dict(rank=tgt_rank, all_rank=rank.detach().cpu().numpy(), scores=scores.detach().cpu().numpy()) 176 | 177 | def on_test_start(self): 178 | entity_dataloader = self.trainer.datamodule.entity_dataloader() 179 | outputs_text_embed = [] 180 | outputs_image_embed = [] 181 | outputs_text_seq_tokens = [] 182 | outputs_image_patch_tokens = [] 183 | 184 | with torch.no_grad(): 185 | for batch in tqdm(entity_dataloader, desc='UpdateEmbed', total=len(entity_dataloader), disable=True): 186 | batch = pl.utilities.move_data_to_device(batch, self.device) 187 | entity_text_embeds, entity_image_embeds, entity_text_seq_tokens, entity_image_patch_tokens = \ 188 | self.encoder(**batch) 189 | outputs_text_embed.append(entity_text_embeds.cpu()) 190 | outputs_image_embed.append(entity_image_embeds.cpu()) 191 | outputs_text_seq_tokens.append(entity_text_seq_tokens.cpu()) 192 | outputs_image_patch_tokens.append(entity_image_patch_tokens.cpu()) 193 | 194 | self.entity_text_embeds = torch.concat(outputs_text_embed, dim=0) 195 | self.entity_image_embeds = torch.concat(outputs_image_embed, dim=0) 196 | self.entity_text_seq_tokens = torch.concat(outputs_text_seq_tokens, dim=0) 197 | self.entity_image_patch_tokens = torch.concat(outputs_image_patch_tokens, dim=0) 198 | 199 | def test_epoch_end(self, outputs): 200 | self.entity_text_embeds = None 201 | self.entity_image_embeds = None 202 | self.entity_text_seq_tokens = None 203 | self.entity_image_patch_tokens = None 204 | 205 | ranks = np.concatenate([_['rank'] for _ in outputs]) 206 | hits1 = (ranks <= 1).mean() 207 | hits3 = (ranks <= 3).mean() 208 | hits5 = (ranks <= 5).mean() 209 | hits10 = (ranks <= 10).mean() 210 | hits20 = (ranks <= 20).mean() 211 | 212 | self.log("Test/mr", ranks.mean()) 213 | self.log("Test/mrr", (1. / ranks).mean()) 214 | self.log("Test/hits1", hits1) 215 | self.log("Test/hits3", hits3) 216 | self.log("Test/hits5", hits5) 217 | self.log("Test/hits10", hits10) 218 | self.log("Test/hits20", hits20) 219 | 220 | def configure_optimizers(self): 221 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 222 | optimizer_grouped_params = [ 223 | {'params': [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay)], 224 | 'weight_decay': 0.0001}, 225 | {'params': [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 226 | ] 227 | optimizer = torch.optim.AdamW(optimizer_grouped_params, lr=self.args.lr, betas=(0.9, 0.999), eps=1e-4) 228 | return [optimizer] -------------------------------------------------------------------------------- /codes/model/modeling_m3el.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from transformers import CLIPModel 4 | from codes.utils.contrastive_loss import * 5 | from codes.utils.utils import * 6 | 7 | class M3ELEncoder(nn.Module): 8 | def __init__(self, args): 9 | super(M3ELEncoder, self).__init__() 10 | self.args = args 11 | current_directory = os.path.dirname(os.path.abspath(__file__)) 12 | base_path = current_directory[0:current_directory.rfind('/')] 13 | self.base_path = base_path[0:base_path.rfind('/')] 14 | self.clip = CLIPModel.from_pretrained(self.base_path + self.args.pretrained_model) 15 | self.image_cls_fc = nn.Linear(self.args.model.input_hidden_dim, self.args.model.dv) 16 | self.image_tokens_fc = nn.Linear(self.args.model.input_image_hidden_dim, self.args.model.dv) 17 | 18 | def forward(self, 19 | input_ids=None, 20 | attention_mask=None, 21 | token_type_ids=None, 22 | pixel_values=None): 23 | clip_output = self.clip(input_ids=input_ids, 24 | attention_mask=attention_mask, 25 | pixel_values=pixel_values) 26 | 27 | text_embeds = clip_output.text_embeds 28 | image_embeds = clip_output.image_embeds 29 | 30 | text_seq_tokens = clip_output.text_model_output[0] 31 | image_patch_tokens = clip_output.vision_model_output[0] 32 | 33 | image_embeds = self.image_cls_fc(image_embeds) 34 | image_patch_tokens = self.image_tokens_fc(image_patch_tokens) 35 | return text_embeds, image_embeds, text_seq_tokens, image_patch_tokens 36 | 37 | class TextIntraModalMatch(nn.Module): 38 | def __init__(self, args): 39 | super(TextIntraModalMatch, self).__init__() 40 | self.args = args 41 | self.fc_query = nn.Linear(self.args.model.input_hidden_dim, self.args.model.TIMM_hidden_dim) 42 | self.fc_key = nn.Linear(self.args.model.input_hidden_dim, self.args.model.TIMM_hidden_dim) 43 | self.fc_value = nn.Linear(self.args.model.input_hidden_dim, self.args.model.TIMM_hidden_dim) 44 | self.fc_cls = nn.Linear(self.args.model.input_hidden_dim, self.args.model.TIMM_hidden_dim) 45 | self.layer_norm = nn.LayerNorm(self.args.model.TIMM_hidden_dim) 46 | 47 | def forward(self, 48 | entity_text_cls, 49 | entity_text_tokens, 50 | mention_text_cls, 51 | mention_text_tokens): 52 | """ 53 | 54 | :param entity_text_cls: [num_entity, dim] 55 | :param entity_text_tokens: [num_entity, max_seq_len, dim] 56 | :param mention_text_cls: [batch_size, dim] 57 | :param mention_text_tokens: [batch_size, max_sqe_len, dim] 58 | :return: 59 | """ 60 | entity_cls_fc = self.fc_cls(entity_text_cls) 61 | entity_cls_fc = entity_cls_fc.unsqueeze(dim=1) 62 | 63 | query = self.fc_query(entity_text_tokens) 64 | key = self.fc_key(mention_text_tokens) 65 | value = self.fc_value(mention_text_tokens) 66 | 67 | query = query.unsqueeze(dim=1) 68 | key = key.unsqueeze(dim=0) 69 | value = value.unsqueeze(dim=0) 70 | 71 | attention_scores = torch.matmul(query, key.transpose(-1, -2)) 72 | attention_scores = attention_scores / math.sqrt(self.args.model.TIMM_hidden_dim) 73 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 74 | 75 | context = torch.matmul(attention_probs, value) 76 | context = torch.mean(context, dim=-2) 77 | context = self.layer_norm(context) 78 | 79 | g2l_matching_score = torch.sum(entity_cls_fc * context, dim=-1) 80 | g2l_matching_score = g2l_matching_score.transpose(0, 1) 81 | g2g_matching_score = torch.matmul(mention_text_cls, entity_text_cls.transpose(-1, -2)) 82 | 83 | matching_score = (g2l_matching_score + g2g_matching_score) / 2 84 | 85 | return matching_score 86 | 87 | class ImageIntraModalMatch(nn.Module): 88 | def __init__(self, args): 89 | super(ImageIntraModalMatch, self).__init__() 90 | self.args = args 91 | self.fc_query = nn.Linear(self.args.model.dv, self.args.model.IIMM_hidden_dim) 92 | self.fc_key = nn.Linear(self.args.model.dv, self.args.model.IIMM_hidden_dim) 93 | self.fc_value = nn.Linear(self.args.model.dv, self.args.model.IIMM_hidden_dim) 94 | self.fc_cls = nn.Linear(self.args.model.dv, self.args.model.IIMM_hidden_dim) 95 | self.layer_norm = nn.LayerNorm(self.args.model.IIMM_hidden_dim) 96 | 97 | def forward(self, 98 | entity_image_cls, 99 | entity_image_tokens, 100 | mention_image_cls, 101 | mention_image_tokens): 102 | """ 103 | :param entity_image_cls: [num_entity, dim] 104 | :param entity_image_tokens: [num_entity, num_patch, dim] 105 | :param mention_image_cls: [batch_size, dim] 106 | :param mention_image_tokens: [batch_size, num_patch, dim] 107 | :return: 108 | """ 109 | entity_cls_fc = self.fc_cls(entity_image_cls) 110 | entity_cls_fc = entity_cls_fc.unsqueeze(dim=1) 111 | 112 | query = self.fc_query(entity_image_tokens) 113 | key = self.fc_key(mention_image_tokens) 114 | value = self.fc_value(mention_image_tokens) 115 | 116 | query = query.unsqueeze(dim=1) 117 | key = key.unsqueeze(dim=0) 118 | value = value.unsqueeze(dim=0) 119 | 120 | attention_scores = torch.matmul(query, key.transpose(-1, -2)) 121 | attention_scores = attention_scores / math.sqrt(self.args.model.IIMM_hidden_dim) 122 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 123 | 124 | context = torch.matmul(attention_probs, value) 125 | context = torch.mean(context, dim=-2) 126 | context = self.layer_norm(context) 127 | 128 | g2l_matching_score = torch.sum(entity_cls_fc * context, dim=-1) 129 | g2l_matching_score = g2l_matching_score.transpose(0, 1) 130 | g2g_matching_score = torch.matmul(mention_image_cls, entity_image_cls.transpose(-1, -2)) 131 | 132 | matching_score = (g2l_matching_score + g2g_matching_score) / 2 133 | 134 | return matching_score 135 | 136 | class CrossModalMatch(nn.Module): 137 | def __init__(self, args): 138 | super(CrossModalMatch, self).__init__() 139 | self.args = args 140 | self.text_fc = nn.Linear(self.args.model.input_hidden_dim, self.args.model.CMM_hidden_dim) 141 | self.image_fc = nn.Linear(self.args.model.dv, self.args.model.CMM_hidden_dim) 142 | self.gate_fc = nn.Linear(self.args.model.CMM_hidden_dim, 1) 143 | self.gate_act = nn.Tanh() 144 | self.gate_layer_norm = nn.LayerNorm(self.args.model.CMM_hidden_dim) 145 | 146 | self.match_module = MatchModule(self.args.model.CMM_hidden_dim) 147 | self.multi_head_module= MultiHeadModule(self.args.model.head_num, self.args.model.weight) 148 | self.fusion_module = FusionModule(self.args.model.CMM_hidden_dim) 149 | 150 | self.mclet_text_loss = MCLETLoss(embedding_dim=self.args.model.CMM_hidden_dim, cl_temperature=0.6) 151 | 152 | def forward(self, entity_text_cls, entity_image_tokens, 153 | mention_text_cls, mention_image_tokens): 154 | """ 155 | :param entity_text_cls: [num_entity, dim] 156 | :param entity_image_tokens: [num_entity, num_patch, dim] 157 | :param mention_text_cls: [batch_size, dim] 158 | :param mention_image_tokens: [batch_size, num_patch, dim] 159 | :return: 160 | """ 161 | entity_text_cls = self.text_fc(entity_text_cls) 162 | entity_text_cls_ori = entity_text_cls 163 | mention_text_cls = self.text_fc(mention_text_cls) 164 | mention_text_cls_ori = mention_text_cls 165 | 166 | entity_image_tokens = self.image_fc(entity_image_tokens) 167 | mention_image_tokens = self.image_fc(mention_image_tokens) 168 | 169 | entity_text_cls = self.match_module([entity_text_cls_ori.unsqueeze(dim=1), entity_image_tokens]).squeeze() 170 | entity_image_tokens = self.match_module([entity_image_tokens, entity_text_cls_ori.unsqueeze(dim=1)]) 171 | entity_image_tokens = self.multi_head_module(entity_image_tokens) 172 | entity_context = self.fusion_module([entity_text_cls, entity_image_tokens]) 173 | entity_gate_score = self.gate_act(self.gate_fc(entity_text_cls_ori)) 174 | entity_context = self.gate_layer_norm((entity_text_cls_ori * entity_gate_score) + entity_context) 175 | 176 | mention_text_cls = self.match_module([mention_text_cls_ori.unsqueeze(dim=1), mention_image_tokens]).squeeze() 177 | mention_image_tokens = self.match_module([mention_image_tokens, mention_text_cls_ori.unsqueeze(dim=1)]) 178 | mention_image_tokens = self.multi_head_module(mention_image_tokens) 179 | mention_context = self.fusion_module([mention_text_cls, mention_image_tokens]) 180 | mention_gate_score = self.gate_act(self.gate_fc(mention_text_cls_ori)) 181 | mention_context = self.gate_layer_norm((mention_text_cls_ori * mention_gate_score) + mention_context) 182 | 183 | score = torch.matmul(mention_context, entity_context.transpose(-1, -2)) 184 | 185 | return score 186 | 187 | class CrossModalMatchBidirection(nn.Module): 188 | def __init__(self, args): 189 | super(CrossModalMatchBidirection, self).__init__() 190 | self.args = args 191 | self.text_fc = nn.Linear(self.args.model.input_hidden_dim, self.args.model.CMM_hidden_dim) 192 | self.image_fc = nn.Linear(self.args.model.dv, self.args.model.CMM_hidden_dim) 193 | 194 | self.gate_fc = nn.Linear(self.args.model.CMM_hidden_dim, 1) 195 | self.gate_act = nn.Tanh() 196 | self.gate_layer_norm = nn.LayerNorm(self.args.model.CMM_hidden_dim) 197 | 198 | self.match_module = MatchModule(self.args.model.CMM_hidden_dim) 199 | self.multi_head_module = MultiHeadModule(self.args.model.head_num, self.args.model.weight) 200 | 201 | self.mclet_text_loss = MCLETLoss(embedding_dim=self.args.model.CMM_hidden_dim, cl_temperature=0.6) 202 | self.mclet_image_loss = MCLETLoss(embedding_dim=self.args.model.CMM_hidden_dim, cl_temperature=0.6) 203 | 204 | def forward(self, 205 | entity_text_cls, entity_text_tokens, 206 | entity_image_cls, entity_image_tokens, 207 | mention_text_cls, mention_text_tokens, 208 | mention_image_cls, mention_image_tokens): 209 | entity_text_cls = self.text_fc(entity_text_cls) 210 | entity_text_cls_ori = entity_text_cls 211 | entity_text_tokens = self.text_fc(entity_text_tokens) 212 | entity_image_cls = self.image_fc(entity_image_cls) 213 | entity_image_cls_ori = entity_image_cls 214 | entity_image_tokens = self.image_fc(entity_image_tokens) 215 | 216 | mention_text_cls = self.text_fc(mention_text_cls) 217 | mention_text_cls_ori = mention_text_cls 218 | mention_text_tokens = self.text_fc(mention_text_tokens) 219 | mention_image_cls = self.image_fc(mention_image_cls) 220 | mention_image_cls_ori = mention_image_cls 221 | mention_image_tokens = self.image_fc(mention_image_tokens) 222 | 223 | entity_text_cls = self.match_module([entity_text_cls_ori.unsqueeze(dim=1), entity_image_tokens]).squeeze() 224 | entity_image_tokens = self.match_module([entity_image_tokens, entity_text_cls_ori.unsqueeze(dim=1)]) 225 | entity_text_image_context = self.multi_head_module(torch.cat([entity_text_cls.unsqueeze(dim=1), entity_image_tokens], dim=1)) 226 | entity_text_image_gate_score = self.gate_act(self.gate_fc(entity_text_cls_ori)) 227 | entity_text_image_context = self.gate_layer_norm((entity_text_cls_ori * entity_text_image_gate_score) + entity_text_image_context) 228 | 229 | entity_image_cls = self.match_module([entity_image_cls_ori.unsqueeze(dim=1), entity_text_tokens]).squeeze() 230 | entity_text_tokens = self.match_module([entity_text_tokens, entity_image_cls_ori.unsqueeze(dim=1)]) 231 | entity_image_text_context = self.multi_head_module(torch.cat([entity_image_cls.unsqueeze(dim=1), entity_text_tokens], dim=1)) 232 | entity_image_text_gate_score = self.gate_act(self.gate_fc(entity_image_cls_ori)) 233 | entity_image_text_context = self.gate_layer_norm((entity_image_cls_ori * entity_image_text_gate_score) + entity_image_text_context) 234 | 235 | mention_text_cls = self.match_module([mention_text_cls_ori.unsqueeze(dim=1), mention_image_tokens]).squeeze() 236 | mention_image_tokens = self.match_module([mention_image_tokens, mention_text_cls_ori.unsqueeze(dim=1)]) 237 | mention_text_image_context = self.multi_head_module(torch.cat([mention_text_cls.unsqueeze(dim=1), mention_image_tokens], dim=1)) 238 | mention_text_image_gate_score = self.gate_act(self.gate_fc(mention_text_cls_ori)) 239 | mention_text_image_context = self.gate_layer_norm((mention_text_cls_ori * mention_text_image_gate_score) + mention_text_image_context) 240 | 241 | mention_image_cls = self.match_module([mention_image_cls_ori.unsqueeze(dim=1), mention_text_tokens]).squeeze() 242 | mention_text_tokens = self.match_module([mention_text_tokens, mention_image_cls_ori.unsqueeze(dim=1)]) 243 | mention_image_text_context = self.multi_head_module(torch.cat([mention_image_cls.unsqueeze(dim=1), mention_text_tokens], dim=1)) 244 | mention_image_text_gate_score = self.gate_act(self.gate_fc(mention_image_cls_ori)) 245 | mention_image_text_context = self.gate_layer_norm((mention_image_cls_ori * mention_image_text_gate_score) + mention_image_text_context) 246 | 247 | score_text_image_context = torch.matmul(mention_text_image_context, entity_text_image_context.transpose(-1, -2)) 248 | score_image_text_context = torch.matmul(mention_image_text_context, entity_image_text_context.transpose(-1, -2)) 249 | 250 | return score_text_image_context, score_image_text_context 251 | 252 | class M3ELMatcher(nn.Module): 253 | def __init__(self, args): 254 | super(M3ELMatcher, self).__init__() 255 | self.args = args 256 | self.timm = TextIntraModalMatch(self.args) 257 | self.iimm = ImageIntraModalMatch(self.args) 258 | self.cmm = CrossModalMatch(self.args) 259 | self.cmmb = CrossModalMatchBidirection(self.args) 260 | 261 | self.text_cls_layernorm = nn.LayerNorm(self.args.model.dt) 262 | self.text_tokens_layernorm = nn.LayerNorm(self.args.model.dt) 263 | self.image_cls_layernorm = nn.LayerNorm(self.args.model.dv) 264 | self.image_tokens_layernorm = nn.LayerNorm(self.args.model.dv) 265 | 266 | self.text_fc = nn.Linear(self.args.model.input_hidden_dim, self.args.model.TIMM_hidden_dim) 267 | self.image_fc = nn.Linear(self.args.model.dv, self.args.model.IIMM_hidden_dim) 268 | self.scale_text_cls_layernorm = nn.LayerNorm(self.args.model.TIMM_hidden_dim) 269 | self.scale_text_tokens_layernorm = nn.LayerNorm(self.args.model.TIMM_hidden_dim) 270 | self.scale_image_cls_layernorm = nn.LayerNorm(self.args.model.IIMM_hidden_dim) 271 | self.scale_image_tokens_layernorm = nn.LayerNorm(self.args.model.IIMM_hidden_dim) 272 | 273 | self.mclet_text_loss = MCLETLoss(embedding_dim=self.args.model.IIMM_hidden_dim, cl_temperature=0.6) 274 | self.mclet_image_loss = MCLETLoss(embedding_dim=self.args.model.IIMM_hidden_dim, cl_temperature=0.6) 275 | self.weight_cl_loss = WeightedContrastiveLoss(temperature=self.args.model.loss_temperature, inter_weight=self.args.model.inter_weight, intra_weight=self.args.model.intra_weight) 276 | 277 | def forward(self, 278 | entity_text_cls, entity_text_tokens, 279 | mention_text_cls, mention_text_tokens, 280 | entity_image_cls, entity_image_tokens, 281 | mention_image_cls, mention_image_tokens, 282 | train_flag=False, bidirection=False): 283 | """ 284 | 285 | :param entity_text_cls: [num_entity, dim] 286 | :param entity_text_tokens: [num_entity, max_seq_len, dim] 287 | :param mention_text_cls: [batch_size, dim] 288 | :param mention_text_tokens: [batch_size, max_sqe_len, dim] 289 | :param entity_image_cls: [num_entity, dim] 290 | :param mention_image_cls: [batch_size, dim] 291 | :param entity_image_tokens: [num_entity, num_patch, dim] 292 | :param mention_image_tokens:[num_entity, num_patch, dim] 293 | :return: 294 | """ 295 | if train_flag == True: 296 | text_cl_loss = self.weight_cl_loss(entity_text_cls, mention_text_cls) 297 | image_cl_loss = self.weight_cl_loss(entity_image_cls, mention_image_cls) 298 | cl_loss = (text_cl_loss + image_cl_loss) / 2 299 | else: 300 | cl_loss = None 301 | 302 | entity_text_cls = self.text_cls_layernorm(entity_text_cls) 303 | mention_text_cls = self.text_cls_layernorm(mention_text_cls) 304 | 305 | entity_text_tokens = self.text_tokens_layernorm(entity_text_tokens) 306 | mention_text_tokens = self.text_tokens_layernorm(mention_text_tokens) 307 | 308 | entity_image_cls = self.image_cls_layernorm(entity_image_cls) 309 | mention_image_cls = self.image_cls_layernorm(mention_image_cls) 310 | 311 | entity_image_tokens = self.image_tokens_layernorm(entity_image_tokens) 312 | mention_image_tokens = self.image_tokens_layernorm(mention_image_tokens) 313 | 314 | text_matching_score = self.timm(entity_text_cls, entity_text_tokens, mention_text_cls, mention_text_tokens) 315 | image_matching_score = self.iimm(entity_image_cls, entity_image_tokens, mention_image_cls, mention_image_tokens) 316 | 317 | if bidirection == True: 318 | text_image_matching_score, image_text_matching_score = self.cmmb(entity_text_cls, entity_text_tokens, entity_image_cls, entity_image_tokens, mention_text_cls, mention_text_tokens, mention_image_cls, mention_image_tokens) 319 | score = (text_matching_score + image_matching_score + text_image_matching_score + image_text_matching_score) / 4 320 | return score, (text_matching_score, image_matching_score, text_image_matching_score, image_text_matching_score), (cl_loss) 321 | else: 322 | image_text_matching_score = self.cmm(entity_text_cls, entity_image_tokens, mention_text_cls, mention_image_tokens) 323 | score = (text_matching_score + image_matching_score + image_text_matching_score) / 3 324 | return score, (text_matching_score, image_matching_score, image_text_matching_score), (cl_loss) 325 | --------------------------------------------------------------------------------