├── .gitignore ├── Graphical Abstract.png ├── config └── multi_instance.json ├── requirements.txt ├── code ├── mediapipe_face_detection.py ├── EEG_preprocessing.ipynb └── model.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | */.DS_Store -------------------------------------------------------------------------------- /Graphical Abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liangyubuaa/Milmer/HEAD/Graphical Abstract.png -------------------------------------------------------------------------------- /config/multi_instance.json: -------------------------------------------------------------------------------- 1 | { 2 | "device": "cuda:1", 3 | "epochs": 100, 4 | "lr": 0.0001, 5 | "lrf": 0.1, 6 | "max_lr": 0.00001, 7 | "batch_size": 14, 8 | "seed": 0, 9 | "num_instances": 10, 10 | "instance_selection_method": "attention_weighted_topk", 11 | "fusion_type": "cross_attention", 12 | "num_select": 3, 13 | "pretrained_model": "swin-tiny-patch4-window7-224-finetuned-face-emotion-v12" 14 | } 15 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | astunparse==1.6.3 3 | cachetools==5.5.0 4 | certifi==2024.12.14 5 | charset-normalizer==3.4.0 6 | filelock==3.13.1 7 | flatbuffers==24.12.23 8 | fsspec==2024.2.0 9 | gast==0.4.0 10 | google-auth==2.37.0 11 | google-auth-oauthlib==1.0.0 12 | google-pasta==0.2.0 13 | grpcio==1.68.1 14 | h5py==3.11.0 15 | huggingface-hub==0.27.0 16 | idna==3.10 17 | importlib_metadata==8.5.0 18 | Jinja2==3.1.3 19 | joblib==1.4.2 20 | keras==2.13.1 21 | libclang==18.1.1 22 | Markdown==3.7 23 | MarkupSafe==2.1.5 24 | mpmath==1.3.0 25 | networkx==3.0 26 | numpy==1.24.1 27 | nvidia-cublas-cu12==12.1.3.1 28 | nvidia-cuda-cupti-cu12==12.1.105 29 | nvidia-cuda-nvrtc-cu12==12.1.105 30 | nvidia-cuda-runtime-cu12==12.1.105 31 | nvidia-cudnn-cu12==8.9.2.26 32 | nvidia-cufft-cu12==11.0.2.54 33 | nvidia-curand-cu12==10.3.2.106 34 | nvidia-cusolver-cu12==11.4.5.107 35 | nvidia-cusparse-cu12==12.1.0.106 36 | nvidia-nccl-cu12==2.19.3 37 | nvidia-nvjitlink-cu12==12.1.105 38 | nvidia-nvtx-cu12==12.1.105 39 | oauthlib==3.2.2 40 | opt_einsum==3.4.0 41 | packaging==24.2 42 | pandas==2.0.3 43 | pillow==10.2.0 44 | protobuf==4.25.5 45 | pyasn1==0.6.1 46 | pyasn1_modules==0.4.1 47 | python-dateutil==2.9.0.post0 48 | pytz==2024.2 49 | PyYAML==6.0.2 50 | regex==2024.11.6 51 | requests==2.32.3 52 | requests-oauthlib==2.0.0 53 | rsa==4.9 54 | safetensors==0.4.5 55 | scikit-learn==1.3.2 56 | scipy==1.10.1 57 | six==1.17.0 58 | sympy==1.13.1 59 | tensorboard==2.13.0 60 | tensorboard-data-server==0.7.2 61 | tensorflow==2.13.1 62 | tensorflow-estimator==2.13.0 63 | tensorflow-io-gcs-filesystem==0.34.0 64 | termcolor==2.4.0 65 | threadpoolctl==3.5.0 66 | tokenizers==0.19.1 67 | torch==2.2.1+cu121 68 | torchaudio==2.2.1+cu121 69 | torchvision==0.17.1+cu121 70 | tqdm==4.67.1 71 | transformers==4.41.2 72 | triton==2.2.0 73 | typing_extensions==4.5.0 74 | tzdata==2024.2 75 | urllib3==2.2.3 76 | Werkzeug==3.0.6 77 | wrapt==1.17.0 78 | zipp==3.20.2 79 | -------------------------------------------------------------------------------- /code/mediapipe_face_detection.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import mediapipe as mp 3 | from tqdm import tqdm 4 | import time 5 | import os 6 | import matplotlib.pyplot as plt 7 | 8 | mp_face_detection = mp.solutions.face_detection 9 | model = mp_face_detection.FaceDetection( 10 | min_detection_confidence=0.7, 11 | model_selection=0, 12 | ) 13 | 14 | def process_participant_data(participant_id): 15 | participant_str = f's{participant_id:02}' 16 | input_dir = f'./photo/{participant_str}/' 17 | output_dir = f'./faces/{participant_str}/' 18 | 19 | if not os.path.exists(output_dir): 20 | os.makedirs(output_dir) 21 | 22 | image_files = os.listdir(input_dir) 23 | valid_image_extensions = [".jpg", ".jpeg", ".png", ".bmp"] 24 | 25 | for image_file in image_files: 26 | filename, ext = os.path.splitext(image_file) 27 | if ext.lower() not in valid_image_extensions: 28 | continue 29 | 30 | img = cv2.imread(f'{input_dir}{image_file}') 31 | if img is None: 32 | print(f"Failed to load image at '{input_dir}{image_file}'") 33 | continue 34 | else: 35 | print(f"'{input_dir}{image_file}'") 36 | 37 | img_RGB = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 38 | results = model.process(img_RGB) 39 | 40 | annotated_image = img.copy() 41 | 42 | detection = results.detections[0] 43 | 44 | bboxC = detection.location_data.relative_bounding_box 45 | ih, iw, _ = img.shape 46 | bbox = [int(bboxC.xmin * iw), int(bboxC.ymin * ih), int(bboxC.width * iw), int(bboxC.height * ih)] 47 | 48 | # Calculate the new coordinates of the face frame, expand the distance by 10%. 49 | expansion_ratio = 0.1 50 | bbox_expanded = [max(0, bbox[0] - bbox[2] * expansion_ratio), # left 51 | max(0, bbox[1] - bbox[3] * expansion_ratio), # top 52 | min(iw, bbox[0] + bbox[2] * (1 + expansion_ratio)), # right 53 | min(ih, bbox[1] + bbox[3] * (1 + expansion_ratio))] # bottom 54 | 55 | cv2.rectangle(annotated_image, (int(bbox_expanded[0]), int(bbox_expanded[1])), (int(bbox_expanded[2]), int(bbox_expanded[3])), (255,0,0), 2) 56 | 57 | face_image = img[int(bbox_expanded[1]):int(bbox_expanded[3]), int(bbox_expanded[0]):int(bbox_expanded[2])] 58 | 59 | cv2.imwrite('{}{}{}'.format(output_dir, filename, ext), face_image) 60 | 61 | for participant_id in range(1, 23): 62 | process_participant_data(participant_id) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Milmer 2 | 3 | This is an implementation of the Milmer model, described in the following paper: 4 | 5 | **Milmer: a Framework for Multiple Instance Learning based Multimodal Emotion Recognition** 6 | 7 | ![Preview](./Graphical%20Abstract.png) 8 | 9 | ## Abstract 10 | 11 | Emotions play a crucial role in human behavior and decision-making, making emotion recognition a key area of interest in human-computer interaction (HCI). This study addresses the challenges of emotion recognition by integrating facial expression analysis with electroencephalogram (EEG) signals, introducing a novel multimodal framework-Milmer. The proposed framework employs a transformer-based fusion approach to effectively integrate visual and physiological modalities. It consists of an EEG preprocessing module, a facial feature extraction and balancing module, and a cross-modal fusion module. To enhance visual feature extraction, we fine-tune a pre-trained Swin Transformer on emotion-related datasets. Additionally, a cross-attention mechanism is introduced to balance token representation across modalities, ensuring effective feature integration. A key innovation of this work is the adoption of a multiple instance learning (MIL) approach, which extracts meaningful information from multiple facial expression images over time, capturing critical temporal dynamics often overlooked in previous studies. Extensive experiments conducted on the DEAP dataset demonstrate the superiority of the proposed framework, achieving a classification accuracy of 96.72\% in the four-class emotion recognition task. Ablation studies further validate the contributions of each module, highlighting the significance of advanced feature extraction and fusion strategies in enhancing emotion recognition performance. Our code are available at https://github.com/liangyubuaa/Milmer. 12 | 13 | ## Requirements 14 | 15 | - Python 3.8 16 | - For other dependencies, see [requirements.txt](./requirements.txt) 17 | 18 | ## Parameters 19 | 20 | For detailed parameter configuration, please refer to the [config](./config) folder. 21 | 22 | - **Pretrained model**: Use swin-tiny-patch4-window7-224-finetuned-face-emotion-v12 by default. 23 | - **LR schedule**: Cosine decay reduces lr from its initial value to lrf×lr with lrf=0.1 by default. 24 | - **Batch size**: Train with batch size 14 and use the same for test. 25 | - **Epochs**: Train for 100 epochs; change it in config/multi_instance.json if needed. 26 | - **Optimizer**: AdamW with lr 1e-4 and other PyTorch defaults. 27 | - **MIL**: Use 10 instances per sample and select top 3 by attention-weighted top-k. 28 | - **Fusion**: Fusion type is cross_attention; available options are none, cross_attention, and mlp. 29 | - **Transformer encoder**: d_model 768, nhead 12, dim_feedforward 2048, num_layers 2, dropout 0.2, CLS dropout 0.1. 30 | - **Seed**: Random seed 0. 31 | 32 | ## Reference 33 | ``` 34 | @article{wang2025milmer, 35 | title={Milmer: a Framework for Multiple Instance Learning based Multimodal Emotion Recognition}, 36 | author={Wang, Zaitian and He, Jian and Liang, Yu and Hu, Xiyuan and Peng, Tianhao and Wang, Kaixin and Wang, Jiakai and Zhang, Chenlong and Zhang, Weili and Niu, Shuang and others}, 37 | journal={arXiv preprint arXiv:2502.00547}, 38 | year={2025} 39 | } 40 | ``` 41 | -------------------------------------------------------------------------------- /code/EEG_preprocessing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import scipy.io\n", 10 | "import numpy as np" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "mat_data = scipy.io.loadmat('./data_preprocessed_matlab/s01.mat') \n", 20 | "original_data = mat_data['data']\n", 21 | "original_label = mat_data['labels']" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "sliced_data = original_data[:, :32, 384:] \n", 31 | "print(\"sliced_data:\", sliced_data.shape)\n", 32 | "eeg_data = sliced_data" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "valence = original_label[:,0]\n", 42 | "arousal = original_label[:,1]\n", 43 | "# HAHV--1, LAHV--2, LALV--3, HALV--4\n", 44 | "VA_labels = np.where((valence > 5) & (arousal > 5), 0,\n", 45 | " np.where((valence >= 5) & (arousal < 5), 1,\n", 46 | " np.where((valence < 5) & (arousal < 5), 2, 3)))\n", 47 | "print(\"V:\", valence)\n", 48 | "print(\"A:\", arousal)\n", 49 | "print(VA_labels)\n", 50 | "\n", 51 | "segment_size = 3 * 128\n", 52 | "\n", 53 | "num_segments = sliced_data.shape[2] // segment_size\n", 54 | "expanded_VA_labels = np.repeat(VA_labels, num_segments)\n", 55 | "print(expanded_VA_labels.shape)\n", 56 | "labels = expanded_VA_labels " 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "import pandas as pd\n", 66 | "import mne\n", 67 | "from mne import io\n", 68 | "from mne.datasets import sample\n", 69 | "from mne.preprocessing import ICA\n", 70 | "import os\n", 71 | "\n", 72 | "sfreq = 128 \n", 73 | "channels = 32\n", 74 | "samples = 384\n", 75 | "num_each_trial = 20\n", 76 | "ch_names = ['Fp1', 'AF3', 'F7', 'F3', 'FC1', 'FC5', 'T7', 'C3', 'CP1', 'CP5', 'P7', 'P3', \n", 77 | " 'Pz', 'PO3', 'O1', 'Oz', 'O2', 'PO4', 'P4', 'P8', 'CP6', 'CP2', 'C4', 'T8', \n", 78 | " 'FC6', 'FC2', 'F4', 'F8', 'AF4', 'FP2', 'Fz', 'Cz']\n", 79 | "ch_types = ['eeg'] * channels\n", 80 | "\n", 81 | "data_list = []\n", 82 | "eeg_data_segments = np.split(eeg_data, 40, axis=0)\n", 83 | "index = 0\n", 84 | "for segment in eeg_data_segments:\n", 85 | " index = index + 1\n", 86 | " segment_2d = segment.reshape(-1, channels).T\n", 87 | "\n", 88 | " info = mne.create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sfreq)\n", 89 | " raw = mne.io.RawArray(segment_2d, info=info)\n", 90 | "\n", 91 | " raw.filter(l_freq=1.0, h_freq=50.0)\n", 92 | "\n", 93 | " ica = ICA(n_components=channels, random_state=0, max_iter=1000) \n", 94 | " ica.fit(raw)\n", 95 | "\n", 96 | " ica.exclude = [] \n", 97 | " ica.apply(raw)\n", 98 | "\n", 99 | " data = raw.get_data().T # (7680, 32)\n", 100 | "\n", 101 | " data = data[:num_each_trial * samples, :]\n", 102 | "\n", 103 | " data = data.reshape(num_each_trial, samples, channels)\n", 104 | "\n", 105 | " data_list.append(data)\n", 106 | "\n", 107 | "\n", 108 | "data_array = np.concatenate(data_list, axis=0)\n", 109 | "data_array = np.swapaxes(data_array, 1, 2)\n", 110 | "eeg_data = data_array" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "np.save('./EEGData/s01_eeg.npy', eeg_data)\n", 120 | "np.save('./EEGData/s01_labels.npy', labels)" 121 | ] 122 | } 123 | ], 124 | "metadata": { 125 | "kernelspec": { 126 | "display_name": "base", 127 | "language": "python", 128 | "name": "python3" 129 | }, 130 | "language_info": { 131 | "codemirror_mode": { 132 | "name": "ipython", 133 | "version": 3 134 | }, 135 | "file_extension": ".py", 136 | "mimetype": "text/x-python", 137 | "name": "python", 138 | "nbconvert_exporter": "python", 139 | "pygments_lexer": "ipython3", 140 | "version": "3.8.17" 141 | } 142 | }, 143 | "nbformat": 4, 144 | "nbformat_minor": 2 145 | } 146 | -------------------------------------------------------------------------------- /code/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class MultiModalClassifier(nn.Module): 6 | """ 7 | This class implements a multi-instance learning approach for emotion recognition 8 | using both EEG signals and facial images. 9 | """ 10 | def __init__(self, input_size = 768, num_classes = 4, 11 | num_heads = 12, dim_feedforward = 2048, num_encoder_layers = 2, device = device, 12 | eeg_size = 384, transformer_dropout_rate = 0.2, cls_dropout_rate = 0.1, 13 | fusion_type = 'cross_attention', # options: 'none', 'cross_attention', 'mlp' 14 | instance_selection_method = 'attention_weighted_topk', # options: 'none', 'softmax', 'amil', 'attention_topk', 'attention_weighted_topk' 15 | num_select = 3, num_instances = 10): 16 | """ 17 | Args: 18 | input_size (int): Hidden dimension size for transformer layers 19 | num_classes (int): Number of output classes for classification 20 | num_heads (int): Number of attention heads in transformer 21 | dim_feedforward (int): Feedforward dimension in transformer layers 22 | num_encoder_layers (int): Number of transformer encoder layers 23 | device: Device to run the model on 24 | eeg_size (int): Input dimension of EEG data 25 | transformer_dropout_rate (float): Dropout rate for transformer layers 26 | cls_dropout_rate (float): Dropout rate for classification head 27 | fusion_type (str): Type of fusion strategy for image features 28 | instance_selection_method (str): Method for selecting instances in MIL 29 | num_select (int): Number of instances to select 30 | num_instances (int): Total number of instances available 31 | """ 32 | super().__init__() 33 | 34 | # Core hyperparameters and options 35 | self.transformer_dropout_rate = transformer_dropout_rate 36 | self.cls_dropout_rate = cls_dropout_rate 37 | self.fusion_type = fusion_type 38 | self.instance_selection_method = instance_selection_method 39 | # Swin image processor and backbone (fine-tuned) 40 | self.img_processor = swin_processor 41 | self.swin_model = swin_model 42 | for param in self.swin_model.parameters(): 43 | # Enable fine-tuning 44 | param.requires_grad = True 45 | 46 | # Token type embeddings: 0 for image tokens, 1 for EEG tokens 47 | self.token_type_embeddings = nn.Embedding(2, input_size) 48 | 49 | # Transformer encoder over concatenated image and EEG tokens 50 | self.transformer_encoder = nn.TransformerEncoder( 51 | nn.TransformerEncoderLayer( 52 | d_model = input_size, 53 | nhead = num_heads, 54 | dim_feedforward = dim_feedforward, 55 | dropout = transformer_dropout_rate, 56 | batch_first = True 57 | ), 58 | num_layers = num_encoder_layers 59 | ) 60 | 61 | # EEG projection and normalization 62 | self.eeg_proj = nn.Linear(eeg_size, input_size) 63 | self.activation = nn.ReLU() 64 | self.layernorm = nn.LayerNorm(eeg_size) 65 | 66 | # Classification head 67 | self.cls_token = nn.Parameter(torch.zeros(1, 1, input_size)).to(device) 68 | self.dropout = nn.Dropout(cls_dropout_rate) 69 | self.classifier = nn.Linear(input_size, num_classes) # Final classifier 70 | 71 | # Initialize cross-attention components only when fusion_type == 'cross_attention' 72 | if fusion_type == 'cross_attention': 73 | self.num_queries = 147 74 | self.query_tokens = nn.Parameter(torch.zeros(1, self.num_queries, input_size)) 75 | nn.init.normal_(self.query_tokens, std = 0.02) 76 | self.cross_attention = nn.MultiheadAttention( 77 | embed_dim = input_size, 78 | num_heads = num_heads, 79 | dropout = transformer_dropout_rate, 80 | batch_first = True 81 | ) 82 | # MLP up/down projection components 83 | elif fusion_type == 'mlp': 84 | self.mlp_up = nn.Linear(input_size, 4 * input_size) # 768 -> 3072 85 | self.mlp_act = nn.GELU() 86 | self.mlp_down = nn.Linear(4 * input_size, input_size) # 3072 -> 768 87 | 88 | self.num_instances = num_instances 89 | self.num_select = num_select 90 | 91 | # Global learnable instance weights (used by 'softmax' method) 92 | self.instance_weights = nn.Parameter(torch.ones(1, num_instances)) 93 | 94 | # AMIL attention projection layers 95 | self.amil_value_proj = nn.Linear(input_size, input_size) 96 | self.amil_weight_proj = nn.Linear(input_size, 1) 97 | 98 | # Top-K attention projection layers 99 | self.topk_value_proj = nn.Linear(input_size * 49, input_size * 49) 100 | self.topk_weight_proj = nn.Linear(input_size * 49, 1) 101 | self.topk_dimension_recover = nn.Linear(input_size * 49, self.num_instances) 102 | 103 | def select_instances(self, images_embedding): 104 | """ 105 | Select instances from multiple image embeddings using different MIL methods. 106 | """ 107 | if self.instance_selection_method == 'none': 108 | # Return the embedding of the first image 109 | return images_embedding[:, 0, :, :].unsqueeze(1) 110 | 111 | elif self.instance_selection_method == 'softmax': 112 | # Compute weight for each instance 113 | weights = F.softmax(self.instance_weights, dim = 1) 114 | # Select top-k instances 115 | _, indices = torch.topk(weights, self.num_select, dim = 1) 116 | selected_embeddings = [] 117 | batch_size = images_embedding.size(0) 118 | 119 | for i in range(batch_size): 120 | # Select the k instances with the highest weights 121 | selected = images_embedding[i, indices[0], :] 122 | selected_embeddings.append(selected) 123 | 124 | return torch.stack(selected_embeddings) 125 | 126 | elif self.instance_selection_method == 'amil': 127 | batch_size = images_embedding.size(0) 128 | 129 | # Compute per-instance representation 130 | instance_features = images_embedding.mean(dim = 2) 131 | 132 | # AMIL attention scores 133 | hidden = torch.tanh(self.amil_value_proj(instance_features)) 134 | weights = self.amil_weight_proj(hidden) 135 | weights = F.softmax(weights, dim = 1) 136 | 137 | # Weighted sum 138 | weighted_features = (instance_features * weights).sum(dim = 1) 139 | 140 | return weighted_features 141 | 142 | elif self.instance_selection_method == 'attention_topk': 143 | batch_size = images_embedding.size(0) 144 | 145 | # Compute per-instance representation 146 | instance_features = images_embedding.view(batch_size, self.num_instances, -1) 147 | 148 | # Attention scores 149 | hidden = torch.tanh(self.topk_value_proj(instance_features)) 150 | weights = self.topk_weight_proj(hidden) 151 | weights = F.softmax(weights, dim = 1) 152 | 153 | # Weighted sum 154 | weighted_features = (instance_features * weights).sum(dim = 1) 155 | # Recover per-instance weights 156 | recovered_weights = self.topk_dimension_recover(weighted_features).view(batch_size, self.num_instances, 1) 157 | 158 | _, indices = torch.topk(recovered_weights, self.num_select, dim = 1) 159 | selected_embeddings = [] 160 | 161 | for i in range(batch_size): 162 | # Select top-k instances by recovered weights 163 | selected = images_embedding[i, indices[i], :, :] 164 | selected_embeddings.append(selected) 165 | 166 | return torch.stack(selected_embeddings) 167 | 168 | elif self.instance_selection_method == 'attention_weighted_topk': 169 | batch_size = images_embedding.size(0) 170 | 171 | # Compute per-instance representation 172 | instance_features = images_embedding.view(batch_size, self.num_instances, -1) 173 | 174 | # Attention scores 175 | hidden = torch.tanh(self.topk_value_proj(instance_features)) 176 | weights = self.topk_weight_proj(hidden) 177 | weights = F.softmax(weights, dim = 1) 178 | 179 | # Select top-k instances 180 | _, indices = torch.topk(weights, self.num_select, dim = 1) 181 | selected_embeddings = [] 182 | 183 | for i in range(batch_size): 184 | # Select the k instances with the highest weights 185 | selected = images_embedding[i, indices[i].squeeze(), :, :] 186 | selected_embeddings.append(selected) 187 | 188 | return torch.stack(selected_embeddings) 189 | 190 | def forward(self, eeg_data, images_data): 191 | batch_size = images_data.size(0) 192 | 193 | # Process multiple images 194 | images_embedding = [] 195 | for i in range(self.num_instances): 196 | image = images_data[:, i] 197 | vision_inputs = self.img_processor(image, return_tensors = "pt").to(device) 198 | embedding = self.swin_model(**vision_inputs).last_hidden_state 199 | images_embedding.append(embedding) 200 | 201 | images_embedding = torch.stack(images_embedding, dim = 1) 202 | 203 | # Select instances 204 | selected_embeddings = self.select_instances(images_embedding) 205 | 206 | selected_embeddings = selected_embeddings.view(batch_size, -1, 768) 207 | 208 | # Process according to fusion_type 209 | if self.fusion_type == 'cross_attention': 210 | query_tokens = self.query_tokens.expand(batch_size, -1, -1) 211 | image_features, _ = self.cross_attention( 212 | query = query_tokens, 213 | key = selected_embeddings, 214 | value = selected_embeddings 215 | ) 216 | images_embedding = image_features 217 | 218 | elif self.fusion_type == 'mlp': 219 | # MLP up/down projection 220 | x = self.mlp_up(selected_embeddings) 221 | x = self.mlp_act(x) 222 | images_embedding = self.mlp_down(x) 223 | 224 | else: 225 | images_embedding = selected_embeddings 226 | 227 | eeg_data = self.layernorm(eeg_data) 228 | eeg_embedding = self.eeg_proj(eeg_data) 229 | eeg_embedding = self.activation(eeg_embedding) 230 | 231 | # Add token-type embeddings to distinguish modalities 232 | images_embedding, eeg_embedding = ( 233 | images_embedding + self.token_type_embeddings(torch.zeros(images_embedding.shape[0], 1, dtype = torch.long, device = device)), 234 | eeg_embedding + self.token_type_embeddings(torch.ones(eeg_embedding.shape[0], 1, dtype = torch.long, device = device)) 235 | ) 236 | 237 | # Concatenate image and EEG tokens, prepend CLS, and encode 238 | multi_embedding = torch.cat((images_embedding, eeg_embedding), dim = 1) 239 | multi_embedding = torch.cat((self.cls_token.expand(multi_embedding.size(0), -1, -1), multi_embedding), dim = 1) 240 | multi_embedding = self.transformer_encoder(multi_embedding) 241 | 242 | # Take the output of the CLS token 243 | cls_token_output = multi_embedding[:, 0, :] 244 | cls_token_output = self.dropout(cls_token_output) 245 | x = self.classifier(cls_token_output) 246 | 247 | return x --------------------------------------------------------------------------------