├── codes ├── utils │ ├── functions.py │ └── dataset.py └── model │ ├── moe.py │ ├── lightning_mmoe.py │ └── modeling_mmoe.py ├── run.sh ├── config ├── wikimel.yaml ├── wikidiverse.yaml └── richpediamel.yaml ├── main.py └── README.md /codes/utils/functions.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from omegaconf import OmegaConf 3 | 4 | 5 | def setup_parser(): 6 | parser = argparse.ArgumentParser(add_help=False) 7 | parser.add_argument('--config', type=str, default='config/wikidiverse.yaml') 8 | _args = parser.parse_args() 9 | args = OmegaConf.load(_args.config) 10 | return args -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | export GPU=1 2 | 3 | export CONFIG=./config/wikimel.yaml 4 | export LOG=./logs/wikimel_baseline.logs 5 | 6 | #export CONFIG=./config/wikidiverse.yaml 7 | #export LOG=./logs/wikidiverse_baseline.logs 8 | 9 | #export CONFIG=./config/richpediamel.yaml 10 | #export LOG=./logs/richpediamel_baseline.logs 11 | 12 | CUDA_VISIBLE_DEVICES=$GPU nohup python -u ./main.py --config $CONFIG \ 13 | > $LOG 2>&1 & -------------------------------------------------------------------------------- /config/wikimel.yaml: -------------------------------------------------------------------------------- 1 | run_name: WikiMEL 2 | seed: 43 3 | pretrained_model: '/checkpoint/clip-vit-base-patch32' 4 | lr: 5e-6 5 | 6 | 7 | data: 8 | num_entity: 109976 9 | kb_img_folder: /data/WikiMEL/kb_image 10 | mention_img_folder: /data/WikiMEL/mention_image 11 | qid2id: /data/WikiMEL/qid2id.json 12 | entity: /data/WikiMEL/kb_entity.json 13 | train_file: /data/WikiMEL/WIKIMEL_train.json 14 | dev_file: /data/WikiMEL/WIKIMEL_dev.json 15 | test_file: /data/WikiMEL/WIKIMEL_test.json 16 | 17 | batch_size: 32 18 | num_workers: 8 19 | text_max_length: 50 20 | visual_patch_length: 50 21 | 22 | eval_chunk_size: 6000 23 | eval_batch_size: 20 24 | embed_update_batch_size: 512 25 | 26 | 27 | model: 28 | input_hidden_dim: 512 29 | input_image_hidden_dim: 768 30 | hidden_dim: 96 31 | dv: 96 32 | dt: 512 33 | TGLU_hidden_dim: 96 34 | IDLU_hidden_dim: 96 35 | CMFU_hidden_dim: 96 36 | num_experts: 4 37 | top_experts: 3 38 | 39 | 40 | trainer: 41 | accelerator: 'gpu' 42 | devices: 1 43 | max_epochs: 20 44 | num_sanity_val_steps: 0 45 | check_val_every_n_epoch: 2 46 | log_every_n_steps: 30 -------------------------------------------------------------------------------- /config/wikidiverse.yaml: -------------------------------------------------------------------------------- 1 | run_name: WikiDiverse 2 | seed: 43 3 | pretrained_model: '/checkpoint/clip-vit-base-patch32' 4 | lr: 1e-5 5 | 6 | 7 | data: 8 | num_entity: 132460 9 | kb_img_folder: /data/WikiDiverse/kb_image 10 | mention_img_folder: /data/WikiDiverse/mention_image 11 | qid2id: /data/WikiDiverse/qid2id.json 12 | entity: /data/WikiDiverse/kb_entity.json 13 | train_file: /data/WikiDiverse/WikiDiverse_train.json 14 | dev_file: /data/WikiDiverse/WikiDiverse_dev.json 15 | test_file: /data/WikiDiverse/WikiDiverse_test.json 16 | 17 | batch_size: 32 18 | num_workers: 8 19 | text_max_length: 50 20 | visual_patch_length: 50 21 | 22 | eval_chunk_size: 6000 23 | eval_batch_size: 20 24 | embed_update_batch_size: 512 25 | 26 | 27 | model: 28 | input_hidden_dim: 512 29 | input_image_hidden_dim: 768 30 | hidden_dim: 96 31 | dv: 96 32 | dt: 512 33 | TGLU_hidden_dim: 96 34 | IDLU_hidden_dim: 96 35 | CMFU_hidden_dim: 96 36 | num_experts: 4 37 | top_experts: 2 38 | 39 | 40 | trainer: 41 | accelerator: 'gpu' 42 | devices: 1 43 | max_epochs: 20 44 | num_sanity_val_steps: 0 45 | check_val_every_n_epoch: 2 46 | log_every_n_steps: 30 -------------------------------------------------------------------------------- /config/richpediamel.yaml: -------------------------------------------------------------------------------- 1 | run_name: RichpediaMEL 2 | seed: 43 3 | pretrained_model: '/checkpoint/clip-vit-base-patch32' 4 | lr: 1e-5 5 | 6 | 7 | data: 8 | num_entity: 160933 9 | kb_img_folder: /data/RichpediaMEL/kb_image 10 | mention_img_folder: /data/RichpediaMEL/mention_image 11 | qid2id: /data/RichpediaMEL/qid2id.json 12 | entity: /data/RichpediaMEL/kb_entity.json 13 | train_file: /data/RichpediaMEL/RichpediaMEL_train.json 14 | dev_file: /data/RichpediaMEL/RichpediaMEL_dev.json 15 | test_file: /data/RichpediaMEL/RichpediaMEL_test.json 16 | 17 | batch_size: 32 18 | num_workers: 8 19 | text_max_length: 50 20 | visual_patch_length: 50 21 | 22 | eval_chunk_size: 6000 23 | eval_batch_size: 20 24 | embed_update_batch_size: 512 25 | 26 | 27 | model: 28 | input_hidden_dim: 512 29 | input_image_hidden_dim: 768 30 | hidden_dim: 96 31 | dv: 96 32 | dt: 512 33 | TGLU_hidden_dim: 96 34 | IDLU_hidden_dim: 96 35 | CMFU_hidden_dim: 96 36 | num_experts: 4 37 | top_experts: 4 38 | 39 | 40 | trainer: 41 | accelerator: 'gpu' 42 | devices: 1 43 | max_epochs: 20 44 | num_sanity_val_steps: 0 45 | check_val_every_n_epoch: 2 46 | log_every_n_steps: 30 -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping 4 | from codes.utils.functions import setup_parser 5 | from codes.model.lightning_mmoe import LightningForMMoE 6 | from codes.utils.dataset import DataModuleForMMoE 7 | 8 | if __name__ == '__main__': 9 | args = setup_parser() 10 | pl.seed_everything(args.seed, workers=True) 11 | torch.set_num_threads(1) 12 | 13 | data_module = DataModuleForMMoE(args) 14 | lightning_model = LightningForMMoE(args) 15 | 16 | logger = pl.loggers.CSVLogger("./runs", name=args.run_name, flush_logs_every_n_steps=30) 17 | 18 | ckpt_callbacks = ModelCheckpoint(monitor='Val/mrr', save_weights_only=True, mode='max') 19 | early_stop_callback = EarlyStopping(monitor="Val/mrr", min_delta=0.00, patience=3, verbose=True, mode="max") 20 | 21 | trainer = pl.Trainer(**args.trainer, 22 | deterministic=True, logger=logger, default_root_dir="./runs", 23 | callbacks=[ckpt_callbacks, early_stop_callback]) 24 | 25 | trainer.fit(lightning_model, datamodule=data_module) 26 | trainer.test(lightning_model, datamodule=data_module, ckpt_path='best') -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-level Mixture of Experts for Multimodal Entity Linking 2 | #### This repo provides the source code & data of our paper: Multi-level Mixture of Experts for Multimodal Entity Linking(KDD2025). 3 | 4 | ## Dependencies 5 | * conda create -n mmoe python=3.7 -y 6 | * torch==1.11.0+cu113 7 | * transformers==4.27.1 8 | * torchmetrics==0.11.0 9 | * tokenizers==0.12.1 10 | * pytorch-lightning==1.7.7 11 | * omegaconf==2.2.3 12 | * pillow==9.3.0 13 | 14 | ## Running the code 15 | ### Dataset 16 | 1. Download the datasets from [MIMIC paper](https://github.com/pengfei-luo/MIMIC). 17 | 2. Download the data with WikiData description information from [here](https://drive.google.com/drive/folders/196zSJCy5XOmRZ995Y1SUZkGbMN922nPY?usp=sharing) and move it to the corresponding MIMIC datasets folder. 18 | 3. Create the root directory ./data and put the dataset in. 19 | 4. Download the pretrained_weight from [clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32). 20 | 5. Create the root directory ./checkpoint and put the pretrained_weight in. 21 | 22 | ### Training model 23 | ```python 24 | sh run.sh 25 | ``` 26 | **Note:** We provide commands for running three datasets in run.sh. You can switch commands by opening comments. 27 | 28 | ### Training logs 29 | **Note:** We provide logs of our training in the logs directory. 30 | 31 | ## Citation 32 | If you find this code useful, please consider citing the following paper. 33 | ``` 34 | @article{ 35 | author={Zhiwei Hu and Víctor Gutiérrez-Basulto and Zhiliang Xiang and Ru Li and Jeff Z. Pan}, 36 | title={Multi-level Mixture of Experts for Multimodal Entity Linking}, 37 | publisher="ACM SIGKDD Conference on Knowledge Discovery and Data Mining", 38 | year={2025} 39 | } 40 | ``` 41 | ## Acknowledgement 42 | We refer to codes of [MIMIC](https://github.com/pengfei-luo/MIMIC) and [MEL-M3EL](https://github.com/zhiweihu1103/MEL-M3EL). Thanks for their contributions. 43 | -------------------------------------------------------------------------------- /codes/model/moe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from torch import nn, Tensor 5 | from typing import Callable 6 | 7 | class GLU(nn.Module): 8 | def __init__( 9 | self, 10 | dim_in: int, 11 | dim_out: int, 12 | activation: Callable[[Tensor], Tensor], 13 | mult_bias: bool = False, 14 | ): 15 | super().__init__() 16 | self.act = activation 17 | self.proj = nn.Linear(dim_in, dim_out * 2) 18 | self.mult_bias = nn.Parameter(torch.ones(dim_out)) if mult_bias else 1.0 19 | 20 | def forward(self, x: Tensor) -> Tensor: 21 | x, gate = self.proj(x).chunk(2, dim=-1) 22 | return x * self.act(gate) * self.mult_bias 23 | 24 | class ReluSquared(nn.Module): 25 | def forward(self, x): 26 | return F.relu(x) ** 2 27 | 28 | 29 | def exists(val): 30 | return val is not None 31 | 32 | 33 | def default(val, default_val): 34 | return default_val if val is None else val 35 | 36 | 37 | def init_zero_(layer): 38 | nn.init.constant_(layer.weight, 0.0) 39 | if exists(layer.bias): 40 | nn.init.constant_(layer.bias, 0.0) 41 | 42 | 43 | class FeedForward(nn.Module): 44 | def __init__( 45 | self, 46 | dim: int, 47 | dim_out: int = None, 48 | mult=4, 49 | glu=False, 50 | glu_mult_bias=False, 51 | swish=False, 52 | relu_squared=False, 53 | post_act_ln=False, 54 | dropout: float = 0.0, 55 | no_bias=False, 56 | zero_init_output=False, 57 | ): 58 | super().__init__() 59 | inner_dim = int(dim * mult) 60 | dim_out = default(dim_out, dim) 61 | 62 | if relu_squared: 63 | activation = ReluSquared() 64 | elif swish: 65 | activation = nn.SiLU() 66 | else: 67 | activation = nn.GELU() 68 | 69 | if glu: 70 | project_in = GLU( 71 | dim, inner_dim, activation, mult_bias=glu_mult_bias 72 | ) 73 | else: 74 | project_in = nn.Sequential( 75 | nn.Linear(dim, inner_dim, bias=not no_bias), activation 76 | ) 77 | 78 | if post_act_ln: 79 | self.ff = nn.Sequential( 80 | project_in, 81 | nn.LayerNorm(inner_dim), 82 | nn.Dropout(dropout), 83 | nn.Linear(inner_dim, dim_out, bias=not no_bias), 84 | ) 85 | else: 86 | self.ff = nn.Sequential( 87 | project_in, 88 | nn.Dropout(dropout), 89 | nn.Linear(inner_dim, dim_out, bias=not no_bias), 90 | ) 91 | 92 | # init last linear layer to 0 93 | if zero_init_output: 94 | init_zero_(self.ff[-1]) 95 | 96 | def forward(self, x): 97 | return self.ff(x) 98 | 99 | 100 | class SwitchGate(nn.Module): 101 | def __init__( 102 | self, 103 | dim, 104 | num_experts: int, 105 | top_k: int, 106 | capacity_factor: float = 1.0, 107 | epsilon: float = 1e-6 108 | ): 109 | super().__init__() 110 | self.dim = dim 111 | self.num_experts = num_experts 112 | self.top_k = top_k 113 | self.capacity_factor = capacity_factor 114 | self.epsilon = epsilon 115 | self.w_gate = nn.Linear(dim, num_experts) 116 | 117 | def forward(self, x: Tensor): 118 | gate_scores = F.softmax(self.w_gate(x), dim=-1) 119 | capacity = int(self.capacity_factor * x.size(0)) 120 | top_k_scores, top_k_indices = gate_scores.topk(self.top_k, dim=-1) 121 | mask = torch.zeros_like(gate_scores).scatter_(1, top_k_indices, 1) 122 | masked_gate_scores = gate_scores * mask 123 | denominators = masked_gate_scores.sum(0, keepdim=True) + self.epsilon 124 | gate_scores = (masked_gate_scores / denominators) * capacity 125 | 126 | return gate_scores 127 | 128 | 129 | class SwitchMoE(nn.Module): 130 | def __init__( 131 | self, 132 | dim: int, 133 | output_dim: int, 134 | num_experts: int, 135 | top_k: int, 136 | capacity_factor: float = 1.0, 137 | mult: int = 4 138 | ): 139 | super().__init__() 140 | self.dim = dim 141 | self.output_dim = output_dim 142 | self.num_experts = num_experts 143 | self.capacity_factor = capacity_factor 144 | self.mult = mult 145 | 146 | self.experts = nn.ModuleList( 147 | [ 148 | FeedForward(dim, output_dim, mult) 149 | for _ in range(num_experts) 150 | ] 151 | ) 152 | 153 | self.gate = SwitchGate( 154 | dim, 155 | num_experts, 156 | top_k, 157 | capacity_factor, 158 | ) 159 | 160 | def forward(self, x: Tensor, gate_input: Tensor): 161 | # (batch_size, seq_len, num_experts) 162 | gate_scores = self.gate(gate_input) 163 | expert_outputs = [expert(x) for expert in self.experts] 164 | 165 | if torch.isnan(gate_scores).any(): 166 | print("NaN in gate scores") 167 | gate_scores[torch.isnan(gate_scores)] = 0 168 | stacked_expert_outputs = torch.stack( 169 | expert_outputs, dim=-1 170 | ) # (batch_size, seq_len, output_dim, num_experts) 171 | if torch.isnan(stacked_expert_outputs).any(): 172 | stacked_expert_outputs[torch.isnan(stacked_expert_outputs)] = 0 173 | 174 | moe_output = torch.sum( 175 | gate_scores.unsqueeze(-2) * stacked_expert_outputs, dim=-1 176 | ) 177 | 178 | return moe_output -------------------------------------------------------------------------------- /codes/model/lightning_mmoe.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import pytorch_lightning as pl 5 | from tqdm import tqdm 6 | from codes.model.modeling_mmoe import MMoEEncoder, MMoEMatcher 7 | 8 | 9 | class LightningForMMoE(pl.LightningModule): 10 | def __init__(self, args): 11 | super(LightningForMMoE, self).__init__() 12 | self.args = args 13 | self.save_hyperparameters(args) 14 | 15 | self.encoder = MMoEEncoder(args) 16 | self.matcher = MMoEMatcher(args) 17 | self.loss_fct = torch.nn.CrossEntropyLoss() 18 | 19 | def training_step(self, batch): 20 | ent_batch = {} 21 | mention_batch = {} 22 | for k, v in batch.items(): 23 | if k.startswith('ent_'): 24 | ent_batch[k.replace('ent_', '')] = v 25 | else: 26 | mention_batch[k] = v 27 | entity_empty_image_flag = ent_batch.pop('empty_img_flag') # not use 28 | 29 | mention_text_embeds, mention_image_embeds, mention_text_seq_tokens, mention_image_patch_tokens = \ 30 | self.encoder(**mention_batch) 31 | entity_text_embeds, entity_image_embeds, entity_text_seq_tokens, entity_image_patch_tokens = \ 32 | self.encoder(**ent_batch) 33 | logits, (text_logits, image_logits, image_text_logits) = self.matcher(entity_text_embeds, 34 | entity_text_seq_tokens, 35 | mention_text_embeds, 36 | mention_text_seq_tokens, 37 | entity_image_embeds, 38 | entity_image_patch_tokens, 39 | mention_image_embeds, 40 | mention_image_patch_tokens) 41 | labels = torch.arange(len(mention_text_embeds)).long().to(mention_text_embeds.device) 42 | 43 | text_loss = self.loss_fct(text_logits, labels) 44 | image_loss = self.loss_fct(image_logits, labels) 45 | image_text_loss = self.loss_fct(image_text_logits, labels) 46 | overall_loss = self.loss_fct(logits, labels) 47 | 48 | loss = overall_loss + text_loss + image_loss + image_text_loss 49 | 50 | self.log('Train/loss', loss.detach().cpu().item(), on_epoch=True, prog_bar=True) 51 | return loss 52 | 53 | def validation_step(self, batch, batch_idx): 54 | answer = batch.pop('answer') 55 | batch_size = len(answer) 56 | mention_text_embeds, mention_image_embeds, mention_text_seq_tokens, mention_image_patch_tokens = \ 57 | self.encoder(**batch) 58 | scores = [] 59 | chunk_size = self.args.data.eval_chunk_size 60 | for idx in range(math.ceil(self.args.data.num_entity / chunk_size)): 61 | start_pos = idx * chunk_size 62 | end_pos = (idx + 1) * chunk_size 63 | 64 | chunk_entity_text_embeds = self.entity_text_embeds[start_pos:end_pos].to(mention_text_embeds.device) 65 | chunk_entity_image_embeds = self.entity_image_embeds[start_pos:end_pos].to(mention_text_embeds.device) 66 | chunk_entity_text_seq_tokens = self.entity_text_seq_tokens[start_pos:end_pos].to(mention_text_embeds.device) 67 | chunk_entity_image_patch_tokens = self.entity_image_patch_tokens[start_pos:end_pos].to( 68 | mention_text_embeds.device) 69 | 70 | chunk_score, _ = self.matcher(chunk_entity_text_embeds, chunk_entity_text_seq_tokens, 71 | mention_text_embeds, mention_text_seq_tokens, 72 | chunk_entity_image_embeds, chunk_entity_image_patch_tokens, 73 | mention_image_embeds, mention_image_patch_tokens) 74 | scores.append(chunk_score) 75 | 76 | scores = torch.concat(scores, dim=-1) 77 | rank = torch.argsort(torch.argsort(scores, dim=-1, descending=True), dim=-1, descending=False) + 1 78 | tgt_rank = rank[torch.arange(batch_size), answer].detach().cpu() 79 | return dict(rank=tgt_rank, all_rank=rank.detach().cpu().numpy()) 80 | 81 | def on_validation_start(self): 82 | entity_dataloader = self.trainer.datamodule.entity_dataloader() 83 | outputs_text_embed = [] 84 | outputs_image_embed = [] 85 | outputs_text_seq_tokens = [] 86 | outputs_image_patch_tokens = [] 87 | 88 | with torch.no_grad(): 89 | for batch in tqdm(entity_dataloader, desc='UpdateEmbed', total=len(entity_dataloader)): 90 | batch = pl.utilities.move_data_to_device(batch, self.device) 91 | entity_text_embeds, entity_image_embeds, entity_text_seq_tokens, entity_image_patch_tokens = \ 92 | self.encoder(**batch) 93 | outputs_text_embed.append(entity_text_embeds.cpu()) 94 | outputs_image_embed.append(entity_image_embeds.cpu()) 95 | outputs_text_seq_tokens.append(entity_text_seq_tokens.cpu()) 96 | outputs_image_patch_tokens.append(entity_image_patch_tokens.cpu()) 97 | 98 | self.entity_text_embeds = torch.concat(outputs_text_embed, dim=0) 99 | self.entity_image_embeds = torch.concat(outputs_image_embed, dim=0) 100 | self.entity_text_seq_tokens = torch.concat(outputs_text_seq_tokens, dim=0) 101 | self.entity_image_patch_tokens = torch.concat(outputs_image_patch_tokens, dim=0) 102 | 103 | def validation_epoch_end(self, outputs): 104 | self.entity_text_embeds = None 105 | self.entity_image_embeds = None 106 | self.entity_text_seq_tokens = None 107 | self.entity_image_patch_tokens = None 108 | 109 | ranks = np.concatenate([_['rank'] for _ in outputs]) 110 | hits20 = (ranks <= 20).mean() 111 | hits10 = (ranks <= 10).mean() 112 | hits5 = (ranks <= 5).mean() 113 | hits3 = (ranks <= 3).mean() 114 | hits1 = (ranks <= 1).mean() 115 | 116 | self.log("Val/hits20", hits20) 117 | self.log("Val/hits10", hits10) 118 | self.log("Val/hits5", hits5) 119 | self.log("Val/hits3", hits3) 120 | self.log("Val/hits1", hits1) 121 | self.log("Val/mr", ranks.mean()) 122 | self.log("Val/mrr", (1. / ranks).mean()) 123 | 124 | def test_step(self, batch, batch_idx, dataloader_idx=None): 125 | answer = batch.pop('answer') 126 | batch_size = len(answer) 127 | mention_text_embeds, mention_image_embeds, mention_text_seq_tokens, mention_image_patch_tokens = \ 128 | self.encoder(**batch) 129 | 130 | scores = [] 131 | chunk_size = self.args.data.eval_chunk_size 132 | for idx in range(math.ceil(self.args.data.num_entity / chunk_size)): 133 | start_pos = idx * chunk_size 134 | end_pos = (idx + 1) * chunk_size 135 | 136 | chunk_entity_text_embeds = self.entity_text_embeds[start_pos:end_pos].to(mention_text_embeds.device) 137 | chunk_entity_image_embeds = self.entity_image_embeds[start_pos:end_pos].to(mention_text_embeds.device) 138 | chunk_entity_text_seq_tokens = self.entity_text_seq_tokens[start_pos:end_pos].to(mention_text_embeds.device) 139 | chunk_entity_image_patch_tokens = self.entity_image_patch_tokens[start_pos:end_pos].to( 140 | mention_text_embeds.device) 141 | 142 | chunk_score, _ = self.matcher(chunk_entity_text_embeds, chunk_entity_text_seq_tokens, 143 | mention_text_embeds, mention_text_seq_tokens, 144 | chunk_entity_image_embeds, chunk_entity_image_patch_tokens, 145 | mention_image_embeds, mention_image_patch_tokens) 146 | scores.append(chunk_score) 147 | 148 | scores = torch.concat(scores, dim=-1) 149 | rank = torch.argsort(torch.argsort(scores, dim=-1, descending=True), dim=-1, descending=False) + 1 150 | tgt_rank = rank[torch.arange(batch_size), answer].detach().cpu() 151 | return dict(rank=tgt_rank, all_rank=rank.detach().cpu().numpy(), scores=scores.detach().cpu().numpy()) 152 | 153 | def on_test_start(self): 154 | entity_dataloader = self.trainer.datamodule.entity_dataloader() 155 | outputs_text_embed = [] 156 | outputs_image_embed = [] 157 | outputs_text_seq_tokens = [] 158 | outputs_image_patch_tokens = [] 159 | 160 | with torch.no_grad(): 161 | for batch in tqdm(entity_dataloader, desc='UpdateEmbed', total=len(entity_dataloader)): 162 | batch = pl.utilities.move_data_to_device(batch, self.device) 163 | entity_text_embeds, entity_image_embeds, entity_text_seq_tokens, entity_image_patch_tokens = \ 164 | self.encoder(**batch) 165 | outputs_text_embed.append(entity_text_embeds.cpu()) 166 | outputs_image_embed.append(entity_image_embeds.cpu()) 167 | outputs_text_seq_tokens.append(entity_text_seq_tokens.cpu()) 168 | outputs_image_patch_tokens.append(entity_image_patch_tokens.cpu()) 169 | 170 | self.entity_text_embeds = torch.concat(outputs_text_embed, dim=0) 171 | self.entity_image_embeds = torch.concat(outputs_image_embed, dim=0) 172 | self.entity_text_seq_tokens = torch.concat(outputs_text_seq_tokens, dim=0) 173 | self.entity_image_patch_tokens = torch.concat(outputs_image_patch_tokens, dim=0) 174 | 175 | def test_epoch_end(self, outputs): 176 | self.entity_text_embeds = None 177 | self.entity_image_embeds = None 178 | self.entity_text_seq_tokens = None 179 | self.entity_image_patch_tokens = None 180 | 181 | ranks = np.concatenate([_['rank'] for _ in outputs]) 182 | hits20 = (ranks <= 20).mean() 183 | hits10 = (ranks <= 10).mean() 184 | hits5 = (ranks <= 5).mean() 185 | hits3 = (ranks <= 3).mean() 186 | hits1 = (ranks <= 1).mean() 187 | 188 | self.log("Test/hits20", hits20) 189 | self.log("Test/hits10", hits10) 190 | self.log("Test/hits5", hits5) 191 | self.log("Test/hits3", hits3) 192 | self.log("Test/hits1", hits1) 193 | self.log("Test/mr", ranks.mean()) 194 | self.log("Test/mrr", (1. / ranks).mean()) 195 | 196 | def configure_optimizers(self): 197 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 198 | optimizer_grouped_params = [ 199 | {'params': [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay)], 200 | 'weight_decay': 0.0001}, 201 | {'params': [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 202 | ] 203 | optimizer = torch.optim.AdamW(optimizer_grouped_params, lr=self.args.lr, betas=(0.9, 0.999), eps=1e-4) 204 | return [optimizer] -------------------------------------------------------------------------------- /codes/utils/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import json 4 | import os.path 5 | import random 6 | import pickle 7 | 8 | import torch 9 | import pytorch_lightning as pl 10 | from PIL import Image 11 | from tqdm import tqdm 12 | from torch.utils.data import DataLoader 13 | from transformers import CLIPProcessor 14 | from urllib.parse import unquote 15 | 16 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 17 | 18 | 19 | def _load_json_file(filepath): 20 | data = [] 21 | if isinstance(filepath, str): 22 | with open(filepath, 'r', encoding='utf-8') as f: 23 | d = json.load(f) 24 | data.extend(d) 25 | elif isinstance(filepath, list): 26 | for path in filepath: 27 | with open(path, 'r', encoding='utf-8') as f: 28 | d = json.load(f) 29 | data.extend(d) 30 | return data 31 | 32 | 33 | class DataModuleForMMoE(pl.LightningDataModule): 34 | def __init__(self, args): 35 | super(DataModuleForMMoE, self).__init__() 36 | self.args = args 37 | current_directory = os.path.dirname(os.path.abspath(__file__)) 38 | base_path = current_directory[0:current_directory.rfind('/')] 39 | self.base_path = base_path[0:base_path.rfind('/')] 40 | self.tokenizer = CLIPProcessor.from_pretrained(self.base_path + self.args.pretrained_model).tokenizer 41 | self.image_processor = CLIPProcessor.from_pretrained(self.base_path + self.args.pretrained_model).feature_extractor 42 | with open(self.base_path + self.args.data.qid2id, 'r', encoding='utf-8') as f: 43 | self.qid2id = json.loads(f.readline()) 44 | self.raw_kb_entity = sorted(_load_json_file(self.base_path + self.args.data.entity), key=lambda x: x['id']) 45 | self.kb_entity = self.setup_dataset_for_entity(self.base_path + self.args.data.entity, self.raw_kb_entity) 46 | self.kb_id2entity = {raw_ent['id']: ent for raw_ent, ent in zip(self.raw_kb_entity, self.kb_entity)} 47 | 48 | self.train_data = self.setup_dataset_for_mention(self.base_path + self.args.data.train_file, _load_json_file(self.base_path + self.args.data.train_file)) 49 | self.val_data = self.setup_dataset_for_mention(self.base_path + self.args.data.dev_file, _load_json_file(self.base_path + self.args.data.dev_file)) 50 | self.test_data = self.setup_dataset_for_mention(self.base_path + self.args.data.test_file, _load_json_file(self.base_path + self.args.data.test_file)) 51 | 52 | def setup_dataset_for_entity(self, path, data): 53 | # prepare entity information 54 | pkl_path = path[0:path.rfind('.')] + '.pkl' 55 | if os.path.exists(pkl_path): 56 | with open(pkl_path, 'rb') as file: 57 | input_data = pickle.load(file) 58 | return input_data 59 | 60 | input_data = [] 61 | for sample_dict in tqdm(data, desc='PreProcessing'): 62 | sample_type = sample_dict['type'] 63 | if sample_type == 'entity': 64 | entity, desc = unquote(sample_dict.pop('entity_name')), sample_dict.pop('desc') 65 | input_text = entity + ' [SEP] ' + desc # concat entity and sentence 66 | input_dict = self.tokenizer(input_text, padding='max_length', max_length=self.args.data.text_max_length, truncation=True) 67 | input_dict['img_list'] = sample_dict['image_list'] 68 | input_dict['sample_type'] = 0 if sample_type == 'entity' else 1 69 | if 'answer' in sample_dict.keys(): 70 | input_dict['answer'] = self.qid2id[sample_dict['answer']] 71 | input_data.append(input_dict) 72 | 73 | with open(pkl_path, 'wb') as file: 74 | pickle.dump(input_data, file) 75 | 76 | return input_data 77 | 78 | def setup_dataset_for_mention(self, path, data): 79 | # prepare mention information 80 | pkl_path = path[0:path.rfind('.')] + '.pkl' 81 | if os.path.exists(pkl_path): 82 | with open(pkl_path, 'rb') as file: 83 | input_data = pickle.load(file) 84 | return input_data 85 | 86 | input_data = [] 87 | for sample_dict in tqdm(data, desc='PreProcessing'): 88 | sample_type = 1 89 | entity, mention, text, desc = unquote(sample_dict.pop('entities')), unquote(sample_dict.pop('mentions')), sample_dict.pop('sentence'), sample_dict.pop('desc') 90 | input_text = mention + ' [SEP] ' + text + ' [SEP] ' + desc # concat entity and text 91 | 92 | input_dict = self.tokenizer(input_text, padding='max_length', max_length=self.args.data.text_max_length, truncation=True) 93 | 94 | input_dict['img_list'] = [sample_dict['imgPath']] if sample_dict['imgPath'] != '' else [] 95 | input_dict['sample_type'] = sample_type 96 | if 'answer' in sample_dict.keys(): 97 | input_dict['answer'] = self.qid2id[sample_dict['answer']] 98 | if sample_dict['answer'] == 'nil': # ignore the sample without ground truth 99 | continue 100 | input_data.append(input_dict) 101 | 102 | with open(pkl_path, 'wb') as file: 103 | pickle.dump(input_data, file) 104 | 105 | return input_data 106 | 107 | def choose_image(self, sample_type, img_list, is_eval=False): 108 | if len(img_list): 109 | img_name = random.choice(img_list) 110 | # when evaluation, we choose the first image 111 | if is_eval: 112 | img_name = img_list[0] 113 | if sample_type == 1: 114 | img_name = img_name.split('/')[-1].split('.')[0] + '.jpg' # we already convert all image to jpg format 115 | try: 116 | img_path = os.path.join( 117 | self.base_path + self.args.data.kb_img_folder if sample_type == 0 else self.base_path + self.args.data.mention_img_folder, 118 | img_name) 119 | img = Image.open(img_path).resize((224, 224), Image.Resampling.LANCZOS) 120 | pixel_values = self.image_processor(img, return_tensors='pt')['pixel_values'].squeeze() 121 | except: 122 | pixel_values = torch.zeros((3, 224, 224)) 123 | else: 124 | pixel_values = torch.zeros((3, 224, 224)) 125 | return pixel_values 126 | 127 | def train_collator(self, samples): 128 | cls_idx, img_list, sample_type, input_dict_list = [], [], [], [] 129 | pixel_values, gt_ent_id = [], [] 130 | 131 | # collect the metadata that need to further process 132 | for sample_idx, sample in enumerate(samples): 133 | img_list.append(sample.pop('img_list')) # mention image list 134 | sample_type.append(sample.pop('sample_type')) # input type: 0 for mention and 1 for entity 135 | input_dict_list.append(sample) # mention input dict (input_tokens, token_type_ids, attention_mask) 136 | gt_ent_id.append(sample.pop('answer')) # ground truth entity id of mentions 137 | ### 138 | # Now we process mention information 139 | # choose an image 140 | for idx, _ in enumerate(input_dict_list): 141 | pixel_values.append(self.choose_image(sample_type[idx], img_list[idx])) 142 | # pad textual input 143 | input_dict = self.tokenizer.pad(input_dict_list, 144 | padding='max_length', 145 | max_length=self.args.data.text_max_length, 146 | return_tensors='pt') 147 | # concat all images 148 | pixel_values = torch.stack(pixel_values) 149 | input_dict['pixel_values'] = pixel_values 150 | 151 | ### 152 | # now we process entity information 153 | # fetch the entities' metadata 154 | ent_info_list = [copy.deepcopy(self.kb_id2entity[idx]) for idx in gt_ent_id] 155 | ent_img_list, ent_type, ent_input_dict_list, ent_pixel_values = [], [], [], [] 156 | for ent_dict in ent_info_list: 157 | ent_img_list.append(ent_dict.pop('img_list')) # entity image list 158 | ent_type.append(ent_dict.pop('sample_type')) # input type: 0 for mention and 1 for entity 159 | ent_input_dict_list.append(ent_dict) # entity input dict (input_tokens, token_type_ids, attention_mask) 160 | # choose an image 161 | for idx, _ in enumerate(ent_input_dict_list): 162 | ent_pixel_values.append(self.choose_image(ent_type[idx], ent_img_list[idx])) 163 | # some of the entities do not have image, so we use bool flags to tag them 164 | ent_empty_img_flag = torch.tensor([True if not len(_) else False for _ in ent_img_list], dtype=torch.bool) 165 | # pad textual input 166 | ent_input_dict = self.tokenizer.pad(ent_input_dict_list, 167 | padding='max_length', 168 | max_length=self.args.data.text_max_length, 169 | return_tensors='pt') 170 | # concat all image 171 | ent_pixel_values = torch.stack(ent_pixel_values) 172 | ent_input_dict['pixel_values'] = ent_pixel_values 173 | ent_input_dict['empty_img_flag'] = ent_empty_img_flag 174 | 175 | # for the entity information, we use prefix 'ent_' to tag them 176 | for k, v in ent_input_dict.items(): 177 | input_dict[f'ent_{k}'] = v 178 | return input_dict 179 | 180 | def eval_collator(self, samples): 181 | # eval collator is similar to train collator, but only include mention information 182 | cls_idx, img_list, sample_type, input_dict_list = [], [], [], [] 183 | pixel_values, gt_ent_id = [], [] 184 | 185 | for sample_idx, sample in enumerate(samples): 186 | img_list.append(sample.pop('img_list')) 187 | sample_type.append(sample.pop('sample_type')) 188 | gt_ent_id.append(sample.pop('answer')) 189 | input_dict_list.append(sample) 190 | 191 | for idx, _ in enumerate(input_dict_list): 192 | pixel_values.append(self.choose_image(sample_type[idx], img_list[idx], is_eval=True)) 193 | 194 | input_dict = self.tokenizer.pad(input_dict_list, 195 | padding='max_length', 196 | max_length=self.args.data.text_max_length, 197 | return_tensors='pt') 198 | input_dict['pixel_values'] = torch.stack(pixel_values) 199 | input_dict['answer'] = torch.tensor(gt_ent_id, dtype=torch.long) 200 | return input_dict 201 | 202 | def entity_collator(self, samples): 203 | # entity collator is similar to train collator, but only include entity information 204 | pixel_values, img_list, sample_type, input_dict_list = [], [], [], [] 205 | for sample_idx, sample in enumerate(samples): 206 | img_list.append(sample.pop('img_list')) 207 | sample_type.append(sample.pop('sample_type')) 208 | input_dict_list.append(sample) 209 | for idx, input_dict in enumerate(input_dict_list): 210 | pixel_values.append(self.choose_image(sample_type[idx], img_list[idx], is_eval=True)) 211 | 212 | input_dict = self.tokenizer.pad(input_dict_list, 213 | padding='max_length', 214 | max_length=self.args.data.text_max_length, 215 | return_tensors='pt') 216 | input_dict['pixel_values'] = torch.stack(pixel_values) 217 | 218 | return input_dict 219 | 220 | def entity_dataloader(self): 221 | return DataLoader(self.kb_entity, 222 | batch_size=self.args.data.embed_update_batch_size, 223 | num_workers=self.args.data.num_workers, 224 | shuffle=False, 225 | collate_fn=self.entity_collator) 226 | 227 | def train_dataloader(self): 228 | return DataLoader(self.train_data, 229 | batch_size=self.args.data.batch_size, 230 | num_workers=self.args.data.num_workers, 231 | shuffle=True, 232 | collate_fn=self.train_collator) 233 | 234 | def val_dataloader(self): 235 | return DataLoader(self.val_data, 236 | batch_size=self.args.data.eval_batch_size, 237 | num_workers=self.args.data.num_workers, 238 | shuffle=False, 239 | collate_fn=self.eval_collator) 240 | 241 | def test_dataloader(self): 242 | return DataLoader(self.test_data, 243 | batch_size=self.args.data.eval_batch_size, 244 | num_workers=self.args.data.num_workers, 245 | shuffle=False, 246 | collate_fn=self.eval_collator) -------------------------------------------------------------------------------- /codes/model/modeling_mmoe.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from transformers import CLIPModel 8 | from codes.model.moe import SwitchMoE 9 | 10 | 11 | class MMoEEncoder(nn.Module): 12 | def __init__(self, args): 13 | super(MMoEEncoder, self).__init__() 14 | self.args = args 15 | current_directory = os.path.dirname(os.path.abspath(__file__)) 16 | base_path = current_directory[0:current_directory.rfind('/')] 17 | self.base_path = base_path[0:base_path.rfind('/')] 18 | self.clip = CLIPModel.from_pretrained(self.base_path + self.args.pretrained_model) 19 | 20 | self.image_cls_fc = nn.Linear(self.args.model.input_hidden_dim, self.args.model.dv) 21 | self.image_tokens_fc = nn.Linear(self.args.model.input_image_hidden_dim, self.args.model.dv) 22 | 23 | def forward(self, 24 | input_ids=None, 25 | attention_mask=None, 26 | token_type_ids=None, 27 | pixel_values=None): 28 | clip_output = self.clip(input_ids=input_ids, 29 | attention_mask=attention_mask, 30 | pixel_values=pixel_values) 31 | 32 | text_embeds = clip_output.text_embeds 33 | image_embeds = clip_output.image_embeds 34 | 35 | text_seq_tokens = clip_output.text_model_output[0] 36 | image_patch_tokens = clip_output.vision_model_output[0] 37 | 38 | image_embeds = self.image_cls_fc(image_embeds) 39 | image_patch_tokens = self.image_tokens_fc(image_patch_tokens) 40 | return text_embeds, image_embeds, text_seq_tokens, image_patch_tokens 41 | 42 | class TextUnit(nn.Module): 43 | def __init__(self, args): 44 | super(TextUnit, self).__init__() 45 | self.args = args 46 | self.fc_query = nn.Linear(self.args.model.TGLU_hidden_dim, self.args.model.TGLU_hidden_dim) 47 | self.fc_key = nn.Linear(self.args.model.TGLU_hidden_dim, self.args.model.TGLU_hidden_dim) 48 | self.fc_value = nn.Linear(self.args.model.TGLU_hidden_dim, self.args.model.TGLU_hidden_dim) 49 | self.layer_norm = nn.LayerNorm(self.args.model.TGLU_hidden_dim) 50 | 51 | self.moe_layer = SwitchMoE( 52 | dim=self.args.model.input_hidden_dim, 53 | output_dim=self.args.model.TGLU_hidden_dim, 54 | num_experts=self.args.model.num_experts, 55 | top_k=self.args.model.top_experts 56 | ) 57 | 58 | def forward(self, 59 | entity_text_cls, 60 | entity_text_tokens, 61 | mention_text_cls, 62 | mention_text_tokens): 63 | """ 64 | :param entity_text_cls: [num_entity, dim] 65 | :param entity_text_tokens: [num_entity, max_seq_len, dim] 66 | :param mention_text_cls: [batch_size, dim] 67 | :param mention_text_tokens: [batch_size, max_sqe_len, dim] 68 | :return: 69 | """ 70 | entity_text_cls_tokens_features = torch.cat([entity_text_cls.unsqueeze(dim=1), entity_text_tokens], dim=1) 71 | entity_text_moe = self.moe_layer(entity_text_cls_tokens_features, entity_text_cls_tokens_features) 72 | entity_text_cls_moe = entity_text_moe[:, 0, :] 73 | entity_cls_fc = entity_text_cls_moe.unsqueeze(dim=1) 74 | entity_text_tokens = entity_text_moe[:, 1:, :] 75 | 76 | mention_text_cls_tokens_features = torch.cat([mention_text_cls.unsqueeze(dim=1), mention_text_tokens], dim=1) 77 | mention_text_moe = self.moe_layer(mention_text_cls_tokens_features, mention_text_cls_tokens_features) 78 | mention_text_tokens = mention_text_moe[:, 1:, :] 79 | 80 | query = self.fc_query(entity_text_tokens) # [num_entity, max_seq_len, dim] 81 | key = self.fc_key(mention_text_tokens) # [batch_size, max_sqe_len, dim] 82 | value = self.fc_value(mention_text_tokens) # [batch_size, max_sqe_len, dim] 83 | 84 | query = query.unsqueeze(dim=1) # [num_entity, 1, max_seq_len, dim] 85 | key = key.unsqueeze(dim=0) # [1, batch_size, max_sqe_len, dim] 86 | value = value.unsqueeze(dim=0) # [1, batch_size, max_sqe_len, dim] 87 | 88 | attention_scores = torch.matmul(query, key.transpose(-1, -2)) # [num_entity, batach_size, max_seq_len, max_seq_len] 89 | 90 | attention_scores = attention_scores / math.sqrt(self.args.model.TGLU_hidden_dim) 91 | attention_probs = nn.Softmax(dim=-1)(attention_scores) # [num_entity, batch_size, max_seq_len, max_seq_len] 92 | 93 | context = torch.matmul(attention_probs, value) # [num_entity, batch_size, max_seq_len, dim] 94 | context = torch.mean(context, dim=-2) # [num_entity, batch_size, dim] 95 | context = self.layer_norm(context) 96 | 97 | g2l_matching_score = torch.sum(entity_cls_fc * context, dim=-1) # [num_entity, batch_size] 98 | g2l_matching_score = g2l_matching_score.transpose(0, 1) # [batch_size, num_entity] 99 | g2g_matching_score = torch.matmul(mention_text_cls, entity_text_cls.transpose(-1, -2)) 100 | 101 | matching_score = (g2l_matching_score + g2g_matching_score) / 2 102 | return matching_score 103 | 104 | 105 | class VisionUnit(nn.Module): 106 | def __init__(self, args): 107 | super(VisionUnit, self).__init__() 108 | self.args = args 109 | self.fc_query = nn.Linear(self.args.model.dv, self.args.model.IDLU_hidden_dim) 110 | self.fc_key = nn.Linear(self.args.model.dv, self.args.model.IDLU_hidden_dim) 111 | self.fc_value = nn.Linear(self.args.model.dv, self.args.model.IDLU_hidden_dim) 112 | self.layer_norm = nn.LayerNorm(self.args.model.IDLU_hidden_dim) 113 | 114 | self.moe_layer = SwitchMoE( 115 | dim=self.args.model.dv, 116 | output_dim=self.args.model.IDLU_hidden_dim, 117 | num_experts=self.args.model.num_experts, 118 | top_k=self.args.model.top_experts 119 | ) 120 | 121 | def forward(self, 122 | entity_image_cls, 123 | entity_image_tokens, 124 | mention_image_cls, 125 | mention_image_tokens): 126 | """ 127 | :param entity_image_cls: [num_entity, dim] 128 | :param entity_image_tokens: [num_entity, num_patch, dim] 129 | :param mention_image_cls: [batch_size, dim] 130 | :param mention_image_tokens: [batch_size, num_patch, dim] 131 | :return: 132 | """ 133 | entity_image_cls_tokens_features = torch.cat([entity_image_cls.unsqueeze(dim=1), entity_image_tokens], dim=1) 134 | entity_image_moe = self.moe_layer(entity_image_cls_tokens_features, entity_image_cls_tokens_features) 135 | entity_image_cls_moe = entity_image_moe[:, 0, :] 136 | entity_cls_fc = entity_image_cls_moe.unsqueeze(dim=1) 137 | entity_image_tokens = entity_image_moe[:, 1:, :] 138 | 139 | mention_image_cls_tokens_features = torch.cat([mention_image_cls.unsqueeze(dim=1), mention_image_tokens], dim=1) 140 | mention_image_moe = self.moe_layer(mention_image_cls_tokens_features, mention_image_cls_tokens_features) 141 | mention_image_tokens = mention_image_moe[:, 1:, :] 142 | 143 | query = self.fc_query(entity_image_tokens) # [num_entity, num_patch, dim] 144 | key = self.fc_key(mention_image_tokens) # [batch_size, num_patch, dim] 145 | value = self.fc_value(mention_image_tokens) # [batch_size, num_patch, dim] 146 | 147 | query = query.unsqueeze(dim=1) # [num_entity, 1, num_patch, dim] 148 | key = key.unsqueeze(dim=0) # [1, batch_size, num_patch, dim] 149 | value = value.unsqueeze(dim=0) # [1, batch_size, num_patch, dim] 150 | 151 | attention_scores = torch.matmul(query, key.transpose(-1, -2)) # [num_entity, batch_size, num_patch, num_patch] 152 | 153 | attention_scores = attention_scores / math.sqrt(self.args.model.IDLU_hidden_dim) 154 | attention_probs = nn.Softmax(dim=-1)(attention_scores) # [num_entity, batch_size, num_patch, num_patch] 155 | 156 | context = torch.matmul(attention_probs, value) # [num_entity, batch_size, num_patch, dim] 157 | context = torch.mean(context, dim=-2) # [num_entity, batch_size, dim] 158 | context = self.layer_norm(context) 159 | 160 | g2l_matching_score = torch.sum(entity_cls_fc * context, dim=-1) # [num_entity, batch_size] 161 | g2l_matching_score = g2l_matching_score.transpose(0, 1) # [batch_size, num_entity] 162 | g2g_matching_score = torch.matmul(mention_image_cls, entity_image_cls.transpose(-1, -2)) 163 | 164 | matching_score = (g2l_matching_score + g2g_matching_score) / 2 165 | return matching_score 166 | 167 | class CrossUnit(nn.Module): 168 | def __init__(self, args): 169 | super(CrossUnit, self).__init__() 170 | self.args = args 171 | self.text_fc = nn.Linear(self.args.model.input_hidden_dim, self.args.model.CMFU_hidden_dim) 172 | self.image_fc = nn.Linear(self.args.model.dv, self.args.model.CMFU_hidden_dim) 173 | self.gate_fc = nn.Linear(self.args.model.CMFU_hidden_dim, 1) 174 | self.gate_act = nn.Tanh() 175 | self.gate_layer_norm = nn.LayerNorm(self.args.model.CMFU_hidden_dim) 176 | self.context_layer_norm = nn.LayerNorm(self.args.model.CMFU_hidden_dim) 177 | 178 | self.moe_layer = SwitchMoE( 179 | dim=self.args.model.CMFU_hidden_dim, 180 | output_dim=self.args.model.CMFU_hidden_dim, 181 | num_experts=self.args.model.num_experts, 182 | top_k=self.args.model.top_experts 183 | ) 184 | 185 | def forward(self, 186 | entity_text_cls, 187 | entity_text_tokens, 188 | mention_text_cls, 189 | mention_text_tokens, 190 | entity_image_cls, 191 | entity_image_tokens, 192 | mention_image_cls, 193 | mention_image_tokens): 194 | """ 195 | :param entity_text_cls: [num_entity, dim] 196 | :param entity_image_tokens: [num_entity, num_patch, dim] 197 | :param mention_text_cls: [batch_size, dim] 198 | :param mention_image_tokens: [batch_size, num_patch, dim] 199 | :return: 200 | """ 201 | entity_text_cls = self.text_fc(entity_text_cls) # [num_entity, dim] 202 | entity_text_cls_ori = entity_text_cls 203 | entity_text_tokens = self.text_fc(entity_text_tokens) # [num_entity, dim] 204 | mention_text_cls = self.text_fc(mention_text_cls) # [num_entity, dim] 205 | mention_text_cls_ori = mention_text_cls 206 | mention_text_tokens = self.text_fc(mention_text_tokens) # [batch_size, dim] 207 | 208 | entity_image_cls = self.image_fc(entity_image_cls) # [num_entity, num_patch, dim] 209 | entity_image_cls_ori = entity_image_cls 210 | entity_image_tokens = self.image_fc(entity_image_tokens) # [num_entity, num_patch, dim] 211 | mention_image_cls = self.image_fc(mention_image_cls) # [batch_size, num_patch, dim] 212 | mention_image_cls_ori = mention_image_cls 213 | mention_image_tokens = self.image_fc(mention_image_tokens) # [batch_size, num_patch, dim] 214 | 215 | entity_text_cls_image_tokens_features = torch.cat([entity_text_cls.unsqueeze(dim=1), entity_image_tokens], dim=1) 216 | entity_image_cls_text_tokens_features = torch.cat([entity_image_cls.unsqueeze(dim=1), entity_text_tokens], dim=1) 217 | mention_text_cls_image_tokens_features = torch.cat([mention_text_cls.unsqueeze(dim=1), mention_image_tokens], dim=1) 218 | mention_image_cls_text_tokens_features = torch.cat([mention_image_cls.unsqueeze(dim=1), mention_text_tokens], dim=1) 219 | 220 | entity_text_cls_image_tokens_moe = self.moe_layer(entity_text_cls_image_tokens_features, entity_image_cls_text_tokens_features) 221 | entity_text_cls = entity_text_cls_image_tokens_moe[:, 0, :] 222 | entity_image_tokens = entity_text_cls_image_tokens_moe[:, 1:, :] 223 | mention_text_cls_image_tokens_moe = self.moe_layer(mention_text_cls_image_tokens_features, mention_image_cls_text_tokens_features) 224 | mention_text_cls = mention_text_cls_image_tokens_moe[:, 0, :] 225 | mention_image_tokens = mention_text_cls_image_tokens_moe[:, 1:, :] 226 | 227 | entity_image_cls_text_tokens_moe = self.moe_layer(entity_image_cls_text_tokens_features, entity_text_cls_image_tokens_features) 228 | entity_image_cls = entity_image_cls_text_tokens_moe[:, 0, :] 229 | entity_text_tokens = entity_image_cls_text_tokens_moe[:, 1:, :] 230 | mention_image_cls_text_tokens_moe = self.moe_layer(mention_image_cls_text_tokens_features, mention_text_cls_image_tokens_features) 231 | mention_image_cls = mention_image_cls_text_tokens_moe[:, 0, :] 232 | mention_text_tokens = mention_image_cls_text_tokens_moe[:, 1:, :] 233 | 234 | entity_text_cls = entity_text_cls.unsqueeze(dim=1) # [num_entity, 1, dim] 235 | entity_text_image_cross_modal_score = torch.matmul(entity_text_cls, entity_image_tokens.transpose(-1, -2)) 236 | entity_text_image_cross_modal_probs = nn.Softmax(dim=-1)(entity_text_image_cross_modal_score) # [num_entity, 1, num_patch] 237 | entity_text_image_context = torch.matmul(entity_text_image_cross_modal_probs, entity_image_tokens).squeeze() # [num_entity, 1, dim] 238 | entity_text_image_gate_score = self.gate_act(self.gate_fc(entity_text_image_context)) 239 | entity_text_image_context = self.gate_layer_norm((entity_text_cls_ori * entity_text_image_gate_score) + entity_text_image_context) 240 | 241 | mention_text_cls = mention_text_cls.unsqueeze(dim=1) # [batch_size, 1, dim] 242 | mention_text_image_cross_modal_score = torch.matmul(mention_text_cls, mention_image_tokens.transpose(-1, -2)) 243 | mention_text_image_cross_modal_probs = nn.Softmax(dim=-1)(mention_text_image_cross_modal_score) 244 | mention_text_image_context = torch.matmul(mention_text_image_cross_modal_probs, mention_image_tokens).squeeze() 245 | mention_text_image_gate_score = self.gate_act(self.gate_fc(mention_text_cls_ori)) 246 | mention_text_image_context = self.gate_layer_norm((mention_text_cls_ori * mention_text_image_gate_score) + mention_text_image_context) 247 | 248 | score_text_image = torch.matmul(mention_text_image_context, entity_text_image_context.transpose(-1, -2)) 249 | 250 | entity_image_cls = entity_image_cls.unsqueeze(dim=1) # [num_entity, 1, dim] 251 | entity_image_text_cross_modal_score = torch.matmul(entity_image_cls, entity_text_tokens.transpose(-1, -2)) 252 | entity_image_text_cross_modal_probs = nn.Softmax(dim=-1)(entity_image_text_cross_modal_score) # [num_entity, 1, num_patch] 253 | entity_image_text_context = torch.matmul(entity_image_text_cross_modal_probs, entity_text_tokens).squeeze() # [num_entity, 1, dim] 254 | entity_image_text_gate_score = self.gate_act(self.gate_fc(entity_image_text_context)) 255 | entity_image_text_context = self.gate_layer_norm((entity_image_cls_ori * entity_image_text_gate_score) + entity_image_text_context) 256 | 257 | mention_image_cls = mention_image_cls.unsqueeze(dim=1) # [batch_size, 1, dim] 258 | mention_image_text_cross_modal_score = torch.matmul(mention_image_cls, mention_text_tokens.transpose(-1, -2)) 259 | mention_image_text_cross_modal_probs = nn.Softmax(dim=-1)(mention_image_text_cross_modal_score) 260 | mention_image_text_context = torch.matmul(mention_image_text_cross_modal_probs, mention_text_tokens).squeeze() 261 | mention_image_text_gate_score = self.gate_act(self.gate_fc(mention_image_cls_ori)) 262 | mention_image_text_context = self.gate_layer_norm((mention_image_cls_ori * mention_image_text_gate_score) + mention_image_text_context) 263 | 264 | score_image_text = torch.matmul(mention_image_text_context, entity_image_text_context.transpose(-1, -2)) 265 | 266 | score = (score_text_image + score_image_text) / 2 267 | 268 | return score 269 | 270 | 271 | class MMoEMatcher(nn.Module): 272 | def __init__(self, args): 273 | super(MMoEMatcher, self).__init__() 274 | self.args = args 275 | self.text_module = TextUnit(self.args) 276 | self.vision_module = VisionUnit(self.args) 277 | self.cross_module = CrossUnit(self.args) 278 | 279 | self.text_cls_layernorm = nn.LayerNorm(self.args.model.dt) 280 | self.text_tokens_layernorm = nn.LayerNorm(self.args.model.dt) 281 | self.image_cls_layernorm = nn.LayerNorm(self.args.model.dv) 282 | self.image_tokens_layernorm = nn.LayerNorm(self.args.model.dv) 283 | 284 | def forward(self, 285 | entity_text_cls, entity_text_tokens, 286 | mention_text_cls, mention_text_tokens, 287 | entity_image_cls, entity_image_tokens, 288 | mention_image_cls, mention_image_tokens): 289 | """ 290 | 291 | :param entity_text_cls: [num_entity, dim] 292 | :param entity_text_tokens: [num_entity, max_seq_len, dim] 293 | :param mention_text_cls: [batch_size, dim] 294 | :param mention_text_tokens: [batch_size, max_sqe_len, dim] 295 | :param entity_image_cls: [num_entity, dim] 296 | :param mention_image_cls: [batch_size, dim] 297 | :param entity_image_tokens: [num_entity, num_patch, dim] 298 | :param mention_image_tokens:[num_entity, num_patch, dim] 299 | :return: 300 | """ 301 | entity_text_cls = self.text_cls_layernorm(entity_text_cls) 302 | mention_text_cls = self.text_cls_layernorm(mention_text_cls) 303 | 304 | entity_text_tokens = self.text_tokens_layernorm(entity_text_tokens) 305 | mention_text_tokens = self.text_tokens_layernorm(mention_text_tokens) 306 | 307 | entity_image_cls = self.image_cls_layernorm(entity_image_cls) 308 | mention_image_cls = self.image_cls_layernorm(mention_image_cls) 309 | 310 | entity_image_tokens = self.image_tokens_layernorm(entity_image_tokens) 311 | mention_image_tokens = self.image_tokens_layernorm(mention_image_tokens) 312 | 313 | text_matching_score = self.text_module(entity_text_cls, entity_text_tokens, 314 | mention_text_cls, mention_text_tokens) 315 | image_matching_score = self.vision_module(entity_image_cls, entity_image_tokens, 316 | mention_image_cls, mention_image_tokens) 317 | image_text_matching_score = self.cross_module(entity_text_cls, entity_text_tokens, 318 | mention_text_cls, mention_text_tokens, 319 | entity_image_cls, entity_image_tokens, 320 | mention_image_cls, mention_image_tokens) 321 | 322 | score = (text_matching_score + image_matching_score + image_text_matching_score) / 3 323 | return score, (text_matching_score, image_matching_score, image_text_matching_score) --------------------------------------------------------------------------------