├── models ├── __init__.py ├── utils │ ├── __init__.py │ ├── masking.py │ ├── metrics.py │ ├── tools.py │ └── timefeatures.py ├── subject_layers │ ├── __init__.py │ ├── StandardNorm.py │ ├── Conv_Blocks.py │ ├── Crossformer_EncDec.py │ ├── Transformer_EncDec.py │ ├── AutoCorrelation.py │ ├── Autoformer_EncDec.py │ ├── FourierCorrelation.py │ ├── Pyraformer_EncDec.py │ ├── Embed.py │ ├── ETSformer_EncDec.py │ └── SelfAttention_Family.py ├── loss.py └── util.py ├── imgs ├── encoder.png ├── test_acc.png ├── bs=16_test_acc.png ├── fig-framework.png ├── fig-genexample.png └── temporal_analysis.png ├── Retrieval ├── data_config.json ├── eegdatasets_joint_subjects.py └── eegdatasets_leaveone.py ├── Generation ├── data_config.json ├── image_adapter.ipynb ├── diffusion_prior.py └── eegdatasets_leaveone.py ├── requirements.txt ├── LICENSE ├── environment.yml ├── setup.sh ├── .gitignore ├── EEG-preprocessing ├── preprocessing.py └── preprocessing_utils.py └── README.md /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/subject_layers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imgs/encoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncclab-sustech/EEG_Image_decode/HEAD/imgs/encoder.png -------------------------------------------------------------------------------- /imgs/test_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncclab-sustech/EEG_Image_decode/HEAD/imgs/test_acc.png -------------------------------------------------------------------------------- /imgs/bs=16_test_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncclab-sustech/EEG_Image_decode/HEAD/imgs/bs=16_test_acc.png -------------------------------------------------------------------------------- /imgs/fig-framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncclab-sustech/EEG_Image_decode/HEAD/imgs/fig-framework.png -------------------------------------------------------------------------------- /imgs/fig-genexample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncclab-sustech/EEG_Image_decode/HEAD/imgs/fig-genexample.png -------------------------------------------------------------------------------- /imgs/temporal_analysis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncclab-sustech/EEG_Image_decode/HEAD/imgs/temporal_analysis.png -------------------------------------------------------------------------------- /Retrieval/data_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_path": "/home/ldy/Workspace/THINGS/Preprocessed_data_250Hz", 3 | "img_directory_training": "/home/ldy/Workspace/THINGS/images_set/training_images", 4 | "img_directory_test": "/home/ldy/Workspace/THINGS/images_set/test_images", 5 | "features_path":"/home/ldy/Workspace/THINGS/CLIP" 6 | } -------------------------------------------------------------------------------- /Generation/data_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_path": "/home/ldy/Workspace/THINGS/Preprocessed_data_250Hz", 3 | "img_directory_training": "/home/ldy/Workspace/THINGS/images_set/training_images", 4 | 5 | "img_directory_test": "/home/ldy/Workspace/Generation/generated_imgs", 6 | "features_path":"/home/ldy/Workspace/THINGS/CLIP" 7 | } 8 | 9 | -------------------------------------------------------------------------------- /models/utils/masking.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class TriangularCausalMask(): 5 | def __init__(self, B, L, device="cpu"): 6 | mask_shape = [B, 1, L, L] 7 | with torch.no_grad(): 8 | self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device) 9 | 10 | @property 11 | def mask(self): 12 | return self._mask 13 | 14 | 15 | class ProbMask(): 16 | def __init__(self, B, H, L, index, scores, device="cpu"): 17 | _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1) 18 | _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1]) 19 | indicator = _mask_ex[torch.arange(B)[:, None, None], 20 | torch.arange(H)[None, :, None], 21 | index, :].to(device) 22 | self._mask = indicator.view(scores.shape).to(device) 23 | 24 | @property 25 | def mask(self): 26 | return self._mask 27 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Core Deep Learning Framework 2 | torch==2.5.0 3 | torchvision==0.20.0 4 | torchaudio==2.5.0 5 | 6 | # Hugging Face Ecosystem 7 | transformers==4.36.0 8 | diffusers==0.30.0 9 | accelerate==1.5.2 10 | huggingface-hub==0.30.2 11 | 12 | # EEG Processing 13 | braindecode==0.8.1 14 | mne==1.9.0 15 | 16 | # CLIP and Vision 17 | clip @ git+https://github.com/openai/CLIP.git 18 | open-clip-torch 19 | 20 | # Image Generation 21 | dalle2-pytorch==1.15.6 22 | pytorch-msssim==1.0.0 23 | 24 | # Data Processing 25 | numpy==1.26.4 26 | pandas==2.3.3 27 | scipy==1.15.3 28 | scikit-learn==1.6.1 29 | h5py==3.13.0 30 | 31 | # Visualization 32 | matplotlib==3.10.7 33 | seaborn==0.13.2 34 | tqdm==4.67.1 35 | 36 | # Deep Learning Utilities 37 | einops==0.8.1 38 | info-nce-pytorch==0.1.0 39 | reformer-pytorch==1.4.4 40 | 41 | # Logging and Tracking 42 | wandb==0.19.10 43 | 44 | # Image Processing 45 | Pillow 46 | imageio==2.37.0 47 | kornia==0.8.0 48 | 49 | # Utilities 50 | ftfy==6.3.1 51 | regex==2024.11.6 52 | clip-retrieval 53 | -------------------------------------------------------------------------------- /models/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def RSE(pred, true): 5 | return np.sqrt(np.sum((true - pred) ** 2)) / np.sqrt(np.sum((true - true.mean()) ** 2)) 6 | 7 | 8 | def CORR(pred, true): 9 | u = ((true - true.mean(0)) * (pred - pred.mean(0))).sum(0) 10 | d = np.sqrt(((true - true.mean(0)) ** 2 * (pred - pred.mean(0)) ** 2).sum(0)) 11 | return (u / d).mean(-1) 12 | 13 | 14 | def MAE(pred, true): 15 | return np.mean(np.abs(pred - true)) 16 | 17 | 18 | def MSE(pred, true): 19 | return np.mean((pred - true) ** 2) 20 | 21 | 22 | def RMSE(pred, true): 23 | return np.sqrt(MSE(pred, true)) 24 | 25 | 26 | def MAPE(pred, true): 27 | return np.mean(np.abs((pred - true) / true)) 28 | 29 | 30 | def MSPE(pred, true): 31 | return np.mean(np.square((pred - true) / true)) 32 | 33 | 34 | def metric(pred, true): 35 | mae = MAE(pred, true) 36 | mse = MSE(pred, true) 37 | rmse = RMSE(pred, true) 38 | mape = MAPE(pred, true) 39 | mspe = MSPE(pred, true) 40 | 41 | return mae, mse, rmse, mape, mspe 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 DongyangLi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: BCI 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | 8 | dependencies: 9 | # Python version 10 | - python=3.12 11 | 12 | # Core tools 13 | - pip 14 | - ipykernel 15 | - ipython 16 | - jupyter 17 | - jupyterlab 18 | 19 | # Scientific computing (conda) 20 | - numpy=1.26.4 21 | - pandas=2.3.3 22 | - matplotlib=3.10.7 23 | - scikit-learn=1.6.1 24 | - scipy=1.15.3 25 | - seaborn=0.13.2 26 | - h5py=3.13.0 27 | 28 | # System utilities 29 | - pillow 30 | - pyyaml 31 | - tqdm=4.67.1 32 | - ffmpeg 33 | 34 | # PyTorch (CUDA 12.4) 35 | - pytorch=2.5.0 36 | - torchvision=0.20.0 37 | - torchaudio=2.5.0 38 | - pytorch-cuda=12.4 39 | 40 | # Pip dependencies 41 | - pip: 42 | # Hugging Face Ecosystem 43 | - transformers==4.36.0 44 | - diffusers==0.30.0 45 | - accelerate==1.5.2 46 | - huggingface-hub==0.30.2 47 | 48 | # EEG Processing 49 | - braindecode==0.8.1 50 | - mne==1.9.0 51 | 52 | # CLIP and Vision 53 | - git+https://github.com/openai/CLIP.git 54 | - open-clip-torch 55 | 56 | # Image Generation 57 | - dalle2-pytorch==1.15.6 58 | - pytorch-msssim==1.0.0 59 | 60 | # Deep Learning Utilities 61 | - einops==0.8.1 62 | - info-nce-pytorch==0.1.0 63 | - reformer-pytorch==1.4.4 64 | 65 | # Logging and Tracking 66 | - wandb==0.19.10 67 | 68 | # Image Processing 69 | - imageio==2.37.0 70 | - kornia==0.8.0 71 | 72 | # Utilities 73 | - ftfy==6.3.1 74 | - regex==2024.11.6 75 | - clip-retrieval 76 | -------------------------------------------------------------------------------- /models/subject_layers/StandardNorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Normalize(nn.Module): 6 | def __init__(self, num_features: int, eps=1e-5, affine=False, subtract_last=False, non_norm=False): 7 | """ 8 | :param num_features: the number of features or channels 9 | :param eps: a value added for numerical stability 10 | :param affine: if True, RevIN has learnable affine parameters 11 | """ 12 | super(Normalize, self).__init__() 13 | self.num_features = num_features 14 | self.eps = eps 15 | self.affine = affine 16 | self.subtract_last = subtract_last 17 | self.non_norm = non_norm 18 | if self.affine: 19 | self._init_params() 20 | 21 | def forward(self, x, mode: str): 22 | if mode == 'norm': 23 | self._get_statistics(x) 24 | x = self._normalize(x) 25 | elif mode == 'denorm': 26 | x = self._denormalize(x) 27 | else: 28 | raise NotImplementedError 29 | return x 30 | 31 | def _init_params(self): 32 | # initialize RevIN params: (C,) 33 | self.affine_weight = nn.Parameter(torch.ones(self.num_features)) 34 | self.affine_bias = nn.Parameter(torch.zeros(self.num_features)) 35 | 36 | def _get_statistics(self, x): 37 | dim2reduce = tuple(range(1, x.ndim - 1)) 38 | if self.subtract_last: 39 | self.last = x[:, -1, :].unsqueeze(1) 40 | else: 41 | self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach() 42 | self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach() 43 | 44 | def _normalize(self, x): 45 | if self.non_norm: 46 | return x 47 | if self.subtract_last: 48 | x = x - self.last 49 | else: 50 | x = x - self.mean 51 | x = x / self.stdev 52 | if self.affine: 53 | x = x * self.affine_weight 54 | x = x + self.affine_bias 55 | return x 56 | 57 | def _denormalize(self, x): 58 | if self.non_norm: 59 | return x 60 | if self.affine: 61 | x = x - self.affine_bias 62 | x = x / (self.affine_weight + self.eps * self.eps) 63 | x = x * self.stdev 64 | if self.subtract_last: 65 | x = x + self.last 66 | else: 67 | x = x + self.mean 68 | return x 69 | -------------------------------------------------------------------------------- /models/subject_layers/Conv_Blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Inception_Block_V1(nn.Module): 6 | def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True): 7 | super(Inception_Block_V1, self).__init__() 8 | self.in_channels = in_channels 9 | self.out_channels = out_channels 10 | self.num_kernels = num_kernels 11 | kernels = [] 12 | for i in range(self.num_kernels): 13 | kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=2 * i + 1, padding=i)) 14 | self.kernels = nn.ModuleList(kernels) 15 | if init_weight: 16 | self._initialize_weights() 17 | 18 | def _initialize_weights(self): 19 | for m in self.modules(): 20 | if isinstance(m, nn.Conv2d): 21 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 22 | if m.bias is not None: 23 | nn.init.constant_(m.bias, 0) 24 | 25 | def forward(self, x): 26 | res_list = [] 27 | for i in range(self.num_kernels): 28 | res_list.append(self.kernels[i](x)) 29 | res = torch.stack(res_list, dim=-1).mean(-1) 30 | return res 31 | 32 | 33 | class Inception_Block_V2(nn.Module): 34 | def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True): 35 | super(Inception_Block_V2, self).__init__() 36 | self.in_channels = in_channels 37 | self.out_channels = out_channels 38 | self.num_kernels = num_kernels 39 | kernels = [] 40 | for i in range(self.num_kernels // 2): 41 | kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=[1, 2 * i + 3], padding=[0, i + 1])) 42 | kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=[2 * i + 3, 1], padding=[i + 1, 0])) 43 | kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=1)) 44 | self.kernels = nn.ModuleList(kernels) 45 | if init_weight: 46 | self._initialize_weights() 47 | 48 | def _initialize_weights(self): 49 | for m in self.modules(): 50 | if isinstance(m, nn.Conv2d): 51 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 52 | if m.bias is not None: 53 | nn.init.constant_(m.bias, 0) 54 | 55 | def forward(self, x): 56 | res_list = [] 57 | for i in range(self.num_kernels + 1): 58 | res_list.append(self.kernels[i](x)) 59 | res = torch.stack(res_list, dim=-1).mean(-1) 60 | return res 61 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # ============================================================ 3 | # EEG Image Decode - Environment Setup Script 4 | # ============================================================ 5 | # This script creates a conda environment with all dependencies 6 | # for reproducing the experiments in the paper. 7 | # 8 | # Usage: . setup.sh 9 | # ============================================================ 10 | 11 | set -e 12 | 13 | ENV_NAME="BCI" 14 | PYTHON_VERSION="3.12" 15 | 16 | echo "============================================================" 17 | echo "Creating conda environment: $ENV_NAME with Python $PYTHON_VERSION" 18 | echo "============================================================" 19 | 20 | # Create conda environment 21 | conda create -n $ENV_NAME python=$PYTHON_VERSION -y 22 | conda activate $ENV_NAME 23 | 24 | echo "Installing base packages via conda..." 25 | conda install numpy matplotlib tqdm scikit-image jupyterlab -y 26 | conda install -c conda-forge accelerate -y 27 | 28 | echo "Installing PyTorch ecosystem (CUDA 12.4)..." 29 | pip install torch==2.5.0 torchvision==0.20.0 torchaudio==2.5.0 --index-url https://download.pytorch.org/whl/cu124 30 | 31 | echo "Installing Hugging Face packages..." 32 | pip install transformers==4.36.0 33 | pip install diffusers==0.30.0 34 | pip install huggingface-hub==0.30.2 35 | pip install accelerate==1.5.2 36 | 37 | echo "Installing CLIP packages..." 38 | pip install git+https://github.com/openai/CLIP.git 39 | pip install open_clip_torch 40 | pip install clip-retrieval 41 | 42 | echo "Installing EEG processing packages..." 43 | pip install braindecode==0.8.1 44 | pip install mne==1.9.0 45 | 46 | echo "Installing image generation packages..." 47 | pip install dalle2-pytorch==1.15.6 48 | pip install pytorch-msssim==1.0.0 49 | pip install kornia==0.8.0 50 | 51 | echo "Installing deep learning utilities..." 52 | pip install einops==0.8.1 53 | pip install info-nce-pytorch==0.1.0 54 | pip install reformer_pytorch==1.4.4 55 | 56 | echo "Installing logging and visualization..." 57 | pip install wandb==0.19.10 58 | pip install seaborn==0.13.2 59 | 60 | echo "Installing other utilities..." 61 | pip install ftfy==6.3.1 62 | pip install regex==2024.11.6 63 | pip install h5py==3.13.0 64 | pip install pandas==2.3.3 65 | pip install imageio==2.37.0 66 | pip install scipy==1.15.3 67 | pip install scikit-learn==1.6.1 68 | 69 | echo "============================================================" 70 | echo "Environment setup complete!" 71 | echo "Activate with: conda activate $ENV_NAME" 72 | echo "============================================================" 73 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # ======================================== 2 | # IDE / 编辑器 3 | # ======================================== 4 | .history/ 5 | .vscode/ 6 | .idea/ 7 | *.swp 8 | *.swo 9 | *~ 10 | 11 | # ======================================== 12 | # 操作系统 13 | # ======================================== 14 | .DS_Store 15 | Thumbs.db 16 | desktop.ini 17 | 18 | # ======================================== 19 | # Python 20 | # ======================================== 21 | __pycache__/ 22 | *.py[cod] 23 | *$py.class 24 | *.so 25 | .Python 26 | *.egg 27 | *.egg-info/ 28 | dist/ 29 | build/ 30 | eggs/ 31 | .eggs/ 32 | *.manifest 33 | *.spec 34 | 35 | # 虚拟环境 36 | env/ 37 | venv/ 38 | .venv/ 39 | ENV/ 40 | 41 | # ======================================== 42 | # Jupyter Notebook 43 | # ======================================== 44 | .ipynb_checkpoints/ 45 | */.ipynb_checkpoints/ 46 | 47 | # ======================================== 48 | # 深度学习 / 模型权重 49 | # ======================================== 50 | *.ckpt 51 | *.pth 52 | *.pt 53 | *.h5 54 | *.pb 55 | *.onnx 56 | *.safetensors 57 | checkpoints/ 58 | ckpts/ 59 | weights/ 60 | 61 | # ======================================== 62 | # 实验输出 / 日志 63 | # ======================================== 64 | outputs/ 65 | old_outputs/ 66 | logs/ 67 | wandb/ 68 | runs/ 69 | mlruns/ 70 | lightning_logs/ 71 | *.log 72 | 73 | # ======================================== 74 | # 数据文件(通常太大不应上传) 75 | # ======================================== 76 | # 如需上传特定数据,可用 !data/important.npy 排除 77 | data/ 78 | datasets/ 79 | *.npy 80 | *.npz 81 | *.mat 82 | *.pkl 83 | *.pickle 84 | *.hdf5 85 | 86 | # ======================================== 87 | # 生成的图像 / 结果 88 | # ======================================== 89 | generated_imgs/ 90 | generated_imgs_tensor/ 91 | generation_metric_outputs/ 92 | tab_generation_metric_outputs/ 93 | 94 | # ======================================== 95 | # 项目特定目录 96 | # ======================================== 97 | # Generation 模块 98 | Generation/ckpts/ 99 | Generation/fintune_ckpts/ 100 | Generation/CLIP-dissect/ 101 | Generation/centralRepo/ 102 | Generation/diffusers/ 103 | Generation/generated_imgs/ 104 | Generation/generated_imgs_tensor/ 105 | Generation/Generation/ 106 | Generation/generation_metric_outputs/ 107 | Generation/LAVIS/ 108 | Generation/models/ 109 | Generation/old_outputs/ 110 | Generation/outputs/ 111 | Generation/tab_generation_metric_outputs/ 112 | Generation/THINGS/ 113 | Generation/Visualize/ 114 | Generation/wandb/ 115 | 116 | # Retrieval 模块 117 | Retrieval/centralRepo/ 118 | Retrieval/CLIP-dissect/ 119 | Retrieval/conditional_outputs/ 120 | Retrieval/LAVIS/ 121 | Retrieval/old_outputs/ 122 | Retrieval/outputs/ 123 | Retrieval/Outputs/ 124 | Retrieval/sliding_window_outputs/ 125 | Retrieval/Visualize/ 126 | Retrieval/wandb/ 127 | 128 | # ======================================== 129 | # 其他 130 | # ======================================== 131 | *.tmp 132 | *.bak 133 | *.cache 134 | .cache/ 135 | tmp/ 136 | temp/ 137 | -------------------------------------------------------------------------------- /EEG-preprocessing/preprocessing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Refer to the code of Things-EEG2 but with a few differences. 3 | Many thanks! 4 | https://www.sciencedirect.com/science/article/pii/S1053811922008758 5 | """ 6 | 7 | """Preprocess the raw EEG data: channel selection, epoching, frequency 8 | downsampling, baseline correction, multivariate noise normalization (MVNN), 9 | sorting of the data image conditions and reshaping the data to: 10 | Image conditions × EEG repetitions × EEG channels × EEG time points. 11 | Then, the data of both test and training EEG partitions is saved. 12 | 13 | Parameters 14 | ---------- 15 | sub : int 16 | Used subject. 17 | n_ses : int 18 | Number of EEG sessions. 19 | sfreq : int 20 | Downsampling frequency. 21 | mvnn_dim : str 22 | Whether to compute the MVNN covariace matrices for each time point 23 | ('time') or for each epoch/repetition ('epochs'). 24 | project_dir : str 25 | Directory of the project folder. 26 | 27 | """ 28 | 29 | import argparse 30 | from preprocessing_utils import epoching 31 | from preprocessing_utils import mvnn 32 | from preprocessing_utils import save_prepr 33 | 34 | 35 | # ============================================================================= 36 | # Input arguments 37 | # ============================================================================= 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument('--sub', default=10, type=int) 40 | parser.add_argument('--n_ses', default=4, type=int) 41 | parser.add_argument('--sfreq', default=250, type=int) 42 | parser.add_argument('--mvnn_dim', default='epochs', type=str) 43 | parser.add_argument('--project_dir', default='/home/Data/Things-EEG2/', type=str) 44 | args = parser.parse_args() 45 | 46 | print('>>> EEG data preprocessing <<<') 47 | print('\nInput arguments:') 48 | for key, val in vars(args).items(): 49 | print('{:16} {}'.format(key, val)) 50 | 51 | # Set random seed for reproducible results 52 | seed = 20200220 53 | 54 | 55 | # ============================================================================= 56 | # Epoch and sort the data 57 | # ============================================================================= 58 | # Channel selection, epoching, baseline correction and frequency downsampling of 59 | # the test and training data partitions. 60 | # Then, the conditions are sorted and the EEG data is reshaped to: 61 | # Image conditions × EGG repetitions × EEG channels × EEG time points 62 | # This step is applied independently to the data of each partition and session. 63 | epoched_test, _, ch_names, times = epoching(args, 'test', seed) 64 | epoched_train, img_conditions_train, _, _ = epoching(args, 'training', seed) 65 | 66 | 67 | # ============================================================================= 68 | # Multivariate Noise Normalization 69 | # ============================================================================= 70 | # MVNN is applied independently to the data of each session. 71 | whitened_test, whitened_train = mvnn(args, epoched_test, epoched_train) 72 | del epoched_test, epoched_train 73 | 74 | 75 | # ============================================================================= 76 | # Merge and save the preprocessed data 77 | # ============================================================================= 78 | # In this step the data of all sessions is merged into the shape: 79 | # Image conditions × EGG repetitions × EEG channels × EEG time points 80 | # Then, the preprocessed data of the test and training data partitions is saved. 81 | save_prepr(args, whitened_test, whitened_train, img_conditions_train, ch_names, 82 | times, seed) 83 | 84 | -------------------------------------------------------------------------------- /models/utils/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import matplotlib.pyplot as plt 6 | import pandas as pd 7 | 8 | plt.switch_backend('agg') 9 | 10 | 11 | def adjust_learning_rate(optimizer, epoch, args): 12 | # lr = args.learning_rate * (0.2 ** (epoch // 2)) 13 | if args.lradj == 'type1': 14 | lr_adjust = {epoch: args.learning_rate * (0.5 ** ((epoch - 1) // 1))} 15 | elif args.lradj == 'type2': 16 | lr_adjust = { 17 | 2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6, 18 | 10: 5e-7, 15: 1e-7, 20: 5e-8 19 | } 20 | if epoch in lr_adjust.keys(): 21 | lr = lr_adjust[epoch] 22 | for param_group in optimizer.param_groups: 23 | param_group['lr'] = lr 24 | print('Updating learning rate to {}'.format(lr)) 25 | 26 | 27 | class EarlyStopping: 28 | def __init__(self, patience=7, verbose=False, delta=0): 29 | self.patience = patience 30 | self.verbose = verbose 31 | self.counter = 0 32 | self.best_score = None 33 | self.early_stop = False 34 | self.val_loss_min = np.Inf 35 | self.delta = delta 36 | 37 | def __call__(self, val_loss, model, path): 38 | score = -val_loss 39 | if self.best_score is None: 40 | self.best_score = score 41 | self.save_checkpoint(val_loss, model, path) 42 | elif score < self.best_score + self.delta: 43 | self.counter += 1 44 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}') 45 | if self.counter >= self.patience: 46 | self.early_stop = True 47 | else: 48 | self.best_score = score 49 | self.save_checkpoint(val_loss, model, path) 50 | self.counter = 0 51 | 52 | def save_checkpoint(self, val_loss, model, path): 53 | if self.verbose: 54 | print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 55 | torch.save(model.state_dict(), path + '/' + 'checkpoint.pth') 56 | self.val_loss_min = val_loss 57 | 58 | 59 | class dotdict(dict): 60 | """dot.notation access to dictionary attributes""" 61 | __getattr__ = dict.get 62 | __setattr__ = dict.__setitem__ 63 | __delattr__ = dict.__delitem__ 64 | 65 | 66 | class StandardScaler(): 67 | def __init__(self, mean, std): 68 | self.mean = mean 69 | self.std = std 70 | 71 | def transform(self, data): 72 | return (data - self.mean) / self.std 73 | 74 | def inverse_transform(self, data): 75 | return (data * self.std) + self.mean 76 | 77 | 78 | def visual(true, preds=None, name='./pic/test.pdf'): 79 | """ 80 | Results visualization 81 | """ 82 | plt.figure() 83 | plt.plot(true, label='GroundTruth', linewidth=2) 84 | if preds is not None: 85 | plt.plot(preds, label='Prediction', linewidth=2) 86 | plt.legend() 87 | plt.savefig(name, bbox_inches='tight') 88 | 89 | 90 | def adjustment(gt, pred): 91 | anomaly_state = False 92 | for i in range(len(gt)): 93 | if gt[i] == 1 and pred[i] == 1 and not anomaly_state: 94 | anomaly_state = True 95 | for j in range(i, 0, -1): 96 | if gt[j] == 0: 97 | break 98 | else: 99 | if pred[j] == 0: 100 | pred[j] = 1 101 | for j in range(i, len(gt)): 102 | if gt[j] == 0: 103 | break 104 | else: 105 | if pred[j] == 0: 106 | pred[j] = 1 107 | elif gt[i] == 0: 108 | anomaly_state = False 109 | if anomaly_state: 110 | pred[i] = 1 111 | return gt, pred 112 | 113 | 114 | def cal_accuracy(y_pred, y_true): 115 | return np.mean(y_pred == y_true) 116 | -------------------------------------------------------------------------------- /models/subject_layers/Crossformer_EncDec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange, repeat 4 | from layers.SelfAttention_Family import TwoStageAttentionLayer 5 | 6 | 7 | class SegMerging(nn.Module): 8 | def __init__(self, d_model, win_size, norm_layer=nn.LayerNorm): 9 | super().__init__() 10 | self.d_model = d_model 11 | self.win_size = win_size 12 | self.linear_trans = nn.Linear(win_size * d_model, d_model) 13 | self.norm = norm_layer(win_size * d_model) 14 | 15 | def forward(self, x): 16 | batch_size, ts_d, seg_num, d_model = x.shape 17 | pad_num = seg_num % self.win_size 18 | if pad_num != 0: 19 | pad_num = self.win_size - pad_num 20 | x = torch.cat((x, x[:, :, -pad_num:, :]), dim=-2) 21 | 22 | seg_to_merge = [] 23 | for i in range(self.win_size): 24 | seg_to_merge.append(x[:, :, i::self.win_size, :]) 25 | x = torch.cat(seg_to_merge, -1) 26 | 27 | x = self.norm(x) 28 | x = self.linear_trans(x) 29 | 30 | return x 31 | 32 | 33 | class scale_block(nn.Module): 34 | def __init__(self, configs, win_size, d_model, n_heads, d_ff, depth, dropout, \ 35 | seg_num=10, factor=10): 36 | super(scale_block, self).__init__() 37 | 38 | if win_size > 1: 39 | self.merge_layer = SegMerging(d_model, win_size, nn.LayerNorm) 40 | else: 41 | self.merge_layer = None 42 | 43 | self.encode_layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.encode_layers.append(TwoStageAttentionLayer(configs, seg_num, factor, d_model, n_heads, \ 47 | d_ff, dropout)) 48 | 49 | def forward(self, x, attn_mask=None, tau=None, delta=None): 50 | _, ts_dim, _, _ = x.shape 51 | 52 | if self.merge_layer is not None: 53 | x = self.merge_layer(x) 54 | 55 | for layer in self.encode_layers: 56 | x = layer(x) 57 | 58 | return x, None 59 | 60 | 61 | class Encoder(nn.Module): 62 | def __init__(self, attn_layers): 63 | super(Encoder, self).__init__() 64 | self.encode_blocks = nn.ModuleList(attn_layers) 65 | 66 | def forward(self, x): 67 | encode_x = [] 68 | encode_x.append(x) 69 | 70 | for block in self.encode_blocks: 71 | x, attns = block(x) 72 | encode_x.append(x) 73 | 74 | return encode_x, None 75 | 76 | 77 | class DecoderLayer(nn.Module): 78 | def __init__(self, self_attention, cross_attention, seg_len, d_model, d_ff=None, dropout=0.1): 79 | super(DecoderLayer, self).__init__() 80 | self.self_attention = self_attention 81 | self.cross_attention = cross_attention 82 | self.norm1 = nn.LayerNorm(d_model) 83 | self.norm2 = nn.LayerNorm(d_model) 84 | self.dropout = nn.Dropout(dropout) 85 | self.MLP1 = nn.Sequential(nn.Linear(d_model, d_model), 86 | nn.GELU(), 87 | nn.Linear(d_model, d_model)) 88 | self.linear_pred = nn.Linear(d_model, seg_len) 89 | 90 | def forward(self, x, cross): 91 | batch = x.shape[0] 92 | x = self.self_attention(x) 93 | x = rearrange(x, 'b ts_d out_seg_num d_model -> (b ts_d) out_seg_num d_model') 94 | 95 | cross = rearrange(cross, 'b ts_d in_seg_num d_model -> (b ts_d) in_seg_num d_model') 96 | tmp, attn = self.cross_attention(x, cross, cross, None, None, None,) 97 | x = x + self.dropout(tmp) 98 | y = x = self.norm1(x) 99 | y = self.MLP1(y) 100 | dec_output = self.norm2(x + y) 101 | 102 | dec_output = rearrange(dec_output, '(b ts_d) seg_dec_num d_model -> b ts_d seg_dec_num d_model', b=batch) 103 | layer_predict = self.linear_pred(dec_output) 104 | layer_predict = rearrange(layer_predict, 'b out_d seg_num seg_len -> b (out_d seg_num) seg_len') 105 | 106 | return dec_output, layer_predict 107 | 108 | 109 | class Decoder(nn.Module): 110 | def __init__(self, layers): 111 | super(Decoder, self).__init__() 112 | self.decode_layers = nn.ModuleList(layers) 113 | 114 | 115 | def forward(self, x, cross): 116 | final_predict = None 117 | i = 0 118 | 119 | ts_d = x.shape[1] 120 | for layer in self.decode_layers: 121 | cross_enc = cross[i] 122 | x, layer_predict = layer(x, cross_enc) 123 | if final_predict is None: 124 | final_predict = layer_predict 125 | else: 126 | final_predict = final_predict + layer_predict 127 | i += 1 128 | 129 | final_predict = rearrange(final_predict, 'b (out_d seg_num) seg_len -> b (seg_num seg_len) out_d', out_d=ts_d) 130 | 131 | return final_predict 132 | -------------------------------------------------------------------------------- /models/utils/timefeatures.py: -------------------------------------------------------------------------------- 1 | # From: gluonts/src/gluonts/time_feature/_base.py 2 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"). 5 | # You may not use this file except in compliance with the License. 6 | # A copy of the License is located at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # or in the "license" file accompanying this file. This file is distributed 11 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 12 | # express or implied. See the License for the specific language governing 13 | # permissions and limitations under the License. 14 | 15 | from typing import List 16 | 17 | import numpy as np 18 | import pandas as pd 19 | from pandas.tseries import offsets 20 | from pandas.tseries.frequencies import to_offset 21 | 22 | 23 | class TimeFeature: 24 | def __init__(self): 25 | pass 26 | 27 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 28 | pass 29 | 30 | def __repr__(self): 31 | return self.__class__.__name__ + "()" 32 | 33 | 34 | class SecondOfMinute(TimeFeature): 35 | """Minute of hour encoded as value between [-0.5, 0.5]""" 36 | 37 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 38 | return index.second / 59.0 - 0.5 39 | 40 | 41 | class MinuteOfHour(TimeFeature): 42 | """Minute of hour encoded as value between [-0.5, 0.5]""" 43 | 44 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 45 | return index.minute / 59.0 - 0.5 46 | 47 | 48 | class HourOfDay(TimeFeature): 49 | """Hour of day encoded as value between [-0.5, 0.5]""" 50 | 51 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 52 | return index.hour / 23.0 - 0.5 53 | 54 | 55 | class DayOfWeek(TimeFeature): 56 | """Hour of day encoded as value between [-0.5, 0.5]""" 57 | 58 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 59 | return index.dayofweek / 6.0 - 0.5 60 | 61 | 62 | class DayOfMonth(TimeFeature): 63 | """Day of month encoded as value between [-0.5, 0.5]""" 64 | 65 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 66 | return (index.day - 1) / 30.0 - 0.5 67 | 68 | 69 | class DayOfYear(TimeFeature): 70 | """Day of year encoded as value between [-0.5, 0.5]""" 71 | 72 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 73 | return (index.dayofyear - 1) / 365.0 - 0.5 74 | 75 | 76 | class MonthOfYear(TimeFeature): 77 | """Month of year encoded as value between [-0.5, 0.5]""" 78 | 79 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 80 | return (index.month - 1) / 11.0 - 0.5 81 | 82 | 83 | class WeekOfYear(TimeFeature): 84 | """Week of year encoded as value between [-0.5, 0.5]""" 85 | 86 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 87 | return (index.isocalendar().week - 1) / 52.0 - 0.5 88 | 89 | 90 | def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]: 91 | """ 92 | Returns a list of time features that will be appropriate for the given frequency string. 93 | Parameters 94 | ---------- 95 | freq_str 96 | Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc. 97 | """ 98 | 99 | features_by_offsets = { 100 | offsets.YearEnd: [], 101 | offsets.QuarterEnd: [MonthOfYear], 102 | offsets.MonthEnd: [MonthOfYear], 103 | offsets.Week: [DayOfMonth, WeekOfYear], 104 | offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear], 105 | offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear], 106 | offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear], 107 | offsets.Minute: [ 108 | MinuteOfHour, 109 | HourOfDay, 110 | DayOfWeek, 111 | DayOfMonth, 112 | DayOfYear, 113 | ], 114 | offsets.Second: [ 115 | SecondOfMinute, 116 | MinuteOfHour, 117 | HourOfDay, 118 | DayOfWeek, 119 | DayOfMonth, 120 | DayOfYear, 121 | ], 122 | } 123 | 124 | offset = to_offset(freq_str) 125 | 126 | for offset_type, feature_classes in features_by_offsets.items(): 127 | if isinstance(offset, offset_type): 128 | return [cls() for cls in feature_classes] 129 | 130 | supported_freq_msg = f""" 131 | Unsupported frequency {freq_str} 132 | The following frequencies are supported: 133 | Y - yearly 134 | alias: A 135 | M - monthly 136 | W - weekly 137 | D - daily 138 | B - business days 139 | H - hourly 140 | T - minutely 141 | alias: min 142 | S - secondly 143 | """ 144 | raise RuntimeError(supported_freq_msg) 145 | 146 | 147 | def time_features(dates, freq='h'): 148 | return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)]) 149 | -------------------------------------------------------------------------------- /models/subject_layers/Transformer_EncDec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ConvLayer(nn.Module): 7 | def __init__(self, c_in): 8 | super(ConvLayer, self).__init__() 9 | self.downConv = nn.Conv1d(in_channels=c_in, 10 | out_channels=c_in, 11 | kernel_size=3, 12 | padding=2, 13 | padding_mode='circular') 14 | self.norm = nn.BatchNorm1d(c_in) 15 | self.activation = nn.ELU() 16 | self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) 17 | 18 | def forward(self, x): 19 | x = self.downConv(x.permute(0, 2, 1)) 20 | x = self.norm(x) 21 | x = self.activation(x) 22 | x = self.maxPool(x) 23 | x = x.transpose(1, 2) 24 | return x 25 | 26 | 27 | class EncoderLayer(nn.Module): 28 | def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): 29 | super(EncoderLayer, self).__init__() 30 | d_ff = d_ff or 4 * d_model 31 | self.attention = attention 32 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) 33 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) 34 | self.norm1 = nn.LayerNorm(d_model) 35 | self.norm2 = nn.LayerNorm(d_model) 36 | self.dropout = nn.Dropout(dropout) 37 | self.activation = F.relu if activation == "relu" else F.gelu 38 | 39 | def forward(self, x, attn_mask=None, tau=None, delta=None): 40 | new_x, attn = self.attention( 41 | x, x, x, 42 | attn_mask=attn_mask, 43 | tau=tau, delta=delta 44 | ) 45 | x = x + self.dropout(new_x) 46 | 47 | y = x = self.norm1(x) 48 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 49 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 50 | 51 | return self.norm2(x + y), attn 52 | 53 | 54 | class Encoder(nn.Module): 55 | def __init__(self, attn_layers, conv_layers=None, norm_layer=None): 56 | super(Encoder, self).__init__() 57 | self.attn_layers = nn.ModuleList(attn_layers) 58 | self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None 59 | self.norm = norm_layer 60 | 61 | def forward(self, x, attn_mask=None, tau=None, delta=None): 62 | # x [B, L, D] 63 | attns = [] 64 | if self.conv_layers is not None: 65 | for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)): 66 | delta = delta if i == 0 else None 67 | x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) 68 | x = conv_layer(x) 69 | attns.append(attn) 70 | x, attn = self.attn_layers[-1](x, tau=tau, delta=None) 71 | attns.append(attn) 72 | else: 73 | for attn_layer in self.attn_layers: 74 | x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) 75 | attns.append(attn) 76 | 77 | if self.norm is not None: 78 | x = self.norm(x) 79 | 80 | return x, attns 81 | 82 | 83 | class DecoderLayer(nn.Module): 84 | def __init__(self, self_attention, cross_attention, d_model, d_ff=None, 85 | dropout=0.1, activation="relu"): 86 | super(DecoderLayer, self).__init__() 87 | d_ff = d_ff or 4 * d_model 88 | self.self_attention = self_attention 89 | self.cross_attention = cross_attention 90 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) 91 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) 92 | self.norm1 = nn.LayerNorm(d_model) 93 | self.norm2 = nn.LayerNorm(d_model) 94 | self.norm3 = nn.LayerNorm(d_model) 95 | self.dropout = nn.Dropout(dropout) 96 | self.activation = F.relu if activation == "relu" else F.gelu 97 | 98 | def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): 99 | x = x + self.dropout(self.self_attention( 100 | x, x, x, 101 | attn_mask=x_mask, 102 | tau=tau, delta=None 103 | )[0]) 104 | x = self.norm1(x) 105 | 106 | x = x + self.dropout(self.cross_attention( 107 | x, cross, cross, 108 | attn_mask=cross_mask, 109 | tau=tau, delta=delta 110 | )[0]) 111 | 112 | y = x = self.norm2(x) 113 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 114 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 115 | 116 | return self.norm3(x + y) 117 | 118 | 119 | class Decoder(nn.Module): 120 | def __init__(self, layers, norm_layer=None, projection=None): 121 | super(Decoder, self).__init__() 122 | self.layers = nn.ModuleList(layers) 123 | self.norm = norm_layer 124 | self.projection = projection 125 | 126 | def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): 127 | for layer in self.layers: 128 | x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta) 129 | 130 | if self.norm is not None: 131 | x = self.norm(x) 132 | 133 | if self.projection is not None: 134 | x = self.projection(x) 135 | return x 136 | -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import logging 9 | import torch 10 | import torch.distributed.nn 11 | from torch import distributed as dist, nn as nn 12 | from torch.nn import functional as F 13 | 14 | try: 15 | import horovod.torch as hvd 16 | except ImportError: 17 | hvd = None 18 | 19 | 20 | def gather_features( 21 | image_features, 22 | text_features, 23 | local_loss=False, 24 | gather_with_grad=False, 25 | rank=0, 26 | world_size=1, 27 | use_horovod=False, 28 | ): 29 | if use_horovod: 30 | assert hvd is not None, "Please install horovod" 31 | if gather_with_grad: 32 | all_image_features = hvd.allgather(image_features) 33 | all_text_features = hvd.allgather(text_features) 34 | else: 35 | with torch.no_grad(): 36 | all_image_features = hvd.allgather(image_features) 37 | all_text_features = hvd.allgather(text_features) 38 | if not local_loss: 39 | # ensure grads for local rank when all_* features don't have a gradient 40 | gathered_image_features = list( 41 | all_image_features.chunk(world_size, dim=0) 42 | ) 43 | gathered_text_features = list( 44 | all_text_features.chunk(world_size, dim=0) 45 | ) 46 | gathered_image_features[rank] = image_features 47 | gathered_text_features[rank] = text_features 48 | all_image_features = torch.cat(gathered_image_features, dim=0) 49 | all_text_features = torch.cat(gathered_text_features, dim=0) 50 | else: 51 | # We gather tensors from all gpus 52 | if gather_with_grad: 53 | all_image_features = torch.cat( 54 | torch.distributed.nn.all_gather(image_features), dim=0 55 | ) 56 | all_text_features = torch.cat( 57 | torch.distributed.nn.all_gather(text_features), dim=0 58 | ) 59 | else: 60 | gathered_image_features = [ 61 | torch.zeros_like(image_features) for _ in range(world_size) 62 | ] 63 | gathered_text_features = [ 64 | torch.zeros_like(text_features) for _ in range(world_size) 65 | ] 66 | dist.all_gather(gathered_image_features, image_features) 67 | dist.all_gather(gathered_text_features, text_features) 68 | if not local_loss: 69 | # ensure grads for local rank when all_* features don't have a gradient 70 | gathered_image_features[rank] = image_features 71 | gathered_text_features[rank] = text_features 72 | all_image_features = torch.cat(gathered_image_features, dim=0) 73 | all_text_features = torch.cat(gathered_text_features, dim=0) 74 | 75 | return all_image_features, all_text_features 76 | 77 | 78 | class ClipLoss(nn.Module): 79 | def __init__( 80 | self, 81 | local_loss=False, 82 | gather_with_grad=False, 83 | cache_labels=False, 84 | rank=0, 85 | world_size=1, 86 | use_horovod=False, 87 | ): 88 | super().__init__() 89 | self.local_loss = local_loss 90 | self.gather_with_grad = gather_with_grad 91 | self.cache_labels = cache_labels 92 | self.rank = rank 93 | self.world_size = world_size 94 | self.use_horovod = use_horovod 95 | 96 | # cache state 97 | self.prev_num_logits = 0 98 | self.labels = {} 99 | 100 | def forward(self, image_features, text_features, logit_scale): 101 | device = image_features.device 102 | if self.world_size > 1: 103 | all_image_features, all_text_features = gather_features( 104 | image_features, 105 | text_features, 106 | self.local_loss, 107 | self.gather_with_grad, 108 | self.rank, 109 | self.world_size, 110 | self.use_horovod, 111 | ) 112 | 113 | if self.local_loss: 114 | logits_per_image = logit_scale * image_features @ all_text_features.T 115 | logits_per_text = logit_scale * text_features @ all_image_features.T 116 | else: 117 | logits_per_image = ( 118 | logit_scale * all_image_features @ all_text_features.T 119 | ) 120 | logits_per_text = logits_per_image.T 121 | else: 122 | logits_per_image = logit_scale * image_features @ text_features.T 123 | logits_per_text = logit_scale * text_features @ image_features.T 124 | 125 | # calculated ground-truth and cache if enabled 126 | num_logits = logits_per_image.shape[0] 127 | if self.prev_num_logits != num_logits or device not in self.labels: 128 | labels = torch.arange(num_logits, device=device, dtype=torch.long) 129 | if self.world_size > 1 and self.local_loss: 130 | labels = labels + num_logits * self.rank 131 | if self.cache_labels: 132 | self.labels[device] = labels 133 | self.prev_num_logits = num_logits 134 | else: 135 | labels = self.labels[device] 136 | 137 | total_loss = ( 138 | F.cross_entropy(logits_per_image, labels) 139 | + F.cross_entropy(logits_per_text, labels) 140 | ) / 2 141 | return total_loss 142 | -------------------------------------------------------------------------------- /Generation/image_adapter.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from functools import partial\n", 10 | "\n", 11 | "from transformers import CLIPVisionModel \n", 12 | "import torch\n", 13 | "from torch import nn\n", 14 | "from torchvision import transforms\n", 15 | "from PIL import Image\n", 16 | "\n", 17 | "import torch\n", 18 | "import torch.nn as nn\n", 19 | "from transformers import CLIPVisionModel\n", 20 | "from torchvision import transforms\n", 21 | "\n", 22 | "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", 23 | "\n" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "train_pixel_img_feature = torch.load('/root/autodl-tmp/Workspace/EEG_caption/ViT-L-14_features_GIT_train.pt')['img_features']# \n", 33 | "test_pixel_img_feature = torch.load('/root/autodl-tmp/Workspace/EEG_caption/ViT-L-14_features_GIT_test.pt')['img_features']# \n", 34 | "train_img_feature = torch.load('/root/autodl-tmp/Workspace/EEG_caption/ViT-H-14_features_train.pt')['img_features'].unsqueeze(1)# \n", 35 | "test_img_feature = torch.load('/root/autodl-tmp/Workspace/EEG_caption/ViT-H-14_features_test.pt')['img_features'].unsqueeze(1)# \n" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "test_img_feature.shape" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "import torch\n", 54 | "import torch.nn as nn\n", 55 | "import torch.optim as optim\n", 56 | "from torch.utils.data import DataLoader, TensorDataset\n", 57 | "from einops.layers.torch import Rearrange, Reduce\n", 58 | "\n", 59 | "# Define the neural network\n", 60 | "class PixelProjector(nn.Sequential):\n", 61 | " def __init__(self, proj_dim=1024):\n", 62 | " super().__init__(\n", 63 | " Rearrange('B C L->B L C'), \n", 64 | " nn.Linear(1, 257),\n", 65 | " nn.LayerNorm(257),\n", 66 | " Rearrange('B L C->B C L'),\n", 67 | " nn.Linear(1024, 1024),\n", 68 | " nn.LayerNorm(proj_dim),\n", 69 | " )\n", 70 | " \n", 71 | " \n", 72 | "\n", 73 | "# Instantiate the model, loss function, and optimizer\n", 74 | "\n", 75 | "model = PixelProjector(proj_dim=1024).to(torch.bfloat16).to(device)\n", 76 | "criterion = nn.MSELoss()\n", 77 | "optimizer = optim.AdamW(model.parameters(), lr=0.001)\n", 78 | "\n", 79 | "# Prepare data loaders\n", 80 | "train_dataset = TensorDataset(train_img_feature, train_pixel_img_feature)\n", 81 | "test_dataset = TensorDataset(test_img_feature, test_pixel_img_feature)\n", 82 | "\n", 83 | "train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)\n", 84 | "test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)\n", 85 | "\n", 86 | "# Training loop\n", 87 | "num_epochs = 30\n", 88 | "for epoch in range(num_epochs):\n", 89 | " model.train()\n", 90 | " running_loss = 0.0\n", 91 | " for inputs, targets in train_loader:\n", 92 | " inputs, targets = inputs.to(torch.bfloat16).to(device), targets.to(torch.bfloat16).to(device)\n", 93 | " optimizer.zero_grad()\n", 94 | " outputs = model(inputs)\n", 95 | " loss = criterion(outputs, targets)\n", 96 | " loss.backward()\n", 97 | " optimizer.step()\n", 98 | " running_loss += loss.item()\n", 99 | " \n", 100 | " print(f\"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}\")\n", 101 | "\n", 102 | "# Testing loop\n", 103 | "model.eval()\n", 104 | "test_loss = 0.0\n", 105 | "with torch.no_grad():\n", 106 | " for inputs, targets in test_loader:\n", 107 | " inputs, targets = inputs.to(torch.bfloat16).to(device), targets.to(torch.bfloat16).to(device)\n", 108 | " outputs = model(inputs)\n", 109 | " loss = criterion(outputs, targets)\n", 110 | " test_loss += loss.item()\n", 111 | "\n", 112 | "print(f\"Test Loss: {test_loss/len(test_loader)}\")\n", 113 | "\n", 114 | "# Save the trained model\n", 115 | "torch.save(model.state_dict(), '/root/autodl-tmp/Workspace/EEG_caption/model_weights/PixelProjector_best.bin')\n", 116 | "print(\"Model saved as PixelProjector.bin\")\n" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 8, 122 | "metadata": {}, 123 | "outputs": [ 124 | { 125 | "name": "stdout", 126 | "output_type": "stream", 127 | "text": [ 128 | "Model saved as PixelProjector.bin\n" 129 | ] 130 | } 131 | ], 132 | "source": [ 133 | "# Save the trained model\n", 134 | "torch.save(model.state_dict(), '/root/autodl-tmp/Workspace/EEG_caption/model_weights/PixelProjector_best.bin')\n", 135 | "print(\"Model saved as PixelProjector.bin\")" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [] 144 | } 145 | ], 146 | "metadata": { 147 | "kernelspec": { 148 | "display_name": "BCI", 149 | "language": "python", 150 | "name": "python3" 151 | }, 152 | "language_info": { 153 | "codemirror_mode": { 154 | "name": "ipython", 155 | "version": 3 156 | }, 157 | "file_extension": ".py", 158 | "mimetype": "text/x-python", 159 | "name": "python", 160 | "nbconvert_exporter": "python", 161 | "pygments_lexer": "ipython3", 162 | "version": "3.10.0" 163 | } 164 | }, 165 | "nbformat": 4, 166 | "nbformat_minor": 2 167 | } 168 | -------------------------------------------------------------------------------- /models/subject_layers/AutoCorrelation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import math 7 | from math import sqrt 8 | import os 9 | 10 | 11 | class AutoCorrelation(nn.Module): 12 | """ 13 | AutoCorrelation Mechanism with the following two phases: 14 | (1) period-based dependencies discovery 15 | (2) time delay aggregation 16 | This block can replace the self-attention family mechanism seamlessly. 17 | """ 18 | 19 | def __init__(self, mask_flag=True, factor=1, scale=None, attention_dropout=0.1, output_attention=False): 20 | super(AutoCorrelation, self).__init__() 21 | self.factor = factor 22 | self.scale = scale 23 | self.mask_flag = mask_flag 24 | self.output_attention = output_attention 25 | self.dropout = nn.Dropout(attention_dropout) 26 | 27 | def time_delay_agg_training(self, values, corr): 28 | """ 29 | SpeedUp version of Autocorrelation (a batch-normalization style design) 30 | This is for the training phase. 31 | """ 32 | head = values.shape[1] 33 | channel = values.shape[2] 34 | length = values.shape[3] 35 | # find top k 36 | top_k = int(self.factor * math.log(length)) 37 | mean_value = torch.mean(torch.mean(corr, dim=1), dim=1) 38 | index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1] 39 | weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1) 40 | # update corr 41 | tmp_corr = torch.softmax(weights, dim=-1) 42 | # aggregation 43 | tmp_values = values 44 | delays_agg = torch.zeros_like(values).float() 45 | for i in range(top_k): 46 | pattern = torch.roll(tmp_values, -int(index[i]), -1) 47 | delays_agg = delays_agg + pattern * \ 48 | (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)) 49 | return delays_agg 50 | 51 | def time_delay_agg_inference(self, values, corr): 52 | """ 53 | SpeedUp version of Autocorrelation (a batch-normalization style design) 54 | This is for the inference phase. 55 | """ 56 | batch = values.shape[0] 57 | head = values.shape[1] 58 | channel = values.shape[2] 59 | length = values.shape[3] 60 | # index init 61 | init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).cuda() 62 | # find top k 63 | top_k = int(self.factor * math.log(length)) 64 | mean_value = torch.mean(torch.mean(corr, dim=1), dim=1) 65 | weights, delay = torch.topk(mean_value, top_k, dim=-1) 66 | # update corr 67 | tmp_corr = torch.softmax(weights, dim=-1) 68 | # aggregation 69 | tmp_values = values.repeat(1, 1, 1, 2) 70 | delays_agg = torch.zeros_like(values).float() 71 | for i in range(top_k): 72 | tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length) 73 | pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay) 74 | delays_agg = delays_agg + pattern * \ 75 | (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)) 76 | return delays_agg 77 | 78 | def time_delay_agg_full(self, values, corr): 79 | """ 80 | Standard version of Autocorrelation 81 | """ 82 | batch = values.shape[0] 83 | head = values.shape[1] 84 | channel = values.shape[2] 85 | length = values.shape[3] 86 | # index init 87 | init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).cuda() 88 | # find top k 89 | top_k = int(self.factor * math.log(length)) 90 | weights, delay = torch.topk(corr, top_k, dim=-1) 91 | # update corr 92 | tmp_corr = torch.softmax(weights, dim=-1) 93 | # aggregation 94 | tmp_values = values.repeat(1, 1, 1, 2) 95 | delays_agg = torch.zeros_like(values).float() 96 | for i in range(top_k): 97 | tmp_delay = init_index + delay[..., i].unsqueeze(-1) 98 | pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay) 99 | delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1)) 100 | return delays_agg 101 | 102 | def forward(self, queries, keys, values, attn_mask): 103 | B, L, H, E = queries.shape 104 | _, S, _, D = values.shape 105 | if L > S: 106 | zeros = torch.zeros_like(queries[:, :(L - S), :]).float() 107 | values = torch.cat([values, zeros], dim=1) 108 | keys = torch.cat([keys, zeros], dim=1) 109 | else: 110 | values = values[:, :L, :, :] 111 | keys = keys[:, :L, :, :] 112 | 113 | # period-based dependencies 114 | q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1) 115 | k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1) 116 | res = q_fft * torch.conj(k_fft) 117 | corr = torch.fft.irfft(res, dim=-1) 118 | 119 | # time delay agg 120 | if self.training: 121 | V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2) 122 | else: 123 | V = self.time_delay_agg_inference(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2) 124 | 125 | if self.output_attention: 126 | return (V.contiguous(), corr.permute(0, 3, 1, 2)) 127 | else: 128 | return (V.contiguous(), None) 129 | 130 | 131 | class AutoCorrelationLayer(nn.Module): 132 | def __init__(self, correlation, d_model, n_heads, d_keys=None, 133 | d_values=None): 134 | super(AutoCorrelationLayer, self).__init__() 135 | 136 | d_keys = d_keys or (d_model // n_heads) 137 | d_values = d_values or (d_model // n_heads) 138 | 139 | self.inner_correlation = correlation 140 | self.query_projection = nn.Linear(d_model, d_keys * n_heads) 141 | self.key_projection = nn.Linear(d_model, d_keys * n_heads) 142 | self.value_projection = nn.Linear(d_model, d_values * n_heads) 143 | self.out_projection = nn.Linear(d_values * n_heads, d_model) 144 | self.n_heads = n_heads 145 | 146 | def forward(self, queries, keys, values, attn_mask): 147 | B, L, _ = queries.shape 148 | _, S, _ = keys.shape 149 | H = self.n_heads 150 | 151 | queries = self.query_projection(queries).view(B, L, H, -1) 152 | keys = self.key_projection(keys).view(B, S, H, -1) 153 | values = self.value_projection(values).view(B, S, H, -1) 154 | 155 | out, attn = self.inner_correlation( 156 | queries, 157 | keys, 158 | values, 159 | attn_mask 160 | ) 161 | out = out.view(B, L, -1) 162 | 163 | return self.out_projection(out), attn 164 | -------------------------------------------------------------------------------- /models/subject_layers/Autoformer_EncDec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class my_Layernorm(nn.Module): 7 | """ 8 | Special designed layernorm for the seasonal part 9 | """ 10 | 11 | def __init__(self, channels): 12 | super(my_Layernorm, self).__init__() 13 | self.layernorm = nn.LayerNorm(channels) 14 | 15 | def forward(self, x): 16 | x_hat = self.layernorm(x) 17 | bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1) 18 | return x_hat - bias 19 | 20 | 21 | class moving_avg(nn.Module): 22 | """ 23 | Moving average block to highlight the trend of time series 24 | """ 25 | 26 | def __init__(self, kernel_size, stride): 27 | super(moving_avg, self).__init__() 28 | self.kernel_size = kernel_size 29 | self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0) 30 | 31 | def forward(self, x): 32 | # padding on the both ends of time series 33 | front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1) 34 | end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1) 35 | x = torch.cat([front, x, end], dim=1) 36 | x = self.avg(x.permute(0, 2, 1)) 37 | x = x.permute(0, 2, 1) 38 | return x 39 | 40 | 41 | class series_decomp(nn.Module): 42 | """ 43 | Series decomposition block 44 | """ 45 | 46 | def __init__(self, kernel_size): 47 | super(series_decomp, self).__init__() 48 | self.moving_avg = moving_avg(kernel_size, stride=1) 49 | 50 | def forward(self, x): 51 | moving_mean = self.moving_avg(x) 52 | res = x - moving_mean 53 | return res, moving_mean 54 | 55 | 56 | class series_decomp_multi(nn.Module): 57 | """ 58 | Multiple Series decomposition block from FEDformer 59 | """ 60 | 61 | def __init__(self, kernel_size): 62 | super(series_decomp_multi, self).__init__() 63 | self.kernel_size = kernel_size 64 | self.series_decomp = [series_decomp(kernel) for kernel in kernel_size] 65 | 66 | def forward(self, x): 67 | moving_mean = [] 68 | res = [] 69 | for func in self.series_decomp: 70 | sea, moving_avg = func(x) 71 | moving_mean.append(moving_avg) 72 | res.append(sea) 73 | 74 | sea = sum(res) / len(res) 75 | moving_mean = sum(moving_mean) / len(moving_mean) 76 | return sea, moving_mean 77 | 78 | 79 | class EncoderLayer(nn.Module): 80 | """ 81 | Autoformer encoder layer with the progressive decomposition architecture 82 | """ 83 | 84 | def __init__(self, attention, d_model, d_ff=None, moving_avg=25, dropout=0.1, activation="relu"): 85 | super(EncoderLayer, self).__init__() 86 | d_ff = d_ff or 4 * d_model 87 | self.attention = attention 88 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False) 89 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False) 90 | self.decomp1 = series_decomp(moving_avg) 91 | self.decomp2 = series_decomp(moving_avg) 92 | self.dropout = nn.Dropout(dropout) 93 | self.activation = F.relu if activation == "relu" else F.gelu 94 | 95 | def forward(self, x, attn_mask=None): 96 | new_x, attn = self.attention( 97 | x, x, x, 98 | attn_mask=attn_mask 99 | ) 100 | x = x + self.dropout(new_x) 101 | x, _ = self.decomp1(x) 102 | y = x 103 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 104 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 105 | res, _ = self.decomp2(x + y) 106 | return res, attn 107 | 108 | 109 | class Encoder(nn.Module): 110 | """ 111 | Autoformer encoder 112 | """ 113 | 114 | def __init__(self, attn_layers, conv_layers=None, norm_layer=None): 115 | super(Encoder, self).__init__() 116 | self.attn_layers = nn.ModuleList(attn_layers) 117 | self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None 118 | self.norm = norm_layer 119 | 120 | def forward(self, x, attn_mask=None): 121 | attns = [] 122 | if self.conv_layers is not None: 123 | for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers): 124 | x, attn = attn_layer(x, attn_mask=attn_mask) 125 | x = conv_layer(x) 126 | attns.append(attn) 127 | x, attn = self.attn_layers[-1](x) 128 | attns.append(attn) 129 | else: 130 | for attn_layer in self.attn_layers: 131 | x, attn = attn_layer(x, attn_mask=attn_mask) 132 | attns.append(attn) 133 | 134 | if self.norm is not None: 135 | x = self.norm(x) 136 | 137 | return x, attns 138 | 139 | 140 | class DecoderLayer(nn.Module): 141 | """ 142 | Autoformer decoder layer with the progressive decomposition architecture 143 | """ 144 | 145 | def __init__(self, self_attention, cross_attention, d_model, c_out, d_ff=None, 146 | moving_avg=25, dropout=0.1, activation="relu"): 147 | super(DecoderLayer, self).__init__() 148 | d_ff = d_ff or 4 * d_model 149 | self.self_attention = self_attention 150 | self.cross_attention = cross_attention 151 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False) 152 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False) 153 | self.decomp1 = series_decomp(moving_avg) 154 | self.decomp2 = series_decomp(moving_avg) 155 | self.decomp3 = series_decomp(moving_avg) 156 | self.dropout = nn.Dropout(dropout) 157 | self.projection = nn.Conv1d(in_channels=d_model, out_channels=c_out, kernel_size=3, stride=1, padding=1, 158 | padding_mode='circular', bias=False) 159 | self.activation = F.relu if activation == "relu" else F.gelu 160 | 161 | def forward(self, x, cross, x_mask=None, cross_mask=None): 162 | x = x + self.dropout(self.self_attention( 163 | x, x, x, 164 | attn_mask=x_mask 165 | )[0]) 166 | x, trend1 = self.decomp1(x) 167 | x = x + self.dropout(self.cross_attention( 168 | x, cross, cross, 169 | attn_mask=cross_mask 170 | )[0]) 171 | x, trend2 = self.decomp2(x) 172 | y = x 173 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 174 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 175 | x, trend3 = self.decomp3(x + y) 176 | 177 | residual_trend = trend1 + trend2 + trend3 178 | residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(1, 2) 179 | return x, residual_trend 180 | 181 | 182 | class Decoder(nn.Module): 183 | """ 184 | Autoformer encoder 185 | """ 186 | 187 | def __init__(self, layers, norm_layer=None, projection=None): 188 | super(Decoder, self).__init__() 189 | self.layers = nn.ModuleList(layers) 190 | self.norm = norm_layer 191 | self.projection = projection 192 | 193 | def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None): 194 | for layer in self.layers: 195 | x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask) 196 | trend = trend + residual_trend 197 | 198 | if self.norm is not None: 199 | x = self.norm(x) 200 | 201 | if self.projection is not None: 202 | x = self.projection(x) 203 | return x, trend 204 | -------------------------------------------------------------------------------- /models/subject_layers/FourierCorrelation.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # author=maziqing 3 | # email=maziqing.mzq@alibaba-inc.com 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | def get_frequency_modes(seq_len, modes=64, mode_select_method='random'): 11 | """ 12 | get modes on frequency domain: 13 | 'random' means sampling randomly; 14 | 'else' means sampling the lowest modes; 15 | """ 16 | modes = min(modes, seq_len // 2) 17 | if mode_select_method == 'random': 18 | index = list(range(0, seq_len // 2)) 19 | np.random.shuffle(index) 20 | index = index[:modes] 21 | else: 22 | index = list(range(0, modes)) 23 | index.sort() 24 | return index 25 | 26 | 27 | # ########## fourier layer ############# 28 | class FourierBlock(nn.Module): 29 | def __init__(self, in_channels, out_channels, seq_len, modes=0, mode_select_method='random'): 30 | super(FourierBlock, self).__init__() 31 | print('fourier enhanced block used!') 32 | """ 33 | 1D Fourier block. It performs representation learning on frequency domain, 34 | it does FFT, linear transform, and Inverse FFT. 35 | """ 36 | # get modes on frequency domain 37 | self.index = get_frequency_modes(seq_len, modes=modes, mode_select_method=mode_select_method) 38 | print('modes={}, index={}'.format(modes, self.index)) 39 | 40 | self.scale = (1 / (in_channels * out_channels)) 41 | self.weights1 = nn.Parameter( 42 | self.scale * torch.rand(8, in_channels // 8, out_channels // 8, len(self.index), dtype=torch.float)) 43 | self.weights2 = nn.Parameter( 44 | self.scale * torch.rand(8, in_channels // 8, out_channels // 8, len(self.index), dtype=torch.float)) 45 | 46 | # Complex multiplication 47 | def compl_mul1d(self, order, x, weights): 48 | x_flag = True 49 | w_flag = True 50 | if not torch.is_complex(x): 51 | x_flag = False 52 | x = torch.complex(x, torch.zeros_like(x).to(x.device)) 53 | if not torch.is_complex(weights): 54 | w_flag = False 55 | weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device)) 56 | if x_flag or w_flag: 57 | return torch.complex(torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag), 58 | torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real)) 59 | else: 60 | return torch.einsum(order, x.real, weights.real) 61 | 62 | def forward(self, q, k, v, mask): 63 | # size = [B, L, H, E] 64 | B, L, H, E = q.shape 65 | x = q.permute(0, 2, 3, 1) 66 | # Compute Fourier coefficients 67 | x_ft = torch.fft.rfft(x, dim=-1) 68 | # Perform Fourier neural operations 69 | out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat) 70 | for wi, i in enumerate(self.index): 71 | if i >= x_ft.shape[3] or wi >= out_ft.shape[3]: 72 | continue 73 | out_ft[:, :, :, wi] = self.compl_mul1d("bhi,hio->bho", x_ft[:, :, :, i], 74 | torch.complex(self.weights1, self.weights2)[:, :, :, wi]) 75 | # Return to time domain 76 | x = torch.fft.irfft(out_ft, n=x.size(-1)) 77 | return (x, None) 78 | 79 | 80 | # ########## Fourier Cross Former #################### 81 | class FourierCrossAttention(nn.Module): 82 | def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes=64, mode_select_method='random', 83 | activation='tanh', policy=0, num_heads=8): 84 | super(FourierCrossAttention, self).__init__() 85 | print(' fourier enhanced cross attention used!') 86 | """ 87 | 1D Fourier Cross Attention layer. It does FFT, linear transform, attention mechanism and Inverse FFT. 88 | """ 89 | self.activation = activation 90 | self.in_channels = in_channels 91 | self.out_channels = out_channels 92 | # get modes for queries and keys (& values) on frequency domain 93 | self.index_q = get_frequency_modes(seq_len_q, modes=modes, mode_select_method=mode_select_method) 94 | self.index_kv = get_frequency_modes(seq_len_kv, modes=modes, mode_select_method=mode_select_method) 95 | 96 | print('modes_q={}, index_q={}'.format(len(self.index_q), self.index_q)) 97 | print('modes_kv={}, index_kv={}'.format(len(self.index_kv), self.index_kv)) 98 | 99 | self.scale = (1 / (in_channels * out_channels)) 100 | self.weights1 = nn.Parameter( 101 | self.scale * torch.rand(num_heads, in_channels // num_heads, out_channels // num_heads, len(self.index_q), dtype=torch.float)) 102 | self.weights2 = nn.Parameter( 103 | self.scale * torch.rand(num_heads, in_channels // num_heads, out_channels // num_heads, len(self.index_q), dtype=torch.float)) 104 | 105 | # Complex multiplication 106 | def compl_mul1d(self, order, x, weights): 107 | x_flag = True 108 | w_flag = True 109 | if not torch.is_complex(x): 110 | x_flag = False 111 | x = torch.complex(x, torch.zeros_like(x).to(x.device)) 112 | if not torch.is_complex(weights): 113 | w_flag = False 114 | weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device)) 115 | if x_flag or w_flag: 116 | return torch.complex(torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag), 117 | torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real)) 118 | else: 119 | return torch.einsum(order, x.real, weights.real) 120 | 121 | def forward(self, q, k, v, mask): 122 | # size = [B, L, H, E] 123 | B, L, H, E = q.shape 124 | xq = q.permute(0, 2, 3, 1) # size = [B, H, E, L] 125 | xk = k.permute(0, 2, 3, 1) 126 | xv = v.permute(0, 2, 3, 1) 127 | 128 | # Compute Fourier coefficients 129 | xq_ft_ = torch.zeros(B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat) 130 | xq_ft = torch.fft.rfft(xq, dim=-1) 131 | for i, j in enumerate(self.index_q): 132 | if j >= xq_ft.shape[3]: 133 | continue 134 | xq_ft_[:, :, :, i] = xq_ft[:, :, :, j] 135 | xk_ft_ = torch.zeros(B, H, E, len(self.index_kv), device=xq.device, dtype=torch.cfloat) 136 | xk_ft = torch.fft.rfft(xk, dim=-1) 137 | for i, j in enumerate(self.index_kv): 138 | if j >= xk_ft.shape[3]: 139 | continue 140 | xk_ft_[:, :, :, i] = xk_ft[:, :, :, j] 141 | 142 | # perform attention mechanism on frequency domain 143 | xqk_ft = (self.compl_mul1d("bhex,bhey->bhxy", xq_ft_, xk_ft_)) 144 | if self.activation == 'tanh': 145 | xqk_ft = torch.complex(xqk_ft.real.tanh(), xqk_ft.imag.tanh()) 146 | elif self.activation == 'softmax': 147 | xqk_ft = torch.softmax(abs(xqk_ft), dim=-1) 148 | xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft)) 149 | else: 150 | raise Exception('{} actiation function is not implemented'.format(self.activation)) 151 | xqkv_ft = self.compl_mul1d("bhxy,bhey->bhex", xqk_ft, xk_ft_) 152 | xqkvw = self.compl_mul1d("bhex,heox->bhox", xqkv_ft, torch.complex(self.weights1, self.weights2)) 153 | out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat) 154 | for i, j in enumerate(self.index_q): 155 | if i >= xqkvw.shape[3] or j >= out_ft.shape[3]: 156 | continue 157 | out_ft[:, :, :, j] = xqkvw[:, :, :, i] 158 | # Return to time domain 159 | out = torch.fft.irfft(out_ft / self.in_channels / self.out_channels, n=xq.size(-1)) 160 | return (out, None) 161 | -------------------------------------------------------------------------------- /models/subject_layers/Pyraformer_EncDec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.modules.linear import Linear 5 | from layers.SelfAttention_Family import AttentionLayer, FullAttention 6 | from layers.Embed import DataEmbedding 7 | import math 8 | 9 | 10 | def get_mask(input_size, window_size, inner_size): 11 | """Get the attention mask of PAM-Naive""" 12 | # Get the size of all layers 13 | all_size = [] 14 | all_size.append(input_size) 15 | for i in range(len(window_size)): 16 | layer_size = math.floor(all_size[i] / window_size[i]) 17 | all_size.append(layer_size) 18 | 19 | seq_length = sum(all_size) 20 | mask = torch.zeros(seq_length, seq_length) 21 | 22 | # get intra-scale mask 23 | inner_window = inner_size // 2 24 | for layer_idx in range(len(all_size)): 25 | start = sum(all_size[:layer_idx]) 26 | for i in range(start, start + all_size[layer_idx]): 27 | left_side = max(i - inner_window, start) 28 | right_side = min(i + inner_window + 1, start + all_size[layer_idx]) 29 | mask[i, left_side:right_side] = 1 30 | 31 | # get inter-scale mask 32 | for layer_idx in range(1, len(all_size)): 33 | start = sum(all_size[:layer_idx]) 34 | for i in range(start, start + all_size[layer_idx]): 35 | left_side = (start - all_size[layer_idx - 1]) + \ 36 | (i - start) * window_size[layer_idx - 1] 37 | if i == (start + all_size[layer_idx] - 1): 38 | right_side = start 39 | else: 40 | right_side = ( 41 | start - all_size[layer_idx - 1]) + (i - start + 1) * window_size[layer_idx - 1] 42 | mask[i, left_side:right_side] = 1 43 | mask[left_side:right_side, i] = 1 44 | 45 | mask = (1 - mask).bool() 46 | 47 | return mask, all_size 48 | 49 | 50 | def refer_points(all_sizes, window_size): 51 | """Gather features from PAM's pyramid sequences""" 52 | input_size = all_sizes[0] 53 | indexes = torch.zeros(input_size, len(all_sizes)) 54 | 55 | for i in range(input_size): 56 | indexes[i][0] = i 57 | former_index = i 58 | for j in range(1, len(all_sizes)): 59 | start = sum(all_sizes[:j]) 60 | inner_layer_idx = former_index - (start - all_sizes[j - 1]) 61 | former_index = start + \ 62 | min(inner_layer_idx // window_size[j - 1], all_sizes[j] - 1) 63 | indexes[i][j] = former_index 64 | 65 | indexes = indexes.unsqueeze(0).unsqueeze(3) 66 | 67 | return indexes.long() 68 | 69 | 70 | class RegularMask(): 71 | def __init__(self, mask): 72 | self._mask = mask.unsqueeze(1) 73 | 74 | @property 75 | def mask(self): 76 | return self._mask 77 | 78 | 79 | class EncoderLayer(nn.Module): 80 | """ Compose with two layers """ 81 | 82 | def __init__(self, d_model, d_inner, n_head, dropout=0.1, normalize_before=True): 83 | super(EncoderLayer, self).__init__() 84 | 85 | self.slf_attn = AttentionLayer( 86 | FullAttention(mask_flag=True, factor=0, 87 | attention_dropout=dropout, output_attention=False), 88 | d_model, n_head) 89 | self.pos_ffn = PositionwiseFeedForward( 90 | d_model, d_inner, dropout=dropout, normalize_before=normalize_before) 91 | 92 | def forward(self, enc_input, slf_attn_mask=None): 93 | attn_mask = RegularMask(slf_attn_mask) 94 | enc_output, _ = self.slf_attn( 95 | enc_input, enc_input, enc_input, attn_mask=attn_mask) 96 | enc_output = self.pos_ffn(enc_output) 97 | return enc_output 98 | 99 | 100 | class Encoder(nn.Module): 101 | """ A encoder model with self attention mechanism. """ 102 | 103 | def __init__(self, configs, window_size, inner_size): 104 | super().__init__() 105 | 106 | d_bottleneck = configs.d_model//4 107 | 108 | self.mask, self.all_size = get_mask( 109 | configs.seq_len, window_size, inner_size) 110 | self.indexes = refer_points(self.all_size, window_size) 111 | self.layers = nn.ModuleList([ 112 | EncoderLayer(configs.d_model, configs.d_ff, configs.n_heads, dropout=configs.dropout, 113 | normalize_before=False) for _ in range(configs.e_layers) 114 | ]) # naive pyramid attention 115 | 116 | self.enc_embedding = DataEmbedding( 117 | configs.enc_in, configs.d_model, configs.dropout) 118 | self.conv_layers = Bottleneck_Construct( 119 | configs.d_model, window_size, d_bottleneck) 120 | 121 | def forward(self, x_enc, x_mark_enc): 122 | seq_enc = self.enc_embedding(x_enc, x_mark_enc) 123 | 124 | mask = self.mask.repeat(len(seq_enc), 1, 1).to(x_enc.device) 125 | seq_enc = self.conv_layers(seq_enc) 126 | 127 | for i in range(len(self.layers)): 128 | seq_enc = self.layers[i](seq_enc, mask) 129 | 130 | indexes = self.indexes.repeat(seq_enc.size( 131 | 0), 1, 1, seq_enc.size(2)).to(seq_enc.device) 132 | indexes = indexes.view(seq_enc.size(0), -1, seq_enc.size(2)) 133 | all_enc = torch.gather(seq_enc, 1, indexes) 134 | seq_enc = all_enc.view(seq_enc.size(0), self.all_size[0], -1) 135 | 136 | return seq_enc 137 | 138 | 139 | class ConvLayer(nn.Module): 140 | def __init__(self, c_in, window_size): 141 | super(ConvLayer, self).__init__() 142 | self.downConv = nn.Conv1d(in_channels=c_in, 143 | out_channels=c_in, 144 | kernel_size=window_size, 145 | stride=window_size) 146 | self.norm = nn.BatchNorm1d(c_in) 147 | self.activation = nn.ELU() 148 | 149 | def forward(self, x): 150 | x = self.downConv(x) 151 | x = self.norm(x) 152 | x = self.activation(x) 153 | return x 154 | 155 | 156 | class Bottleneck_Construct(nn.Module): 157 | """Bottleneck convolution CSCM""" 158 | 159 | def __init__(self, d_model, window_size, d_inner): 160 | super(Bottleneck_Construct, self).__init__() 161 | if not isinstance(window_size, list): 162 | self.conv_layers = nn.ModuleList([ 163 | ConvLayer(d_inner, window_size), 164 | ConvLayer(d_inner, window_size), 165 | ConvLayer(d_inner, window_size) 166 | ]) 167 | else: 168 | self.conv_layers = [] 169 | for i in range(len(window_size)): 170 | self.conv_layers.append(ConvLayer(d_inner, window_size[i])) 171 | self.conv_layers = nn.ModuleList(self.conv_layers) 172 | self.up = Linear(d_inner, d_model) 173 | self.down = Linear(d_model, d_inner) 174 | self.norm = nn.LayerNorm(d_model) 175 | 176 | def forward(self, enc_input): 177 | temp_input = self.down(enc_input).permute(0, 2, 1) 178 | all_inputs = [] 179 | for i in range(len(self.conv_layers)): 180 | temp_input = self.conv_layers[i](temp_input) 181 | all_inputs.append(temp_input) 182 | 183 | all_inputs = torch.cat(all_inputs, dim=2).transpose(1, 2) 184 | all_inputs = self.up(all_inputs) 185 | all_inputs = torch.cat([enc_input, all_inputs], dim=1) 186 | 187 | all_inputs = self.norm(all_inputs) 188 | return all_inputs 189 | 190 | 191 | class PositionwiseFeedForward(nn.Module): 192 | """ Two-layer position-wise feed-forward neural network. """ 193 | 194 | def __init__(self, d_in, d_hid, dropout=0.1, normalize_before=True): 195 | super().__init__() 196 | 197 | self.normalize_before = normalize_before 198 | 199 | self.w_1 = nn.Linear(d_in, d_hid) 200 | self.w_2 = nn.Linear(d_hid, d_in) 201 | 202 | self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) 203 | self.dropout = nn.Dropout(dropout) 204 | 205 | def forward(self, x): 206 | residual = x 207 | if self.normalize_before: 208 | x = self.layer_norm(x) 209 | 210 | x = F.gelu(self.w_1(x)) 211 | x = self.dropout(x) 212 | x = self.w_2(x) 213 | x = self.dropout(x) 214 | x = x + residual 215 | 216 | if not self.normalize_before: 217 | x = self.layer_norm(x) 218 | return x 219 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

Visual Decoding and Reconstruction via EEG Embeddings with Guided Diffusion

4 | 5 | 6 |
7 | 8 |

9 | 10 |

11 | 12 | 13 |

14 |

15 | 16 | 17 |
18 | 19 |
20 | 21 |
22 | 23 | 25 | 27 | 28 | 29 | 30 | 33 | 34 | 35 | Framework 36 | Framework of our proposed method. 37 | 38 | 39 | 40 | 41 | 42 | fig-genexample 43 | 44 | Some examples of using EEG to reconstruct stimulus images. 45 | 46 | 47 | ## News: 48 | - [2024/09/26] Our paper is accepted to **NeurIPS 2024**. 49 | - [2024/09/25] We have updated the [arxiv](https://arxiv.org/abs/2403.07721) paper. 50 | - [2024/08/01] Update scripts for training and inference in different tasks. 51 | - [2024/05/19] Update the dataset loading scripts. 52 | - [2024/03/12] The [arxiv](https://arxiv.org/abs/2403.07721) paper is available. 53 | 54 | 55 | 56 |

Environment setup

57 | 58 | ### Option 1: Using setup.sh (Recommended) 59 | 60 | Run the setup script to create a conda environment with all dependencies: 61 | 62 | ```bash 63 | . setup.sh 64 | conda activate BCI 65 | ``` 66 | 67 | ### Option 2: Using environment.yml 68 | 69 | ```bash 70 | conda env create -f environment.yml 71 | conda activate BCI 72 | ``` 73 | 74 | ### Option 3: Using requirements.txt 75 | 76 | ```bash 77 | conda create -n BCI python=3.12 -y 78 | conda activate BCI 79 | pip install torch==2.5.0 torchvision==0.20.0 torchaudio==2.5.0 --index-url https://download.pytorch.org/whl/cu124 80 | pip install -r requirements.txt 81 | ``` 82 | 83 | 84 | 85 |

Quick training and test

86 | 87 | If you want to quickly reproduce the results in the paper, please download the relevant ``preprocessed data`` and ``model weights`` from [Hugging Face](https://huggingface.co/datasets/LidongYang/EEG_Image_decode) first. 88 | #### 1.Image Retrieval 89 | We provide the script to learn the training strategy of EEG Encoder and verify it during training. In this task, we use **normalized clip embedding** to train EEG encoder. Please modify your data set path and run: 90 | ``` 91 | cd Retrieval/ 92 | python ATMS_retrieval.py --logger True --gpu cuda:0 --output_dir ./outputs/contrast 93 | ``` 94 | We also provide the script for ``joint subject training``, which aims to train all subjects jointly and test on a specific subject: 95 | ``` 96 | cd Retrieval/ 97 | python ATMS_retrieval_joint_train.py --joint_train --sub sub-01 True --logger True --gpu cuda:0 --output_dir ./outputs/contrast 98 | ``` 99 | 100 | Additionally, replicating the results of other methods (e.g. EEGNetV4) by run 101 | ``` 102 | cd Retrieval/ 103 | contrast_retrieval.py --encoder_type EEGNetv4_Encoder --epochs 30 --batch_size 1024 104 | ``` 105 | 106 | #### 2.Image Reconstruction 107 | We provide quick training and inference scripts for ``clip pipeline`` of visual reconstruction. In this task, we use **the original clip embedding** to train EEG encoder. Please modify your data set path and run zero-shot on 200 classes test dataset: 108 | ``` 109 | # Train and generate eeg features in Subject 8 110 | cd Generation/ 111 | python ATMS_reconstruction.py --insubject True --subjects sub-08 --logger True \ 112 | --gpu cuda:0 --output_dir ./outputs/contrast 113 | ``` 114 | 115 | ``` 116 | # Reconstruct images in Subject 8 117 | Generation_metrics_sub8.ipynb 118 | ``` 119 | 120 | We also provide scripts for image reconstruction combined ``with the low level pipeline``. 121 | ``` 122 | cd Generation/ 123 | 124 | # step 1: train vae encoder and then generate low level images 125 | train_vae_latent_512_low_level_no_average.py 126 | 127 | # step 2: load low level images and then reconstruct them 128 | 1x1024_reconstruct_sdxl.ipynb 129 | ``` 130 | 131 | 132 | We provide scripts for caption generation combined ``with the semantic level pipeline``. 133 | ``` 134 | cd Generation/ 135 | 136 | # step 1: train feature adapter 137 | image_adapter.ipynb 138 | 139 | # step 2: get caption from eeg latent 140 | GIT_caption_batch.ipynb 141 | 142 | # step 3: load text prompt and then reconstruct images 143 | 1x1024_reconstruct_sdxl.ipynb 144 | ``` 145 | 146 | To evaluate the quality of the reconstructed images, modify the paths of the reconstructed images and the original stimulus images in the notebook and run: 147 | ``` 148 | #compute metrics, cited from MindEye 149 | Reconstruction_Metrics_ATM.ipynb 150 | ``` 151 | 152 | 153 |

Data availability

154 | 155 | We provide you with the ``preprocessed EEG`` and ``preprocessed MEG`` data used in our paper at [Hugging Face](https://huggingface.co/datasets/LidongYang/EEG_Image_decode), as well as the raw image data. 156 | 157 | 158 | Note that the experimental paradigms of the THINGS-EEG and THINGS-MEG datasets themselves are different, so we will provide images and data for the two datasets separately. 159 | 160 | You can also download the relevant THINGS-EEG data set and THINGS-MEG data set at osf.io. 161 | 162 | The raw and preprocessed EEG dataset, the training and test images are available on [osf](https://osf.io/3jk45/). 163 | - ``Raw EEG data:`` `../project_directory/eeg_dataset/raw_data/`. 164 | - ``Preprocessed EEG data:`` `../project_directory/eeg_dataset/preprocessed_data/`. 165 | - ``Training and test images:`` `../project_directory/image_set/`. 166 | 167 | 168 | The raw and preprocessed MEG dataset, the training and test images are available on [OpenNEURO](https://openneuro.org/datasets/ds004212/versions/2.0.0). 169 | 170 | 171 | 172 | 173 | 174 | 175 |

EEG/MEG preprocessing

176 | 177 | 178 | Modify your path and execute the following code to perform the same preprocessing on the raw data as in our experiment: 179 | ``` 180 | cd EEG-preprocessing/ 181 | python EEG-preprocessing/preprocessing.py 182 | ``` 183 | 184 | ``` 185 | cd MEG-preprocessing/ 186 | MEG-preprocessing/pre_possess.ipynb 187 | ``` 188 | Also You can get the data set used in this project through the BaiduNetDisk [link](https://pan.baidu.com/s/1-1hgpoi4nereLVqE4ylE_g?pwd=nid5) to run the code. 189 | 190 | ## TODO 191 | - [√] Release retrieval and reconstruction scripts. 192 | - [√] Update training scripts of reconstruction pipeline. 193 | - [ ] Adding validation sets improves performance evaluation accuracy. 194 | 195 | 196 | 197 | 198 |

Acknowledge

199 | 200 | 1.Thanks to Y Song et al. for their contribution in data set preprocessing and neural network structure, we refer to their work:
"[Decoding Natural Images from EEG for Object Recognition](https://arxiv.org/pdf/2308.13234.pdf)".
Yonghao Song, Bingchuan Liu, Xiang Li, Nanlin Shi, Yijun Wang, and Xiaorong Gao. 201 | 202 | 2.We also thank the authors of [SDRecon](https://github.com/yu-takagi/StableDiffusionReconstruction) for providing the codes and the results. Some parts of the training script are based on [MindEye](https://medarc-ai.github.io/mindeye/) and [MindEye2](https://github.com/MedARC-AI/MindEyeV2). Thanks for the awesome research works. 203 | 204 | 3.Here we provide our THING-EEG dataset cited in the paper:
"[A large and rich EEG dataset for modeling human visual object recognition](https://www.sciencedirect.com/science/article/pii/S1053811922008758?via%3Dihub)".
205 | Alessandro T. Gifford, Kshitij Dwivedi, Gemma Roig, Radoslaw M. Cichy. 206 | 207 | 208 | 4.Another used THINGS-MEG data set provides a reference:
"[THINGS-data, a multimodal collection of large-scale datasets for investigating object representations in human brain and behavior.](https://elifesciences.org/articles/82580.pdf)".
Hebart, Martin N., Oliver Contier, Lina Teichmann, Adam H. Rockter, Charles Y. Zheng, Alexis Kidder, Anna Corriveau, Maryam Vaziri-Pashkam, and Chris I. Baker. 209 | 210 | 211 | 212 | 213 |

Citation

214 | 215 | ```bibtex 216 | @inproceedings{li2024visual, 217 | author = {Li, Dongyang and Wei, Chen and Li, Shiying and Zou, Jiachen and Liu, Quanying}, 218 | booktitle = {Advances in Neural Information Processing Systems}, 219 | editor = {A. Globerson and L. Mackey and D. Belgrave and A. Fan and U. Paquet and J. Tomczak and C. Zhang}, 220 | pages = {102822--102864}, 221 | publisher = {Curran Associates, Inc.}, 222 | title = {Visual Decoding and Reconstruction via EEG Embeddings with Guided Diffusion}, 223 | url = {https://proceedings.neurips.cc/paper_files/paper/2024/file/ba5f1233efa77787ff9ec015877dbd1f-Paper-Conference.pdf}, 224 | volume = {37}, 225 | year = {2024} 226 | } 227 | 228 | 229 | @article{li2024visual, 230 | title={Visual Decoding and Reconstruction via EEG Embeddings with Guided Diffusion}, 231 | author={Li, Dongyang and Wei, Chen and Li, Shiying and Zou, Jiachen and Liu, Quanying}, 232 | journal={arXiv preprint arXiv:2403.07721}, 233 | year={2024} 234 | } 235 | ``` 236 | -------------------------------------------------------------------------------- /models/subject_layers/Embed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils import weight_norm 5 | import math 6 | 7 | 8 | class PositionalEmbedding(nn.Module): 9 | def __init__(self, d_model, max_len=5000): 10 | super(PositionalEmbedding, self).__init__() 11 | # Compute the positional encodings once in log space. 12 | pe = torch.zeros(max_len, d_model).float() 13 | pe.require_grad = False 14 | 15 | position = torch.arange(0, max_len).float().unsqueeze(1) 16 | div_term = (torch.arange(0, d_model, 2).float() 17 | * -(math.log(10000.0) / d_model)).exp() 18 | 19 | pe[:, 0::2] = torch.sin(position * div_term) 20 | pe[:, 1::2] = torch.cos(position * div_term) 21 | 22 | pe = pe.unsqueeze(0) 23 | self.register_buffer('pe', pe) 24 | 25 | def forward(self, x): 26 | return self.pe[:, :x.size(1)] 27 | 28 | 29 | class TokenEmbedding(nn.Module): 30 | def __init__(self, c_in, d_model): 31 | super(TokenEmbedding, self).__init__() 32 | padding = 1 if torch.__version__ >= '1.5.0' else 2 33 | self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, 34 | kernel_size=3, padding=padding, padding_mode='circular', bias=False) 35 | for m in self.modules(): 36 | if isinstance(m, nn.Conv1d): 37 | nn.init.kaiming_normal_( 38 | m.weight, mode='fan_in', nonlinearity='leaky_relu') 39 | 40 | def forward(self, x): 41 | x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2) 42 | return x 43 | 44 | 45 | class FixedEmbedding(nn.Module): 46 | def __init__(self, c_in, d_model): 47 | super(FixedEmbedding, self).__init__() 48 | 49 | w = torch.zeros(c_in, d_model).float() 50 | w.require_grad = False 51 | 52 | position = torch.arange(0, c_in).float().unsqueeze(1) 53 | div_term = (torch.arange(0, d_model, 2).float() 54 | * -(math.log(10000.0) / d_model)).exp() 55 | 56 | w[:, 0::2] = torch.sin(position * div_term) 57 | w[:, 1::2] = torch.cos(position * div_term) 58 | 59 | self.emb = nn.Embedding(c_in, d_model) 60 | self.emb.weight = nn.Parameter(w, requires_grad=False) 61 | 62 | def forward(self, x): 63 | return self.emb(x).detach() 64 | 65 | 66 | class TemporalEmbedding(nn.Module): 67 | def __init__(self, d_model, embed_type='fixed', freq='h'): 68 | super(TemporalEmbedding, self).__init__() 69 | 70 | minute_size = 4 71 | hour_size = 24 72 | weekday_size = 7 73 | day_size = 32 74 | month_size = 13 75 | 76 | Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding 77 | if freq == 't': 78 | self.minute_embed = Embed(minute_size, d_model) 79 | self.hour_embed = Embed(hour_size, d_model) 80 | self.weekday_embed = Embed(weekday_size, d_model) 81 | self.day_embed = Embed(day_size, d_model) 82 | self.month_embed = Embed(month_size, d_model) 83 | 84 | def forward(self, x): 85 | x = x.long() 86 | minute_x = self.minute_embed(x[:, :, 4]) if hasattr( 87 | self, 'minute_embed') else 0. 88 | hour_x = self.hour_embed(x[:, :, 3]) 89 | weekday_x = self.weekday_embed(x[:, :, 2]) 90 | day_x = self.day_embed(x[:, :, 1]) 91 | month_x = self.month_embed(x[:, :, 0]) 92 | 93 | return hour_x + weekday_x + day_x + month_x + minute_x 94 | 95 | 96 | class TimeFeatureEmbedding(nn.Module): 97 | def __init__(self, d_model, embed_type='timeF', freq='h'): 98 | super(TimeFeatureEmbedding, self).__init__() 99 | 100 | freq_map = {'h': 4, 't': 5, 's': 6, 101 | 'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3} 102 | d_inp = freq_map[freq] 103 | self.embed = nn.Linear(d_inp, d_model, bias=False) 104 | 105 | def forward(self, x): 106 | return self.embed(x) 107 | 108 | 109 | class SubjectEmbedding(nn.Module): 110 | def __init__(self, num_subjects, d_model): 111 | super(SubjectEmbedding, self).__init__() 112 | self.subject_embedding = nn.Embedding(num_subjects, d_model) 113 | self.shared_embedding = nn.Parameter(torch.randn(1, d_model)) # Shared token for unknown subjects 114 | self.mask_embedding = nn.Parameter(torch.randn(1, d_model)) # Mask token embedding 115 | 116 | def forward(self, subject_ids): 117 | if subject_ids[0] is None or torch.any(subject_ids >= self.subject_embedding.num_embeddings): 118 | batch_size = subject_ids.size(0) 119 | return self.shared_embedding.expand(batch_size, 1, -1) 120 | else: 121 | return self.subject_embedding(subject_ids).unsqueeze(1) 122 | 123 | 124 | # class DataEmbedding(nn.Module): 125 | # def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1, num_subjects=None): 126 | # super(DataEmbedding, self).__init__() 127 | # self.value_embedding = nn.Linear(c_in, d_model) 128 | # self.position_embedding = PositionalEmbedding(d_model=d_model) 129 | # self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) 130 | # self.dropout = nn.Dropout(p=dropout) 131 | # self.subject_embedding = SubjectEmbedding(num_subjects, d_model) if num_subjects is not None else None 132 | # self.mask_token = nn.Parameter(torch.randn(1, d_model)) # Mask token embedding 133 | 134 | # def forward(self, x, x_mark, subject_ids=None, mask=None): 135 | # if x_mark is None: 136 | # x = self.value_embedding(x) 137 | # else: 138 | # x = self.value_embedding(x) + self.temporal_embedding(x_mark) + self.position_embedding(x) 139 | 140 | # if mask is not None: 141 | # x = x * (~mask.bool()) + self.mask_token * mask.float() 142 | 143 | # if self.subject_embedding is not None: 144 | # subject_emb = self.subject_embedding(subject_ids) # (batch_size, 1, d_model) 145 | # x = torch.cat([subject_emb, x], dim=1) # Concatenate along sequence dimension (batch_size, seq_len + 1, d_model) 146 | 147 | # return self.dropout(x) 148 | 149 | class DataEmbedding(nn.Module): 150 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1, joint_train=False, num_subjects=None): 151 | super(DataEmbedding, self).__init__() 152 | if joint_train and num_subjects is not None: 153 | self.value_embedding = nn.ModuleDict({ 154 | str(subject_id): nn.Linear(c_in, d_model) for subject_id in range(num_subjects) 155 | }) 156 | else: 157 | self.value_embedding = nn.Linear(c_in, d_model) # If no subjects specified, use a single value embedding 158 | 159 | self.position_embedding = PositionalEmbedding(d_model=d_model) 160 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) 161 | self.dropout = nn.Dropout(p=dropout) 162 | self.subject_embedding = SubjectEmbedding(num_subjects, d_model) if num_subjects is not None else None 163 | self.mask_token = nn.Parameter(torch.randn(1, d_model)) # Mask token embedding 164 | self.joint_train = joint_train 165 | 166 | def forward(self, x, x_mark, subject_ids=None, mask=None): 167 | if self.joint_train: 168 | # Use subject-specific value embedding for each subject 169 | x = torch.stack([self.value_embedding[str(subject_id.item())](x[i]) for i, subject_id in enumerate(subject_ids)]) 170 | else: 171 | x = self.value_embedding(x) 172 | 173 | if x_mark is not None: 174 | x = x + self.temporal_embedding(x_mark) + self.position_embedding(x) 175 | 176 | if mask is not None: 177 | x = x * (~mask.bool()) + self.mask_token * mask.float() 178 | 179 | if self.subject_embedding is not None: 180 | subject_emb = self.subject_embedding(subject_ids) # (batch_size, 1, d_model) 181 | x = torch.cat([subject_emb, x], dim=1) # Concatenate along sequence dimension (batch_size, seq_len + 1, d_model) 182 | 183 | return self.dropout(x) 184 | 185 | 186 | 187 | 188 | class DataEmbedding_inverted(nn.Module): 189 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): 190 | super(DataEmbedding_inverted, self).__init__() 191 | self.value_embedding = nn.Linear(c_in, d_model) 192 | self.dropout = nn.Dropout(p=dropout) 193 | 194 | def forward(self, x, x_mark): 195 | x = x.permute(0, 2, 1) 196 | # x: [Batch Variate Time] 197 | if x_mark is None: 198 | x = self.value_embedding(x) 199 | else: 200 | x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1)) 201 | # x: [Batch Variate d_model] 202 | return self.dropout(x) 203 | 204 | 205 | class DataEmbedding_wo_pos(nn.Module): 206 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): 207 | super(DataEmbedding_wo_pos, self).__init__() 208 | 209 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) 210 | self.position_embedding = PositionalEmbedding(d_model=d_model) 211 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, 212 | freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( 213 | d_model=d_model, embed_type=embed_type, freq=freq) 214 | self.dropout = nn.Dropout(p=dropout) 215 | 216 | def forward(self, x, x_mark): 217 | if x_mark is None: 218 | x = self.value_embedding(x) + self.position_embedding(x) 219 | else: 220 | x = self.value_embedding(x) + self.temporal_embedding(x_mark) 221 | return self.dropout(x) 222 | 223 | 224 | class PatchEmbedding(nn.Module): 225 | def __init__(self, d_model, patch_len, stride, padding, dropout): 226 | super(PatchEmbedding, self).__init__() 227 | # Patching 228 | self.patch_len = patch_len 229 | self.stride = stride 230 | self.padding_patch_layer = nn.ReplicationPad1d((0, padding)) 231 | 232 | # Backbone, Input encoding: projection of feature vectors onto a d-dim vector space 233 | self.value_embedding = nn.Linear(patch_len, d_model, bias=False) 234 | 235 | # Positional embedding 236 | self.position_embedding = PositionalEmbedding(d_model) 237 | 238 | # Residual dropout 239 | self.dropout = nn.Dropout(dropout) 240 | 241 | def forward(self, x): 242 | # do patching 243 | n_vars = x.shape[1] 244 | x = self.padding_patch_layer(x) 245 | x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride) 246 | x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) 247 | # Input encoding 248 | x = self.value_embedding(x) + self.position_embedding(x) 249 | return self.dropout(x), n_vars 250 | -------------------------------------------------------------------------------- /EEG-preprocessing/preprocessing_utils.py: -------------------------------------------------------------------------------- 1 | def epoching(args, data_part, seed): 2 | """This function first converts the EEG data to MNE raw format, and 3 | performs channel selection, epoching, baseline correction and frequency 4 | downsampling. Then, it sorts the EEG data of each session according to the 5 | image conditions. 6 | 7 | Parameters 8 | ---------- 9 | args : Namespace 10 | Input arguments. 11 | data_part : str 12 | 'test' or 'training' data partitions. 13 | seed : int 14 | Random seed. 15 | 16 | Returns 17 | ------- 18 | epoched_data : list of float 19 | Epoched EEG data. 20 | img_conditions : list of int 21 | Unique image conditions of the epoched and sorted EEG data. 22 | ch_names : list of str 23 | EEG channel names. 24 | times : float 25 | EEG time points. 26 | 27 | """ 28 | 29 | import os 30 | import mne 31 | import numpy as np 32 | from sklearn.utils import shuffle 33 | 34 | chan_order = ['Fp1', 'Fp2', 'AF7', 'AF3', 'AFz', 'AF4', 'AF8', 'F7', 'F5', 'F3', 35 | 'F1', 'F2', 'F4', 'F6', 'F8', 'FT9', 'FT7', 'FC5', 'FC3', 'FC1', 36 | 'FCz', 'FC2', 'FC4', 'FC6', 'FT8', 'FT10', 'T7', 'C5', 'C3', 'C1', 37 | 'Cz', 'C2', 'C4', 'C6', 'T8', 'TP9', 'TP7', 'CP5', 'CP3', 'CP1', 38 | 'CPz', 'CP2', 'CP4', 'CP6', 'TP8', 'TP10', 'P7', 'P5', 'P3', 'P1', 39 | 'Pz', 'P2', 'P4', 'P6', 'P8', 'PO7', 'PO3', 'POz', 'PO4', 'PO8', 40 | 'O1', 'Oz', 'O2'] 41 | 42 | ### Loop across data collection sessions ### 43 | epoched_data = [] 44 | img_conditions = [] 45 | for s in range(args.n_ses): 46 | 47 | ### Load the EEG data and convert it to MNE raw format ### 48 | eeg_dir = os.path.join('Raw_data', 'sub-'+ 49 | format(args.sub,'02'), 'ses-'+format(s+1,'02'), 'raw_eeg_'+ 50 | data_part+'.npy') 51 | eeg_data = np.load(os.path.join(args.project_dir, eeg_dir), 52 | allow_pickle=True).item() 53 | ch_names = eeg_data['ch_names'] 54 | sfreq = eeg_data['sfreq'] 55 | ch_types = eeg_data['ch_types'] 56 | eeg_data = eeg_data['raw_eeg_data'] 57 | # Convert to MNE raw format 58 | info = mne.create_info(ch_names, sfreq, ch_types) 59 | raw = mne.io.RawArray(eeg_data, info) 60 | del eeg_data 61 | 62 | ### Get events, drop unused channels and reject target trials ### 63 | events = mne.find_events(raw, stim_channel='stim') 64 | # # Select only occipital (O) and posterior (P) channels 65 | # chan_idx = np.asarray(mne.pick_channels_regexp(raw.info['ch_names'], 66 | # '^O *|^P *')) 67 | # new_chans = [raw.info['ch_names'][c] for c in chan_idx] 68 | # raw.pick_channels(new_chans) 69 | # * chose all channels 70 | raw.pick_channels(chan_order, ordered=True) 71 | # Reject the target trials (event 99999) 72 | idx_target = np.where(events[:,2] == 99999)[0] 73 | events = np.delete(events, idx_target, 0) 74 | ### Epoching, baseline correction and resampling ### 75 | # * [0, 1.0] 76 | epochs = mne.Epochs(raw, events, tmin=-.2, tmax=1.0, baseline=(None,0), 77 | preload=True) 78 | # epochs = mne.Epochs(raw, events, tmin=-.2, tmax=.8, baseline=(None,0), 79 | # preload=True) 80 | del raw 81 | # Resampling 82 | if args.sfreq < 1000: 83 | epochs.resample(args.sfreq) 84 | ch_names = epochs.info['ch_names'] 85 | times = epochs.times 86 | 87 | ### Sort the data ### 88 | data = epochs.get_data() 89 | events = epochs.events[:,2] 90 | img_cond = np.unique(events) 91 | del epochs 92 | # Select only a maximum number of EEG repetitions 93 | if data_part == 'test': 94 | max_rep = 20 95 | else: 96 | max_rep = 2 97 | # Sorted data matrix of shape: 98 | # Image conditions × EEG repetitions × EEG channels × EEG time points 99 | sorted_data = np.zeros((len(img_cond),max_rep,data.shape[1], 100 | data.shape[2])) 101 | for i in range(len(img_cond)): 102 | # Find the indices of the selected image condition 103 | idx = np.where(events == img_cond[i])[0] 104 | # Randomly select only the max number of EEG repetitions 105 | idx = shuffle(idx, random_state=seed, n_samples=max_rep) 106 | sorted_data[i] = data[idx] 107 | del data 108 | epoched_data.append(sorted_data[:, :, :, 50:]) 109 | img_conditions.append(img_cond) 110 | del sorted_data 111 | 112 | ### Output ### 113 | return epoched_data, img_conditions, ch_names, times 114 | 115 | 116 | def mvnn(args, epoched_test, epoched_train): 117 | """Compute the covariance matrices of the EEG data (calculated for each 118 | time-point or epoch/repetitions of each image condition), and then average 119 | them across image conditions and data partitions. The inverse of the 120 | resulting averaged covariance matrix is used to whiten the EEG data 121 | (independently for each session). 122 | 123 | zero-score standardization also has well performance 124 | 125 | Parameters 126 | ---------- 127 | args : Namespace 128 | Input arguments. 129 | epoched_test : list of floats 130 | Epoched test EEG data. 131 | epoched_train : list of floats 132 | Epoched training EEG data. 133 | 134 | Returns 135 | ------- 136 | whitened_test : list of float 137 | Whitened test EEG data. 138 | whitened_train : list of float 139 | Whitened training EEG data. 140 | 141 | """ 142 | 143 | import numpy as np 144 | from tqdm import tqdm 145 | from sklearn.discriminant_analysis import _cov 146 | import scipy 147 | 148 | ### Loop across data collection sessions ### 149 | whitened_test = [] 150 | whitened_train = [] 151 | for s in range(args.n_ses): 152 | session_data = [epoched_test[s], epoched_train[s]] 153 | 154 | ### Compute the covariance matrices ### 155 | # Data partitions covariance matrix of shape: 156 | # Data partitions × EEG channels × EEG channels 157 | sigma_part = np.empty((len(session_data),session_data[0].shape[2], 158 | session_data[0].shape[2])) 159 | for p in range(sigma_part.shape[0]): 160 | # Image conditions covariance matrix of shape: 161 | # Image conditions × EEG channels × EEG channels 162 | sigma_cond = np.empty((session_data[p].shape[0], 163 | session_data[0].shape[2],session_data[0].shape[2])) 164 | for i in tqdm(range(session_data[p].shape[0])): 165 | cond_data = session_data[p][i] 166 | # Compute covariace matrices at each time point, and then 167 | # average across time points 168 | if args.mvnn_dim == "time": 169 | sigma_cond[i] = np.mean([_cov(cond_data[:,:,t], 170 | shrinkage='auto') for t in range(cond_data.shape[2])], 171 | axis=0) 172 | # Compute covariace matrices at each epoch (EEG repetition), 173 | # and then average across epochs/repetitions 174 | elif args.mvnn_dim == "epochs": 175 | sigma_cond[i] = np.mean([_cov(np.transpose(cond_data[e]), 176 | shrinkage='auto') for e in range(cond_data.shape[0])], 177 | axis=0) 178 | # Average the covariance matrices across image conditions 179 | sigma_part[p] = sigma_cond.mean(axis=0) 180 | # # Average the covariance matrices across image partitions 181 | # sigma_tot = sigma_part.mean(axis=0) 182 | # ? It seems not fair to use test data for mvnn, so we change to just use training data 183 | sigma_tot = sigma_part[1] 184 | # Compute the inverse of the covariance matrix 185 | sigma_inv = scipy.linalg.fractional_matrix_power(sigma_tot, -0.5) 186 | 187 | ### Whiten the data ### 188 | whitened_test.append(np.reshape((np.reshape(session_data[0], (-1, 189 | session_data[0].shape[2],session_data[0].shape[3])).swapaxes(1, 2) 190 | @ sigma_inv).swapaxes(1, 2), session_data[0].shape)) 191 | whitened_train.append(np.reshape((np.reshape(session_data[1], (-1, 192 | session_data[1].shape[2],session_data[1].shape[3])).swapaxes(1, 2) 193 | @ sigma_inv).swapaxes(1, 2), session_data[1].shape)) 194 | 195 | ### Output ### 196 | return whitened_test, whitened_train 197 | 198 | 199 | def save_prepr(args, whitened_test, whitened_train, img_conditions_train, 200 | ch_names, times, seed): 201 | """Merge the EEG data of all sessions together, shuffle the EEG repetitions 202 | across sessions and reshaping the data to the format: 203 | Image conditions × EGG repetitions × EEG channels × EEG time points. 204 | Then, the data of both test and training EEG partitions is saved. 205 | 206 | Parameters 207 | ---------- 208 | args : Namespace 209 | Input arguments. 210 | whitened_test : list of float 211 | Whitened test EEG data. 212 | whitened_train : list of float 213 | Whitened training EEG data. 214 | img_conditions_train : list of int 215 | Unique image conditions of the epoched and sorted train EEG data. 216 | ch_names : list of str 217 | EEG channel names. 218 | times : float 219 | EEG time points. 220 | seed : int 221 | Random seed. 222 | 223 | """ 224 | 225 | import numpy as np 226 | from sklearn.utils import shuffle 227 | import os 228 | import pickle 229 | 230 | ### Merge and save the test data ### 231 | for s in range(args.n_ses): 232 | if s == 0: 233 | merged_test = whitened_test[s] 234 | else: 235 | merged_test = np.append(merged_test, whitened_test[s], 1) 236 | del whitened_test 237 | # Shuffle the repetitions of different sessions 238 | idx = shuffle(np.arange(0, merged_test.shape[1]), random_state=seed) 239 | merged_test = merged_test[:,idx] 240 | # Insert the data into a dictionary 241 | test_dict = { 242 | 'preprocessed_eeg_data': merged_test, 243 | 'ch_names': ch_names, 244 | 'times': times 245 | } 246 | del merged_test 247 | # Saving directories 248 | save_dir = os.path.join(args.project_dir, 249 | 'Preprocessed_data_250Hz', 'sub-'+format(args.sub,'02')) 250 | file_name_test = 'preprocessed_eeg_test.npy' 251 | file_name_train = 'preprocessed_eeg_training.npy' 252 | # Create the directory if not existing and save the data 253 | if os.path.isdir(save_dir) == False: 254 | os.makedirs(save_dir) 255 | # np.save(os.path.join(save_dir, file_name_test), test_dict) 256 | save_pic = open(os.path.join(save_dir, file_name_test), 'wb') 257 | pickle.dump(test_dict, save_pic, protocol=4) 258 | save_pic.close() 259 | del test_dict 260 | 261 | ### Merge and save the training data ### 262 | for s in range(args.n_ses): 263 | if s == 0: 264 | white_data = whitened_train[s] 265 | img_cond = img_conditions_train[s] 266 | else: 267 | white_data = np.append(white_data, whitened_train[s], 0) 268 | img_cond = np.append(img_cond, img_conditions_train[s], 0) 269 | del whitened_train, img_conditions_train 270 | # Data matrix of shape: 271 | # Image conditions × EGG repetitions × EEG channels × EEG time points 272 | merged_train = np.zeros((len(np.unique(img_cond)), white_data.shape[1]*2, 273 | white_data.shape[2],white_data.shape[3])) 274 | for i in range(len(np.unique(img_cond))): 275 | # Find the indices of the selected category 276 | idx = np.where(img_cond == i+1)[0] 277 | for r in range(len(idx)): 278 | if r == 0: 279 | ordered_data = white_data[idx[r]] 280 | else: 281 | ordered_data = np.append(ordered_data, white_data[idx[r]], 0) 282 | merged_train[i] = ordered_data 283 | # Shuffle the repetitions of different sessions 284 | idx = shuffle(np.arange(0, merged_train.shape[1]), random_state=seed) 285 | merged_train = merged_train[:,idx] 286 | # Insert the data into a dictionary 287 | train_dict = { 288 | 'preprocessed_eeg_data': merged_train, 289 | 'ch_names': ch_names, 290 | 'times': times 291 | } 292 | del merged_train 293 | # Create the directory if not existing and save the data 294 | if os.path.isdir(save_dir) == False: 295 | os.makedirs(save_dir) 296 | # np.save(os.path.join(save_dir, file_name_train), 297 | # train_dict) 298 | save_pic = open(os.path.join(save_dir, file_name_train), 'wb') 299 | pickle.dump(train_dict, save_pic, protocol=4) 300 | save_pic.close() 301 | del train_dict 302 | -------------------------------------------------------------------------------- /models/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import torch 4 | import os 5 | import sys 6 | import time 7 | from torch import inf 8 | import wandb 9 | 10 | class NativeScaler: 11 | state_dict_key = "amp_scaler" 12 | 13 | def __init__(self): 14 | self._scaler = torch.cuda.amp.GradScaler() 15 | 16 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 17 | self._scaler.scale(loss).backward(create_graph=create_graph) 18 | if update_grad: 19 | if clip_grad is not None: 20 | assert parameters is not None 21 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 22 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 23 | else: 24 | self._scaler.unscale_(optimizer) 25 | norm = get_grad_norm_(parameters) 26 | self._scaler.step(optimizer) 27 | self._scaler.update() 28 | else: 29 | norm = None 30 | return norm 31 | 32 | def state_dict(self): 33 | return self._scaler.state_dict() 34 | 35 | def load_state_dict(self, state_dict): 36 | self._scaler.load_state_dict(state_dict) 37 | 38 | 39 | 40 | def get_grad_norm_(parameters, norm_type: float = 2.0): 41 | if isinstance(parameters, torch.Tensor): 42 | parameters = [parameters] 43 | parameters = [p for p in parameters if p.grad is not None] 44 | norm_type = float(norm_type) 45 | if len(parameters) == 0: 46 | return torch.tensor(0.) 47 | device = parameters[0].grad.device 48 | if norm_type == inf: 49 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 50 | else: 51 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 52 | return total_norm 53 | 54 | def train_one_epoch(model, data_loader, optimizer, device, epoch, 55 | loss_scaler, log_writer=None, config=None, start_time=None, model_without_ddp=None, 56 | img_feature_extractor=None, preprocess=None): 57 | model.train(True) 58 | optimizer.zero_grad() 59 | total_loss = [] 60 | total_cor = [] 61 | accum_iter = config.accum_iter 62 | for data_iter_step, (data_dcit) in enumerate(data_loader): 63 | 64 | # we use a per iteration (instead of per epoch) lr scheduler 65 | # print(data_iter_step) 66 | # print(len(data_loader)) 67 | 68 | if data_iter_step % accum_iter == 0: 69 | adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, config) 70 | samples = data_dcit['eeg'] 71 | 72 | img_features = None 73 | valid_idx = None 74 | if img_feature_extractor is not None: 75 | images = data_dcit['image'] 76 | valid_idx = torch.nonzero(images.sum(dim=(1,2,3)) != 0).squeeze(1) 77 | img_feature_extractor.eval() 78 | with torch.no_grad(): 79 | img_features = img_feature_extractor(preprocess(images[valid_idx]).to(device))['layer2'] 80 | samples = samples.to(device) 81 | # img_features = img_features.to(device) 82 | 83 | optimizer.zero_grad() 84 | with torch.cuda.amp.autocast(enabled=True): 85 | loss, pred, _ = model(samples, img_features, valid_idx=valid_idx, mask_ratio=config.mask_ratio) 86 | # loss.backward() 87 | # norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.clip_grad) 88 | # optimizer.step() 89 | 90 | loss_value = loss.item() 91 | 92 | if not math.isfinite(loss_value): 93 | print(f"Loss is {loss_value}, stopping training at step {data_iter_step} epoch {epoch}") 94 | sys.exit(1) 95 | 96 | # loss /= accum_iter 97 | loss_scaler(loss, optimizer, parameters=model.parameters(), clip_grad=config.clip_grad) 98 | 99 | # if (data_iter_step + 1) % accum_iter == 0: 100 | # cal the cor 101 | pred = pred.to('cpu').detach() 102 | samples = samples.to('cpu').detach() 103 | # pred = pred.transpose(1,2) #model_without_ddp.unpatchify(pred) 104 | pred = model_without_ddp.unpatchify(pred) 105 | 106 | cor = torch.mean(torch.tensor([torch.corrcoef(torch.cat([p[0].unsqueeze(0), s[0].unsqueeze(0)],axis=0))[0,1] for p, s in zip(pred, samples)])).item() 107 | optimizer.zero_grad() 108 | 109 | total_loss.append(loss_value) 110 | total_cor.append(cor) 111 | if device == torch.device('cuda:0'): 112 | lr = optimizer.param_groups[0]["lr"] 113 | print('train_loss_step:', np.mean(total_loss), 'lr:', lr, 'cor', np.mean(total_cor)) 114 | 115 | if log_writer is not None: 116 | lr = optimizer.param_groups[0]["lr"] 117 | log_writer.log('train_loss_step', np.mean(total_loss), step=epoch) 118 | log_writer.log('lr', lr, step=epoch) 119 | log_writer.log('cor', np.mean(total_cor), step=epoch) 120 | if start_time is not None: 121 | log_writer.log('time (min)', (time.time() - start_time)/60.0, step=epoch) 122 | if config.local_rank == 0: 123 | print(f'[Epoch {epoch}] loss: {np.mean(total_loss)}') 124 | 125 | return np.mean(total_cor) 126 | 127 | def get_1d_sincos_pos_embed(embed_dim, length, cls_token=False): 128 | """ 129 | grid_size: int of the grid height and width 130 | return: 131 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 132 | """ 133 | grid_l = np.arange(length, dtype=float) 134 | 135 | grid_l = grid_l.reshape([1, length]) 136 | pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_l) 137 | if cls_token: 138 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 139 | return pos_embed 140 | 141 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 142 | """ 143 | embed_dim: output dimension for each position 144 | pos: a list of positions to be encoded: size (M,) 145 | out: (M, D) 146 | """ 147 | assert embed_dim % 2 == 0 148 | omega = np.arange(embed_dim // 2, dtype=float) 149 | omega /= embed_dim / 2. 150 | omega = 1. / 10000**omega # (D/2,) 151 | 152 | pos = pos.reshape(-1) # (M,) 153 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 154 | 155 | emb_sin = np.sin(out) # (M, D/2) 156 | emb_cos = np.cos(out) # (M, D/2) 157 | 158 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 159 | return emb 160 | 161 | def interpolate_pos_embed(model, checkpoint_model): 162 | if 'pos_embed' in checkpoint_model: 163 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 164 | embedding_size = pos_embed_checkpoint.shape[-1] 165 | num_patches = model.patch_embed.num_patches 166 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches # cls token 167 | # height (== width) for the checkpoint position embedding 168 | orig_size = int(pos_embed_checkpoint.shape[-2] - num_extra_tokens) 169 | # height (== width) for the new position embedding 170 | new_size = int(num_patches) 171 | # class_token and dist_token are kept unchanged 172 | if orig_size != new_size: 173 | print("Position interpolate from %d to %d" % (orig_size, new_size)) 174 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 175 | # only the position tokens are interpolated 176 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 177 | pos_tokens = pos_tokens.reshape(-1, orig_size, embedding_size).permute(0, 2, 1) 178 | pos_tokens = torch.nn.functional.interpolate( 179 | pos_tokens, size=(new_size)) 180 | pos_tokens = pos_tokens.permute(0, 2, 1) 181 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 182 | checkpoint_model['pos_embed'] = new_pos_embed 183 | 184 | 185 | def adjust_learning_rate(optimizer, epoch, config): 186 | """Decay the learning rate with half-cycle cosine after warmup""" 187 | if epoch < config.warmup_epochs: 188 | lr = config.lr * epoch / config.warmup_epochs 189 | else: 190 | lr = config.min_lr + (config.lr - config.min_lr) * 0.5 * \ 191 | (1. + math.cos(math.pi * (epoch - config.warmup_epochs) / (config.num_epoch - config.warmup_epochs))) 192 | for param_group in optimizer.param_groups: 193 | if "lr_scale" in param_group: 194 | param_group["lr"] = lr * param_group["lr_scale"] 195 | else: 196 | param_group["lr"] = lr 197 | return lr 198 | 199 | 200 | 201 | 202 | def load_model(config, model, checkpoint_path ): 203 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 204 | model.load_state_dict(checkpoint['model']) 205 | print(f'Model loaded with {checkpoint_path}') 206 | 207 | 208 | def patchify(imgs, patch_size): 209 | """ 210 | imgs: (N, 1, num_voxels) 211 | x: (N, L, patch_size) 212 | """ 213 | p = patch_size 214 | assert imgs.ndim == 3 and imgs.shape[2] % p == 0 215 | 216 | h = imgs.shape[2] // p 217 | x = imgs.reshape(shape=(imgs.shape[0], h, p)) 218 | return x 219 | 220 | def unpatchify(x, patch_size): 221 | """ 222 | x: (N, L, patch_size) 223 | imgs: (N, 1, num_voxels) 224 | """ 225 | p = patch_size 226 | h = x.shape[1] 227 | 228 | imgs = x.reshape(shape=(x.shape[0], 1, h * p)) 229 | return imgs 230 | 231 | class wandb_logger: 232 | def __init__(self, config): 233 | try: 234 | wandb.init( 235 | # Set the project where this run will be logged 236 | project=config['project'], 237 | name=config['name'], 238 | config=config, 239 | entity=config['entity'], 240 | ) 241 | except: 242 | wandb.init( 243 | # Set the project where this run will be logged 244 | project=config.project, 245 | name=config.name, 246 | config=config, 247 | entity=config.entity, 248 | ) 249 | 250 | self.config = config 251 | self.step = None 252 | 253 | def log(self, data, step=None): 254 | if step is None: 255 | wandb.log(data) 256 | else: 257 | wandb.log(data, step=step) 258 | self.step = step 259 | 260 | def watch_model(self, *args, **kwargs): 261 | wandb.watch(*args, **kwargs) 262 | 263 | def log_image(self, figs): 264 | if self.step is None: 265 | wandb.log(figs) 266 | else: 267 | wandb.log(figs, step=self.step) 268 | 269 | def finish(self): 270 | wandb.finish(quiet=True) 271 | 272 | def load(self, net): 273 | path = os.path.join(self.config['path_data'], self.config['path_ckpt'], self.config['file_ckpt']) 274 | net.load_state_dict(torch.load(path)) 275 | print(f'load {path}') 276 | 277 | def save(self, net, file_name=None): 278 | path_ckpt = os.path.join(self.config['path_data'], self.config['path_ckpt']) 279 | if not os.path.exists(path_ckpt): 280 | os.makedirs(path_ckpt) 281 | print(f'{path_ckpt} created!') 282 | 283 | path = os.path.join(path_ckpt, file_name) 284 | torch.save(net.state_dict(), path) 285 | 286 | def watch(self, model, log): 287 | wandb.watch(model, log) -------------------------------------------------------------------------------- /models/subject_layers/ETSformer_EncDec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.fft as fft 5 | from einops import rearrange, reduce, repeat 6 | import math, random 7 | from scipy.fftpack import next_fast_len 8 | 9 | 10 | class Transform: 11 | def __init__(self, sigma): 12 | self.sigma = sigma 13 | 14 | @torch.no_grad() 15 | def transform(self, x): 16 | return self.jitter(self.shift(self.scale(x))) 17 | 18 | def jitter(self, x): 19 | return x + (torch.randn(x.shape).to(x.device) * self.sigma) 20 | 21 | def scale(self, x): 22 | return x * (torch.randn(x.size(-1)).to(x.device) * self.sigma + 1) 23 | 24 | def shift(self, x): 25 | return x + (torch.randn(x.size(-1)).to(x.device) * self.sigma) 26 | 27 | 28 | def conv1d_fft(f, g, dim=-1): 29 | N = f.size(dim) 30 | M = g.size(dim) 31 | 32 | fast_len = next_fast_len(N + M - 1) 33 | 34 | F_f = fft.rfft(f, fast_len, dim=dim) 35 | F_g = fft.rfft(g, fast_len, dim=dim) 36 | 37 | F_fg = F_f * F_g.conj() 38 | out = fft.irfft(F_fg, fast_len, dim=dim) 39 | out = out.roll((-1,), dims=(dim,)) 40 | idx = torch.as_tensor(range(fast_len - N, fast_len)).to(out.device) 41 | out = out.index_select(dim, idx) 42 | 43 | return out 44 | 45 | 46 | class ExponentialSmoothing(nn.Module): 47 | 48 | def __init__(self, dim, nhead, dropout=0.1, aux=False): 49 | super().__init__() 50 | self._smoothing_weight = nn.Parameter(torch.randn(nhead, 1)) 51 | self.v0 = nn.Parameter(torch.randn(1, 1, nhead, dim)) 52 | self.dropout = nn.Dropout(dropout) 53 | if aux: 54 | self.aux_dropout = nn.Dropout(dropout) 55 | 56 | def forward(self, values, aux_values=None): 57 | b, t, h, d = values.shape 58 | 59 | init_weight, weight = self.get_exponential_weight(t) 60 | output = conv1d_fft(self.dropout(values), weight, dim=1) 61 | output = init_weight * self.v0 + output 62 | 63 | if aux_values is not None: 64 | aux_weight = weight / (1 - self.weight) * self.weight 65 | aux_output = conv1d_fft(self.aux_dropout(aux_values), aux_weight) 66 | output = output + aux_output 67 | 68 | return output 69 | 70 | def get_exponential_weight(self, T): 71 | # Generate array [0, 1, ..., T-1] 72 | powers = torch.arange(T, dtype=torch.float, device=self.weight.device) 73 | 74 | # (1 - \alpha) * \alpha^t, for all t = T-1, T-2, ..., 0] 75 | weight = (1 - self.weight) * (self.weight ** torch.flip(powers, dims=(0,))) 76 | 77 | # \alpha^t for all t = 1, 2, ..., T 78 | init_weight = self.weight ** (powers + 1) 79 | 80 | return rearrange(init_weight, 'h t -> 1 t h 1'), \ 81 | rearrange(weight, 'h t -> 1 t h 1') 82 | 83 | @property 84 | def weight(self): 85 | return torch.sigmoid(self._smoothing_weight) 86 | 87 | 88 | class Feedforward(nn.Module): 89 | def __init__(self, d_model, dim_feedforward, dropout=0.1, activation='sigmoid'): 90 | # Implementation of Feedforward model 91 | super().__init__() 92 | self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False) 93 | self.dropout1 = nn.Dropout(dropout) 94 | self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False) 95 | self.dropout2 = nn.Dropout(dropout) 96 | self.activation = getattr(F, activation) 97 | 98 | def forward(self, x): 99 | x = self.linear2(self.dropout1(self.activation(self.linear1(x)))) 100 | return self.dropout2(x) 101 | 102 | 103 | class GrowthLayer(nn.Module): 104 | 105 | def __init__(self, d_model, nhead, d_head=None, dropout=0.1): 106 | super().__init__() 107 | self.d_head = d_head or (d_model // nhead) 108 | self.d_model = d_model 109 | self.nhead = nhead 110 | 111 | self.z0 = nn.Parameter(torch.randn(self.nhead, self.d_head)) 112 | self.in_proj = nn.Linear(self.d_model, self.d_head * self.nhead) 113 | self.es = ExponentialSmoothing(self.d_head, self.nhead, dropout=dropout) 114 | self.out_proj = nn.Linear(self.d_head * self.nhead, self.d_model) 115 | 116 | assert self.d_head * self.nhead == self.d_model, "d_model must be divisible by nhead" 117 | 118 | def forward(self, inputs): 119 | """ 120 | :param inputs: shape: (batch, seq_len, dim) 121 | :return: shape: (batch, seq_len, dim) 122 | """ 123 | b, t, d = inputs.shape 124 | values = self.in_proj(inputs).view(b, t, self.nhead, -1) 125 | values = torch.cat([repeat(self.z0, 'h d -> b 1 h d', b=b), values], dim=1) 126 | values = values[:, 1:] - values[:, :-1] 127 | out = self.es(values) 128 | out = torch.cat([repeat(self.es.v0, '1 1 h d -> b 1 h d', b=b), out], dim=1) 129 | out = rearrange(out, 'b t h d -> b t (h d)') 130 | return self.out_proj(out) 131 | 132 | 133 | class FourierLayer(nn.Module): 134 | 135 | def __init__(self, d_model, pred_len, k=None, low_freq=1): 136 | super().__init__() 137 | self.d_model = d_model 138 | self.pred_len = pred_len 139 | self.k = k 140 | self.low_freq = low_freq 141 | 142 | def forward(self, x): 143 | """x: (b, t, d)""" 144 | b, t, d = x.shape 145 | x_freq = fft.rfft(x, dim=1) 146 | 147 | if t % 2 == 0: 148 | x_freq = x_freq[:, self.low_freq:-1] 149 | f = fft.rfftfreq(t)[self.low_freq:-1] 150 | else: 151 | x_freq = x_freq[:, self.low_freq:] 152 | f = fft.rfftfreq(t)[self.low_freq:] 153 | 154 | x_freq, index_tuple = self.topk_freq(x_freq) 155 | f = repeat(f, 'f -> b f d', b=x_freq.size(0), d=x_freq.size(2)) 156 | f = rearrange(f[index_tuple], 'b f d -> b f () d').to(x_freq.device) 157 | 158 | return self.extrapolate(x_freq, f, t) 159 | 160 | def extrapolate(self, x_freq, f, t): 161 | x_freq = torch.cat([x_freq, x_freq.conj()], dim=1) 162 | f = torch.cat([f, -f], dim=1) 163 | t_val = rearrange(torch.arange(t + self.pred_len, dtype=torch.float), 164 | 't -> () () t ()').to(x_freq.device) 165 | 166 | amp = rearrange(x_freq.abs() / t, 'b f d -> b f () d') 167 | phase = rearrange(x_freq.angle(), 'b f d -> b f () d') 168 | 169 | x_time = amp * torch.cos(2 * math.pi * f * t_val + phase) 170 | 171 | return reduce(x_time, 'b f t d -> b t d', 'sum') 172 | 173 | def topk_freq(self, x_freq): 174 | values, indices = torch.topk(x_freq.abs(), self.k, dim=1, largest=True, sorted=True) 175 | mesh_a, mesh_b = torch.meshgrid(torch.arange(x_freq.size(0)), torch.arange(x_freq.size(2))) 176 | index_tuple = (mesh_a.unsqueeze(1), indices, mesh_b.unsqueeze(1)) 177 | x_freq = x_freq[index_tuple] 178 | 179 | return x_freq, index_tuple 180 | 181 | 182 | class LevelLayer(nn.Module): 183 | 184 | def __init__(self, d_model, c_out, dropout=0.1): 185 | super().__init__() 186 | self.d_model = d_model 187 | self.c_out = c_out 188 | 189 | self.es = ExponentialSmoothing(1, self.c_out, dropout=dropout, aux=True) 190 | self.growth_pred = nn.Linear(self.d_model, self.c_out) 191 | self.season_pred = nn.Linear(self.d_model, self.c_out) 192 | 193 | def forward(self, level, growth, season): 194 | b, t, _ = level.shape 195 | growth = self.growth_pred(growth).view(b, t, self.c_out, 1) 196 | season = self.season_pred(season).view(b, t, self.c_out, 1) 197 | growth = growth.view(b, t, self.c_out, 1) 198 | season = season.view(b, t, self.c_out, 1) 199 | level = level.view(b, t, self.c_out, 1) 200 | out = self.es(level - season, aux_values=growth) 201 | out = rearrange(out, 'b t h d -> b t (h d)') 202 | return out 203 | 204 | 205 | class EncoderLayer(nn.Module): 206 | 207 | def __init__(self, d_model, nhead, c_out, seq_len, pred_len, k, dim_feedforward=None, dropout=0.1, 208 | activation='sigmoid', layer_norm_eps=1e-5): 209 | super().__init__() 210 | self.d_model = d_model 211 | self.nhead = nhead 212 | self.c_out = c_out 213 | self.seq_len = seq_len 214 | self.pred_len = pred_len 215 | dim_feedforward = dim_feedforward or 4 * d_model 216 | self.dim_feedforward = dim_feedforward 217 | 218 | self.growth_layer = GrowthLayer(d_model, nhead, dropout=dropout) 219 | self.seasonal_layer = FourierLayer(d_model, pred_len, k=k) 220 | self.level_layer = LevelLayer(d_model, c_out, dropout=dropout) 221 | 222 | # Implementation of Feedforward model 223 | self.ff = Feedforward(d_model, dim_feedforward, dropout=dropout, activation=activation) 224 | self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) 225 | self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) 226 | 227 | self.dropout1 = nn.Dropout(dropout) 228 | self.dropout2 = nn.Dropout(dropout) 229 | 230 | def forward(self, res, level, attn_mask=None): 231 | season = self._season_block(res) 232 | res = res - season[:, :-self.pred_len] 233 | growth = self._growth_block(res) 234 | res = self.norm1(res - growth[:, 1:]) 235 | res = self.norm2(res + self.ff(res)) 236 | 237 | level = self.level_layer(level, growth[:, :-1], season[:, :-self.pred_len]) 238 | return res, level, growth, season 239 | 240 | def _growth_block(self, x): 241 | x = self.growth_layer(x) 242 | return self.dropout1(x) 243 | 244 | def _season_block(self, x): 245 | x = self.seasonal_layer(x) 246 | return self.dropout2(x) 247 | 248 | 249 | class Encoder(nn.Module): 250 | 251 | def __init__(self, layers): 252 | super().__init__() 253 | self.layers = nn.ModuleList(layers) 254 | 255 | def forward(self, res, level, attn_mask=None): 256 | growths = [] 257 | seasons = [] 258 | for layer in self.layers: 259 | res, level, growth, season = layer(res, level, attn_mask=None) 260 | growths.append(growth) 261 | seasons.append(season) 262 | 263 | return level, growths, seasons 264 | 265 | 266 | class DampingLayer(nn.Module): 267 | 268 | def __init__(self, pred_len, nhead, dropout=0.1): 269 | super().__init__() 270 | self.pred_len = pred_len 271 | self.nhead = nhead 272 | self._damping_factor = nn.Parameter(torch.randn(1, nhead)) 273 | self.dropout = nn.Dropout(dropout) 274 | 275 | def forward(self, x): 276 | x = repeat(x, 'b 1 d -> b t d', t=self.pred_len) 277 | b, t, d = x.shape 278 | 279 | powers = torch.arange(self.pred_len).to(self._damping_factor.device) + 1 280 | powers = powers.view(self.pred_len, 1) 281 | damping_factors = self.damping_factor ** powers 282 | damping_factors = damping_factors.cumsum(dim=0) 283 | x = x.view(b, t, self.nhead, -1) 284 | x = self.dropout(x) * damping_factors.unsqueeze(-1) 285 | return x.view(b, t, d) 286 | 287 | @property 288 | def damping_factor(self): 289 | return torch.sigmoid(self._damping_factor) 290 | 291 | 292 | class DecoderLayer(nn.Module): 293 | 294 | def __init__(self, d_model, nhead, c_out, pred_len, dropout=0.1): 295 | super().__init__() 296 | self.d_model = d_model 297 | self.nhead = nhead 298 | self.c_out = c_out 299 | self.pred_len = pred_len 300 | 301 | self.growth_damping = DampingLayer(pred_len, nhead, dropout=dropout) 302 | self.dropout1 = nn.Dropout(dropout) 303 | 304 | def forward(self, growth, season): 305 | growth_horizon = self.growth_damping(growth[:, -1:]) 306 | growth_horizon = self.dropout1(growth_horizon) 307 | 308 | seasonal_horizon = season[:, -self.pred_len:] 309 | return growth_horizon, seasonal_horizon 310 | 311 | 312 | class Decoder(nn.Module): 313 | 314 | def __init__(self, layers): 315 | super().__init__() 316 | self.d_model = layers[0].d_model 317 | self.c_out = layers[0].c_out 318 | self.pred_len = layers[0].pred_len 319 | self.nhead = layers[0].nhead 320 | 321 | self.layers = nn.ModuleList(layers) 322 | self.pred = nn.Linear(self.d_model, self.c_out) 323 | 324 | def forward(self, growths, seasons): 325 | growth_repr = [] 326 | season_repr = [] 327 | 328 | for idx, layer in enumerate(self.layers): 329 | growth_horizon, season_horizon = layer(growths[idx], seasons[idx]) 330 | growth_repr.append(growth_horizon) 331 | season_repr.append(season_horizon) 332 | growth_repr = sum(growth_repr) 333 | season_repr = sum(season_repr) 334 | return self.pred(growth_repr), self.pred(season_repr) 335 | -------------------------------------------------------------------------------- /models/subject_layers/SelfAttention_Family.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from math import sqrt 5 | from models.utils.masking import TriangularCausalMask, ProbMask 6 | from reformer_pytorch import LSHSelfAttention 7 | from einops import rearrange, repeat 8 | 9 | 10 | class DSAttention(nn.Module): 11 | '''De-stationary Attention''' 12 | 13 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): 14 | super(DSAttention, self).__init__() 15 | self.scale = scale 16 | self.mask_flag = mask_flag 17 | self.output_attention = output_attention 18 | self.dropout = nn.Dropout(attention_dropout) 19 | 20 | def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): 21 | B, L, H, E = queries.shape 22 | _, S, _, D = values.shape 23 | scale = self.scale or 1. / sqrt(E) 24 | 25 | tau = 1.0 if tau is None else tau.unsqueeze( 26 | 1).unsqueeze(1) # B x 1 x 1 x 1 27 | delta = 0.0 if delta is None else delta.unsqueeze( 28 | 1).unsqueeze(1) # B x 1 x 1 x S 29 | 30 | # De-stationary Attention, rescaling pre-softmax score with learned de-stationary factors 31 | scores = torch.einsum("blhe,bshe->bhls", queries, keys) * tau + delta 32 | 33 | if self.mask_flag: 34 | if attn_mask is None: 35 | attn_mask = TriangularCausalMask(B, L, device=queries.device) 36 | 37 | scores.masked_fill_(attn_mask.mask, -np.inf) 38 | 39 | A = self.dropout(torch.softmax(scale * scores, dim=-1)) 40 | V = torch.einsum("bhls,bshd->blhd", A, values) 41 | 42 | if self.output_attention: 43 | return V.contiguous(), A 44 | else: 45 | return V.contiguous(), None 46 | 47 | 48 | class FullAttention(nn.Module): 49 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): 50 | super(FullAttention, self).__init__() 51 | self.scale = scale 52 | self.mask_flag = mask_flag 53 | self.output_attention = output_attention 54 | self.dropout = nn.Dropout(attention_dropout) 55 | 56 | def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): 57 | B, L, H, E = queries.shape 58 | _, S, _, D = values.shape 59 | scale = self.scale or 1. / sqrt(E) 60 | 61 | scores = torch.einsum("blhe,bshe->bhls", queries, keys) 62 | 63 | if self.mask_flag: 64 | if attn_mask is None: 65 | attn_mask = TriangularCausalMask(B, L, device=queries.device) 66 | 67 | scores.masked_fill_(attn_mask.mask, -np.inf) 68 | 69 | A = self.dropout(torch.softmax(scale * scores, dim=-1)) 70 | V = torch.einsum("bhls,bshd->blhd", A, values) 71 | 72 | if self.output_attention: 73 | return V.contiguous(), A 74 | else: 75 | return V.contiguous(), None 76 | 77 | 78 | class ProbAttention(nn.Module): 79 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): 80 | super(ProbAttention, self).__init__() 81 | self.factor = factor 82 | self.scale = scale 83 | self.mask_flag = mask_flag 84 | self.output_attention = output_attention 85 | self.dropout = nn.Dropout(attention_dropout) 86 | 87 | def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q) 88 | # Q [B, H, L, D] 89 | B, H, L_K, E = K.shape 90 | _, _, L_Q, _ = Q.shape 91 | 92 | # calculate the sampled Q_K 93 | K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E) 94 | # real U = U_part(factor*ln(L_k))*L_q 95 | index_sample = torch.randint(L_K, (L_Q, sample_k)) 96 | K_sample = K_expand[:, :, torch.arange( 97 | L_Q).unsqueeze(1), index_sample, :] 98 | Q_K_sample = torch.matmul( 99 | Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze() 100 | 101 | # find the Top_k query with sparisty measurement 102 | M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K) 103 | M_top = M.topk(n_top, sorted=False)[1] 104 | 105 | # use the reduced Q to calculate Q_K 106 | Q_reduce = Q[torch.arange(B)[:, None, None], 107 | torch.arange(H)[None, :, None], 108 | M_top, :] # factor*ln(L_q) 109 | Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k 110 | 111 | return Q_K, M_top 112 | 113 | def _get_initial_context(self, V, L_Q): 114 | B, H, L_V, D = V.shape 115 | if not self.mask_flag: 116 | # V_sum = V.sum(dim=-2) 117 | V_sum = V.mean(dim=-2) 118 | contex = V_sum.unsqueeze(-2).expand(B, H, 119 | L_Q, V_sum.shape[-1]).clone() 120 | else: # use mask 121 | # requires that L_Q == L_V, i.e. for self-attention only 122 | assert (L_Q == L_V) 123 | contex = V.cumsum(dim=-2) 124 | return contex 125 | 126 | def _update_context(self, context_in, V, scores, index, L_Q, attn_mask): 127 | B, H, L_V, D = V.shape 128 | 129 | if self.mask_flag: 130 | attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device) 131 | scores.masked_fill_(attn_mask.mask, -np.inf) 132 | 133 | attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores) 134 | 135 | context_in[torch.arange(B)[:, None, None], 136 | torch.arange(H)[None, :, None], 137 | index, :] = torch.matmul(attn, V).type_as(context_in) 138 | if self.output_attention: 139 | attns = (torch.ones([B, H, L_V, L_V]) / 140 | L_V).type_as(attn).to(attn.device) 141 | attns[torch.arange(B)[:, None, None], torch.arange(H)[ 142 | None, :, None], index, :] = attn 143 | return context_in, attns 144 | else: 145 | return context_in, None 146 | 147 | def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): 148 | B, L_Q, H, D = queries.shape 149 | _, L_K, _, _ = keys.shape 150 | 151 | queries = queries.transpose(2, 1) 152 | keys = keys.transpose(2, 1) 153 | values = values.transpose(2, 1) 154 | 155 | U_part = self.factor * \ 156 | np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k) 157 | u = self.factor * \ 158 | np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q) 159 | 160 | U_part = U_part if U_part < L_K else L_K 161 | u = u if u < L_Q else L_Q 162 | 163 | scores_top, index = self._prob_QK( 164 | queries, keys, sample_k=U_part, n_top=u) 165 | 166 | # add scale factor 167 | scale = self.scale or 1. / sqrt(D) 168 | if scale is not None: 169 | scores_top = scores_top * scale 170 | # get the context 171 | context = self._get_initial_context(values, L_Q) 172 | # update the context with selected top_k queries 173 | context, attn = self._update_context( 174 | context, values, scores_top, index, L_Q, attn_mask) 175 | 176 | return context.contiguous(), attn 177 | 178 | 179 | class AttentionLayer(nn.Module): 180 | def __init__(self, attention, d_model, n_heads, d_keys=None, 181 | d_values=None): 182 | super(AttentionLayer, self).__init__() 183 | 184 | d_keys = d_keys or (d_model // n_heads) 185 | d_values = d_values or (d_model // n_heads) 186 | 187 | self.inner_attention = attention 188 | self.query_projection = nn.Linear(d_model, d_keys * n_heads) 189 | self.key_projection = nn.Linear(d_model, d_keys * n_heads) 190 | self.value_projection = nn.Linear(d_model, d_values * n_heads) 191 | self.out_projection = nn.Linear(d_values * n_heads, d_model) 192 | self.n_heads = n_heads 193 | 194 | def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): 195 | B, L, _ = queries.shape 196 | _, S, _ = keys.shape 197 | H = self.n_heads 198 | 199 | queries = self.query_projection(queries).view(B, L, H, -1) 200 | keys = self.key_projection(keys).view(B, S, H, -1) 201 | values = self.value_projection(values).view(B, S, H, -1) 202 | 203 | out, attn = self.inner_attention( 204 | queries, 205 | keys, 206 | values, 207 | attn_mask, 208 | tau=tau, 209 | delta=delta 210 | ) 211 | out = out.view(B, L, -1) 212 | 213 | return self.out_projection(out), attn 214 | 215 | 216 | class ReformerLayer(nn.Module): 217 | def __init__(self, attention, d_model, n_heads, d_keys=None, 218 | d_values=None, causal=False, bucket_size=4, n_hashes=4): 219 | super().__init__() 220 | self.bucket_size = bucket_size 221 | self.attn = LSHSelfAttention( 222 | dim=d_model, 223 | heads=n_heads, 224 | bucket_size=bucket_size, 225 | n_hashes=n_hashes, 226 | causal=causal 227 | ) 228 | 229 | def fit_length(self, queries): 230 | # inside reformer: assert N % (bucket_size * 2) == 0 231 | B, N, C = queries.shape 232 | if N % (self.bucket_size * 2) == 0: 233 | return queries 234 | else: 235 | # fill the time series 236 | fill_len = (self.bucket_size * 2) - (N % (self.bucket_size * 2)) 237 | return torch.cat([queries, torch.zeros([B, fill_len, C]).to(queries.device)], dim=1) 238 | 239 | def forward(self, queries, keys, values, attn_mask, tau, delta): 240 | # in Reformer: defalut queries=keys 241 | B, N, C = queries.shape 242 | queries = self.attn(self.fit_length(queries))[:, :N, :] 243 | return queries, None 244 | 245 | 246 | class TwoStageAttentionLayer(nn.Module): 247 | ''' 248 | The Two Stage Attention (TSA) Layer 249 | input/output shape: [batch_size, Data_dim(D), Seg_num(L), d_model] 250 | ''' 251 | 252 | def __init__(self, configs, 253 | seg_num, factor, d_model, n_heads, d_ff=None, dropout=0.1): 254 | super(TwoStageAttentionLayer, self).__init__() 255 | d_ff = d_ff or 4 * d_model 256 | self.time_attention = AttentionLayer(FullAttention(False, configs.factor, attention_dropout=configs.dropout, 257 | output_attention=configs.output_attention), d_model, n_heads) 258 | self.dim_sender = AttentionLayer(FullAttention(False, configs.factor, attention_dropout=configs.dropout, 259 | output_attention=configs.output_attention), d_model, n_heads) 260 | self.dim_receiver = AttentionLayer(FullAttention(False, configs.factor, attention_dropout=configs.dropout, 261 | output_attention=configs.output_attention), d_model, n_heads) 262 | self.router = nn.Parameter(torch.randn(seg_num, factor, d_model)) 263 | 264 | self.dropout = nn.Dropout(dropout) 265 | 266 | self.norm1 = nn.LayerNorm(d_model) 267 | self.norm2 = nn.LayerNorm(d_model) 268 | self.norm3 = nn.LayerNorm(d_model) 269 | self.norm4 = nn.LayerNorm(d_model) 270 | 271 | self.MLP1 = nn.Sequential(nn.Linear(d_model, d_ff), 272 | nn.GELU(), 273 | nn.Linear(d_ff, d_model)) 274 | self.MLP2 = nn.Sequential(nn.Linear(d_model, d_ff), 275 | nn.GELU(), 276 | nn.Linear(d_ff, d_model)) 277 | 278 | def forward(self, x, attn_mask=None, tau=None, delta=None): 279 | # Cross Time Stage: Directly apply MSA to each dimension 280 | batch = x.shape[0] 281 | time_in = rearrange(x, 'b ts_d seg_num d_model -> (b ts_d) seg_num d_model') 282 | time_enc, attn = self.time_attention( 283 | time_in, time_in, time_in, attn_mask=None, tau=None, delta=None 284 | ) 285 | dim_in = time_in + self.dropout(time_enc) 286 | dim_in = self.norm1(dim_in) 287 | dim_in = dim_in + self.dropout(self.MLP1(dim_in)) 288 | dim_in = self.norm2(dim_in) 289 | 290 | # Cross Dimension Stage: use a small set of learnable vectors to aggregate and distribute messages to build the D-to-D connection 291 | dim_send = rearrange(dim_in, '(b ts_d) seg_num d_model -> (b seg_num) ts_d d_model', b=batch) 292 | batch_router = repeat(self.router, 'seg_num factor d_model -> (repeat seg_num) factor d_model', repeat=batch) 293 | dim_buffer, attn = self.dim_sender(batch_router, dim_send, dim_send, attn_mask=None, tau=None, delta=None) 294 | dim_receive, attn = self.dim_receiver(dim_send, dim_buffer, dim_buffer, attn_mask=None, tau=None, delta=None) 295 | dim_enc = dim_send + self.dropout(dim_receive) 296 | dim_enc = self.norm3(dim_enc) 297 | dim_enc = dim_enc + self.dropout(self.MLP2(dim_enc)) 298 | dim_enc = self.norm4(dim_enc) 299 | 300 | final_out = rearrange(dim_enc, '(b seg_num) ts_d d_model -> b ts_d seg_num d_model', b=batch) 301 | 302 | return final_out 303 | -------------------------------------------------------------------------------- /Generation/diffusion_prior.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from tqdm import tqdm 7 | 8 | from diffusers.models.embeddings import Timesteps, TimestepEmbedding 9 | from torch.utils.data import Dataset 10 | 11 | 12 | class DiffusionPrior(nn.Module): 13 | 14 | def __init__( 15 | self, 16 | embed_dim=1024, 17 | cond_dim=42, 18 | hidden_dim=1024, 19 | layers_per_block=4, 20 | time_embed_dim=512, 21 | act_fn=nn.SiLU, 22 | dropout=0.0, 23 | ): 24 | super().__init__() 25 | 26 | self.embed_dim = embed_dim 27 | 28 | # 1. time embedding 29 | self.time_proj = Timesteps(time_embed_dim, True, 0) 30 | self.time_embedding = TimestepEmbedding( 31 | time_embed_dim, 32 | hidden_dim, 33 | ) 34 | 35 | # 2. conditional embedding 36 | self.cond_embedding = nn.Linear(cond_dim, hidden_dim) 37 | 38 | # 3. prior mlp 39 | 40 | # 3.1 input 41 | self.input_layer = nn.Sequential( 42 | nn.Linear(embed_dim, hidden_dim), 43 | nn.LayerNorm(hidden_dim), 44 | act_fn(), 45 | ) 46 | 47 | # 3.2 hidden 48 | self.hidden_layers = nn.ModuleList( 49 | [ 50 | nn.Sequential( 51 | nn.Linear(hidden_dim, hidden_dim), 52 | nn.LayerNorm(hidden_dim), 53 | act_fn(), 54 | nn.Dropout(dropout), 55 | ) 56 | for _ in range(layers_per_block) 57 | ] 58 | ) 59 | 60 | # 3.3 output 61 | self.output_layer = nn.Linear(hidden_dim, embed_dim) 62 | 63 | 64 | def forward(self, x, t, c=None): 65 | # x (batch_size, embed_dim) 66 | # t (batch_size, ) 67 | # c (batch_size, cond_dim) 68 | 69 | # 1. time embedding 70 | t = self.time_proj(t) # (batch_size, time_embed_dim) 71 | t = self.time_embedding(t) # (batch_size, hidden_dim) 72 | 73 | # 2. conditional embedding 74 | c = self.cond_embedding(c) if c is not None else 0 # (batch_size, hidden_dim) 75 | 76 | # 3. prior mlp 77 | 78 | # 3.1 input 79 | x = self.input_layer(x) 80 | 81 | # 3.2 hidden 82 | for layer in self.hidden_layers: 83 | x = x + t + c 84 | x = layer(x) + x 85 | 86 | # 3.3 output 87 | x = self.output_layer(x) 88 | 89 | return x 90 | 91 | 92 | class DiffusionPriorUNet(nn.Module): 93 | 94 | def __init__( 95 | self, 96 | embed_dim=1024, 97 | cond_dim=42, 98 | hidden_dim=[1024, 512, 256, 128, 64], 99 | time_embed_dim=512, 100 | act_fn=nn.SiLU, 101 | dropout=0.0, 102 | ): 103 | super().__init__() 104 | 105 | self.embed_dim = embed_dim 106 | self.cond_dim = cond_dim 107 | self.hidden_dim = hidden_dim 108 | 109 | # 1. time embedding 110 | self.time_proj = Timesteps(time_embed_dim, True, 0) 111 | 112 | # 2. conditional embedding 113 | # to 3.2, 3,3 114 | 115 | # 3. prior mlp 116 | 117 | # 3.1 input 118 | self.input_layer = nn.Sequential( 119 | nn.Linear(embed_dim, hidden_dim[0]), 120 | nn.LayerNorm(hidden_dim[0]), 121 | act_fn(), 122 | ) 123 | 124 | # 3.2 hidden encoder 125 | self.num_layers = len(hidden_dim) 126 | self.encode_time_embedding = nn.ModuleList( 127 | [TimestepEmbedding( 128 | time_embed_dim, 129 | hidden_dim[i], 130 | ) for i in range(self.num_layers-1)] 131 | ) # d_0, ..., d_{n-1} 132 | self.encode_cond_embedding = nn.ModuleList( 133 | [nn.Linear(cond_dim, hidden_dim[i]) for i in range(self.num_layers-1)] 134 | ) 135 | self.encode_layers = nn.ModuleList( 136 | [nn.Sequential( 137 | nn.Linear(hidden_dim[i], hidden_dim[i+1]), 138 | nn.LayerNorm(hidden_dim[i+1]), 139 | act_fn(), 140 | nn.Dropout(dropout), 141 | ) for i in range(self.num_layers-1)] 142 | ) 143 | 144 | # 3.3 hidden decoder 145 | self.decode_time_embedding = nn.ModuleList( 146 | [TimestepEmbedding( 147 | time_embed_dim, 148 | hidden_dim[i], 149 | ) for i in range(self.num_layers-1,0,-1)] 150 | ) # d_{n}, ..., d_1 151 | self.decode_cond_embedding = nn.ModuleList( 152 | [nn.Linear(cond_dim, hidden_dim[i]) for i in range(self.num_layers-1,0,-1)] 153 | ) 154 | self.decode_layers = nn.ModuleList( 155 | [nn.Sequential( 156 | nn.Linear(hidden_dim[i], hidden_dim[i-1]), 157 | nn.LayerNorm(hidden_dim[i-1]), 158 | act_fn(), 159 | nn.Dropout(dropout), 160 | ) for i in range(self.num_layers-1,0,-1)] 161 | ) 162 | 163 | # 3.4 output 164 | self.output_layer = nn.Linear(hidden_dim[0], embed_dim) 165 | 166 | 167 | def forward(self, x, t, c=None): 168 | # x (batch_size, embed_dim) 169 | # t (batch_size, ) 170 | # c (batch_size, cond_dim) 171 | 172 | # 1. time embedding 173 | t = self.time_proj(t) # (batch_size, time_embed_dim) 174 | 175 | # 2. conditional embedding 176 | # to 3.2, 3.3 177 | 178 | # 3. prior mlp 179 | 180 | # 3.1 input 181 | x = self.input_layer(x) 182 | 183 | # 3.2 hidden encoder 184 | hidden_activations = [] 185 | for i in range(self.num_layers-1): 186 | hidden_activations.append(x) 187 | t_emb = self.encode_time_embedding[i](t) 188 | c_emb = self.encode_cond_embedding[i](c) if c is not None else 0 189 | x = x + t_emb + c_emb 190 | x = self.encode_layers[i](x) 191 | 192 | # 3.3 hidden decoder 193 | for i in range(self.num_layers-1): 194 | t_emb = self.decode_time_embedding[i](t) 195 | c_emb = self.decode_cond_embedding[i](c) if c is not None else 0 196 | x = x + t_emb + c_emb 197 | x = self.decode_layers[i](x) 198 | x += hidden_activations[-1-i] 199 | 200 | # 3.4 output 201 | x = self.output_layer(x) 202 | 203 | return x 204 | 205 | 206 | class EmbeddingDataset(Dataset): 207 | 208 | def __init__(self, c_embeddings, h_embeddings): 209 | self.c_embeddings = c_embeddings 210 | self.h_embeddings = h_embeddings 211 | 212 | def __len__(self): 213 | return len(self.c_embeddings) 214 | 215 | def __getitem__(self, idx): 216 | return { 217 | "c_embedding": self.c_embeddings[idx], 218 | "h_embedding": self.h_embeddings[idx] 219 | } 220 | 221 | class EmbeddingDatasetVICE(Dataset): 222 | def __init__(self, path_data): 223 | image_features_dict = torch.load(os.path.join(path_data, 'openclip_emb/image_features.pt')) 224 | self.embedding_vise = torch.load(os.path.join(path_data, 'variables/embedding_vise.pt')) 225 | self.image_features = image_features_dict['image_features'] 226 | self.labels = image_features_dict['labels'] 227 | self.label2index = image_features_dict['l2i'] 228 | 229 | def __len__(self): 230 | return len(self.image_features) 231 | 232 | def __getitem__(self, idx): 233 | idx_c = self.label2index[self.labels[idx]] 234 | return { 235 | "c_embedding": self.embedding_vise[idx_c], 236 | "h_embedding": self.image_features[idx] 237 | } 238 | 239 | 240 | # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise 241 | def add_noise_with_sigma( 242 | self, 243 | original_samples: torch.FloatTensor, 244 | noise: torch.FloatTensor, 245 | timesteps: torch.FloatTensor, 246 | ) -> torch.FloatTensor: 247 | # Make sure sigmas and timesteps have the same device and dtype as original_samples 248 | sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) 249 | if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): 250 | # mps does not support float64 251 | schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) 252 | timesteps = timesteps.to(original_samples.device, dtype=torch.float32) 253 | else: 254 | schedule_timesteps = self.timesteps.to(original_samples.device) 255 | timesteps = timesteps.to(original_samples.device) 256 | 257 | step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] 258 | 259 | sigma = sigmas[step_indices].flatten() 260 | while len(sigma.shape) < len(original_samples.shape): 261 | sigma = sigma.unsqueeze(-1) 262 | 263 | noisy_samples = original_samples + noise * sigma 264 | return noisy_samples, sigma 265 | 266 | 267 | # diffusion pipe 268 | class Pipe: 269 | 270 | def __init__(self, diffusion_prior=None, scheduler=None, device='cuda'): 271 | self.diffusion_prior = diffusion_prior.to(device) 272 | 273 | if scheduler is None: 274 | from diffusers.schedulers import DDPMScheduler 275 | self.scheduler = DDPMScheduler() 276 | # self.scheduler.add_noise_with_sigma = add_noise_with_sigma.__get__(self.scheduler) 277 | else: 278 | self.scheduler = scheduler 279 | 280 | self.device = device 281 | 282 | def train(self, dataloader, num_epochs=10, learning_rate=1e-4): 283 | self.diffusion_prior.train() 284 | device = self.device 285 | criterion = nn.MSELoss(reduction='none') 286 | optimizer = optim.Adam(self.diffusion_prior.parameters(), lr=learning_rate) 287 | from diffusers.optimization import get_cosine_schedule_with_warmup 288 | lr_scheduler = get_cosine_schedule_with_warmup( 289 | optimizer=optimizer, 290 | num_warmup_steps=500, 291 | num_training_steps=(len(dataloader) * num_epochs), 292 | ) 293 | 294 | num_train_timesteps = self.scheduler.config.num_train_timesteps 295 | 296 | for epoch in range(num_epochs): 297 | loss_sum = 0 298 | for batch in dataloader: 299 | c_embeds = batch['c_embedding'].to(device) if 'c_embedding' in batch.keys() else None 300 | h_embeds = batch['h_embedding'].to(device) 301 | N = h_embeds.shape[0] 302 | 303 | # 1. randomly replecing c_embeds to None 304 | if torch.rand(1) < 0.1: 305 | c_embeds = None 306 | 307 | # 2. Generate noisy embeddings as input 308 | noise = torch.randn_like(h_embeds) 309 | 310 | # 3. sample timestep 311 | timesteps = torch.randint(0, num_train_timesteps, (N,), device=device) 312 | 313 | # 4. add noise to h_embedding 314 | perturbed_h_embeds = self.scheduler.add_noise( 315 | h_embeds, 316 | noise, 317 | timesteps 318 | ) # (batch_size, embed_dim), (batch_size, ) 319 | 320 | # 5. predict noise 321 | noise_pre = self.diffusion_prior(perturbed_h_embeds, timesteps, c_embeds) 322 | 323 | # 6. loss function weighted by sigma 324 | loss = criterion(noise_pre, noise) # (batch_size,) 325 | loss = (loss).mean() 326 | 327 | # 7. update parameters 328 | optimizer.zero_grad() 329 | loss.backward() 330 | torch.nn.utils.clip_grad_norm_(self.diffusion_prior.parameters(), 1.0) 331 | lr_scheduler.step() 332 | optimizer.step() 333 | 334 | loss_sum += loss.item() 335 | 336 | loss_epoch = loss_sum / len(dataloader) 337 | print(f'epoch: {epoch}, loss: {loss_epoch}') 338 | # lr_scheduler.step(loss) 339 | 340 | def generate( 341 | self, 342 | c_embeds=None, 343 | num_inference_steps=50, 344 | timesteps=None, 345 | guidance_scale=5.0, 346 | generator=None 347 | ): 348 | # c_embeds (batch_size, cond_dim) 349 | self.diffusion_prior.eval() 350 | N = c_embeds.shape[0] if c_embeds is not None else 1 351 | 352 | # 1. Prepare timesteps 353 | from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import retrieve_timesteps 354 | timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, self.device, timesteps) 355 | 356 | # 2. Prepare c_embeds 357 | if c_embeds is not None: 358 | c_embeds = c_embeds.to(self.device) 359 | 360 | # 3. Prepare noise 361 | h_t = torch.randn(N, self.diffusion_prior.embed_dim, generator=generator, device=self.device) 362 | 363 | # 4. denoising loop 364 | for _, t in tqdm(enumerate(timesteps)): 365 | t = torch.ones(h_t.shape[0], dtype=torch.float, device=self.device) * t 366 | # 4.1 noise prediction 367 | if guidance_scale == 0 or c_embeds is None: 368 | noise_pred = self.diffusion_prior(h_t, t) 369 | else: 370 | noise_pred_cond = self.diffusion_prior(h_t, t, c_embeds) 371 | noise_pred_uncond = self.diffusion_prior(h_t, t) 372 | # perform classifier-free guidance 373 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) 374 | 375 | # 4.2 compute the previous noisy sample h_t -> h_{t-1} 376 | h_t = self.scheduler.step(noise_pred, t.long().item(), h_t, generator=generator).prev_sample 377 | 378 | return h_t 379 | 380 | if __name__ == '__main__': 381 | import os 382 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 383 | # 1. test prior 384 | prior = DiffusionPriorUNet(cond_dim=1024) 385 | x = torch.randn(2, 1024) 386 | t = torch.randint(0, 1000, (2,)) 387 | c = torch.randn(2, 1024) 388 | y = prior(x, t, c) 389 | print(y.shape) 390 | 391 | 392 | 393 | -------------------------------------------------------------------------------- /Retrieval/eegdatasets_joint_subjects.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import numpy as np 4 | import os 5 | import clip 6 | from torch.nn import functional as F 7 | import torch.nn as nn 8 | from torchvision import transforms 9 | from PIL import Image 10 | import requests 11 | 12 | # import os 13 | # proxy = 'http://10.16.35.10:13390' 14 | # os.environ['http_proxy'] = proxy 15 | # os.environ['https_proxy'] = proxy 16 | 17 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 18 | # vlmodel, preprocess = clip.load("ViT-B/32", device=device) 19 | model_type = 'ViT-H-14' 20 | import open_clip 21 | vlmodel, preprocess_train, feature_extractor = open_clip.create_model_and_transforms( 22 | model_type, pretrained='laion2b_s32b_b79k', precision='fp32', device=device) 23 | 24 | import json 25 | 26 | # Load the configuration from the JSON file 27 | config_path = "data_config.json" 28 | with open(config_path, "r") as config_file: 29 | config = json.load(config_file) 30 | 31 | # Access the paths from the config 32 | data_path = config["data_path"] 33 | img_directory_training = config["img_directory_training"] 34 | img_directory_test = config["img_directory_test"] 35 | 36 | 37 | class EEGDataset(): 38 | """ 39 | subjects = ['sub-01', 'sub-02', 'sub-05', 'sub-04', 'sub-03', 'sub-06', 'sub-07', 'sub-08', 'sub-09', 'sub-10'] 40 | """ 41 | def __init__(self, data_path, adap_subject=None, subjects=None, train=True, time_window=[0, 1.0], classes=None, pictures=None): 42 | self.data_path = data_path 43 | self.train = train 44 | self.subject_list = os.listdir(data_path) 45 | self.subjects = self.subject_list if subjects is None else subjects 46 | self.n_sub = len(self.subjects) 47 | self.time_window = time_window 48 | self.n_cls = 1654 if train else 200 49 | self.classes = classes 50 | self.pictures = pictures 51 | self.adap_subject = adap_subject # Save this parameter 52 | 53 | # Assert any subjects in subject_list 54 | assert any(sub in self.subject_list for sub in self.subjects) 55 | 56 | self.data, self.labels, self.text, self.img = self.load_data() 57 | 58 | self.data = self.extract_eeg(self.data, time_window) 59 | 60 | if self.classes is None and self.pictures is None: 61 | # Try to load the saved features if they exist 62 | features_filename = os.path.join(f'{model_type}_features_train.pt') if self.train else os.path.join(f'{model_type}_features_test.pt') 63 | 64 | if os.path.exists(features_filename): 65 | saved_features = torch.load(features_filename) 66 | self.text_features = saved_features['text_features'] 67 | self.img_features = saved_features['img_features'] 68 | else: 69 | self.text_features = self.Textencoder(self.text) 70 | self.img_features = self.ImageEncoder(self.img) 71 | torch.save({ 72 | 'text_features': self.text_features.cpu(), 73 | 'img_features': self.img_features.cpu(), 74 | }, features_filename) 75 | else: 76 | self.text_features = self.Textencoder(self.text) 77 | self.img_features = self.ImageEncoder(self.img) 78 | 79 | def load_data(self): 80 | data_list = [] 81 | label_list = [] 82 | texts = [] 83 | images = [] 84 | 85 | if self.train: 86 | directory = img_directory_training 87 | else: 88 | directory = img_directory_test 89 | # Get all directories in the path 90 | dirnames = [d for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d))] 91 | dirnames.sort() 92 | 93 | if self.classes is not None: 94 | dirnames = [dirnames[i] for i in self.classes] 95 | 96 | for dir in dirnames: 97 | # Try to find the first occurrence of '_' 98 | try: 99 | idx = dir.index('_') 100 | description = dir[idx+1:] # Get all content after the first '_' 101 | except ValueError: 102 | print(f"Skipped: {dir} due to no '_' found.") 103 | continue 104 | 105 | new_description = f"This picture is {description}" 106 | texts.append(new_description) 107 | 108 | if self.train: 109 | img_directory = img_directory_training # Replace with your new address 110 | else: 111 | img_directory = img_directory_test 112 | 113 | all_folders = [d for d in os.listdir(img_directory) if os.path.isdir(os.path.join(img_directory, d))] 114 | all_folders.sort() # Ensure the order of folders 115 | 116 | if self.classes is not None and self.pictures is not None: 117 | images = [] # Initialize images list 118 | for i in range(len(self.classes)): 119 | class_idx = self.classes[i] 120 | pic_idx = self.pictures[i] 121 | if class_idx < len(all_folders): 122 | folder = all_folders[class_idx] 123 | folder_path = os.path.join(img_directory, folder) 124 | all_images = [img for img in os.listdir(folder_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))] 125 | all_images.sort() 126 | if pic_idx < len(all_images): 127 | images.append(os.path.join(folder_path, all_images[pic_idx])) 128 | elif self.classes is not None and self.pictures is None: 129 | images = [] # Initialize images list 130 | for i in range(len(self.classes)): 131 | class_idx = self.classes[i] 132 | if class_idx < len(all_folders): 133 | folder = all_folders[class_idx] 134 | folder_path = os.path.join(img_directory, folder) 135 | all_images = [img for img in os.listdir(folder_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))] 136 | all_images.sort() 137 | images.extend(os.path.join(folder_path, img) for img in all_images) 138 | elif self.classes is None: 139 | images = [] # Initialize images list 140 | for folder in all_folders: 141 | folder_path = os.path.join(img_directory, folder) 142 | all_images = [img for img in os.listdir(folder_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))] 143 | all_images.sort() 144 | images.extend(os.path.join(folder_path, img) for img in all_images) 145 | else: 146 | # Handle other cases, such as mismatched lengths of self.classes and self.pictures 147 | print("Error") 148 | 149 | print("self.subjects", self.subjects) 150 | print("adap_subject", self.adap_subject) 151 | for subject in self.subjects: 152 | if self.train: 153 | # if subject == self.adap_subject: # Skip the excluded subject 154 | # continue 155 | # print("subject:", subject) 156 | file_name = 'preprocessed_eeg_training.npy' 157 | 158 | file_path = os.path.join(self.data_path, subject, file_name) 159 | data = np.load(file_path, allow_pickle=True) 160 | 161 | preprocessed_eeg_data = torch.from_numpy(data['preprocessed_eeg_data']).float().detach() 162 | times = torch.from_numpy(data['times']).detach()[50:] 163 | ch_names = data['ch_names'] # Keep as a Python list or encode appropriately 164 | 165 | n_classes = 1654 # Each class contains 10 images 166 | samples_per_class = 10 # Each class has ten samples 167 | 168 | if self.classes is not None and self.pictures is not None: 169 | for c, p in zip(self.classes, self.pictures): 170 | start_index = c * 1 + p 171 | if start_index < len(preprocessed_eeg_data): # Ensure index is within range 172 | preprocessed_eeg_data_class = preprocessed_eeg_data[start_index: start_index+1] # Select only one sample 173 | labels = torch.full((1,), c, dtype=torch.long).detach() # Add class label 174 | data_list.append(preprocessed_eeg_data_class) 175 | label_list.append(labels) # Add labels to the label list 176 | 177 | elif self.classes is not None and self.pictures is None: 178 | for c in self.classes: 179 | start_index = c * samples_per_class 180 | preprocessed_eeg_data_class = preprocessed_eeg_data[start_index: start_index+samples_per_class] 181 | labels = torch.full((samples_per_class,), c, dtype=torch.long).detach() # Add class label 182 | data_list.append(preprocessed_eeg_data_class) 183 | label_list.append(labels) 184 | 185 | else: 186 | for i in range(n_classes): 187 | start_index = i * samples_per_class 188 | preprocessed_eeg_data_class = preprocessed_eeg_data[start_index: start_index+samples_per_class] 189 | labels = torch.full((samples_per_class,), i, dtype=torch.long).detach() # Add class label 190 | data_list.append(preprocessed_eeg_data_class) 191 | label_list.append(labels) 192 | 193 | 194 | else: 195 | if subject == self.adap_subject or self.adap_subject == None: 196 | file_name = 'preprocessed_eeg_test.npy' 197 | file_path = os.path.join(self.data_path, subject, file_name) 198 | data = np.load(file_path, allow_pickle=True) 199 | preprocessed_eeg_data = torch.from_numpy(data['preprocessed_eeg_data']).float().detach() 200 | times = torch.from_numpy(data['times']).detach()[50:] 201 | ch_names = data['ch_names'] # Keep as a Python list or encode appropriately 202 | n_classes = 200 # Each class contains 1 image 203 | 204 | samples_per_class = 1 # Each class has one sample 205 | 206 | for i in range(n_classes): 207 | if self.classes is not None and i not in self.classes: # Skip if class not in the specified list 208 | continue 209 | start_index = i * samples_per_class # Update start_index for each class 210 | preprocessed_eeg_data_class = preprocessed_eeg_data[start_index:start_index+samples_per_class] 211 | labels = torch.full((samples_per_class,), i, dtype=torch.long).detach() # Add class labels 212 | preprocessed_eeg_data_class = torch.mean(preprocessed_eeg_data_class.squeeze(0), 0) 213 | data_list.append(preprocessed_eeg_data_class) 214 | label_list.append(labels) # Add labels to the label list 215 | else: 216 | continue 217 | # Data list: (subjects * classes) * (10 * 4 * 17 * 100) 218 | # Data tensor: (subjects * classes * 10 * 4) * 17 * 100 219 | if self.train: 220 | data_tensor = torch.cat(data_list, dim=0).view(-1, *data_list[0].shape[2:]) 221 | print("data_tensor", data_tensor.shape) 222 | else: 223 | data_tensor = torch.cat(data_list, dim=0).view(-1, *data_list[0].shape) 224 | label_tensor = torch.cat(label_list, dim=0) 225 | if self.train: 226 | # Label tensor: (subjects * classes * 10 * 4) 227 | label_tensor = label_tensor.repeat_interleave(4) 228 | if self.classes is not None: 229 | unique_values = list(label_tensor.numpy()) 230 | lis = [] 231 | for i in unique_values: 232 | if i not in lis: 233 | lis.append(i) 234 | unique_values = torch.tensor(lis) 235 | mapping = {val.item(): index for index, val in enumerate(unique_values)} 236 | label_tensor = torch.tensor([mapping[val.item()] for val in label_tensor], dtype=torch.long) 237 | 238 | else: 239 | pass 240 | 241 | self.times = times 242 | self.ch_names = ch_names 243 | 244 | print(f"Data tensor shape: {data_tensor.shape}, label tensor shape: {label_tensor.shape}, text length: {len(texts)}, image length: {len(images)}") 245 | 246 | return data_tensor, label_tensor, texts, images 247 | 248 | def extract_eeg(self, eeg_data, time_window): 249 | 250 | start, end = time_window 251 | 252 | # Get the indices of the times within the specified window 253 | indices = (self.times >= start) & (self.times <= end) 254 | # Use these indices to select the corresponding data 255 | extracted_data = eeg_data[..., indices] 256 | # print(f"extracted_data shape: {extracted_data.shape}") 257 | 258 | return extracted_data 259 | 260 | def Textencoder(self, text): 261 | # Use the preprocessor to convert text to the model's input format 262 | text_inputs = torch.cat([clip.tokenize(t) for t in text]).to(device) 263 | 264 | # Use the CLIP model to encode text 265 | with torch.no_grad(): 266 | text_features = vlmodel.encode_text(text_inputs) 267 | 268 | text_features = F.normalize(text_features, dim=-1).detach() 269 | 270 | return text_features 271 | 272 | def ImageEncoder(self, images): 273 | batch_size = 20 # Set to an appropriate value 274 | image_features_list = [] 275 | 276 | for i in range(0, len(images), batch_size): 277 | batch_images = images[i:i + batch_size] 278 | image_inputs = torch.stack([preprocess_train(Image.open(img).convert("RGB")) for img in batch_images]).to(device) 279 | 280 | with torch.no_grad(): 281 | batch_image_features = vlmodel.encode_image(image_inputs) 282 | batch_image_features /= batch_image_features.norm(dim=-1, keepdim=True) 283 | 284 | image_features_list.append(batch_image_features) 285 | 286 | image_features = torch.cat(image_features_list, dim=0) 287 | 288 | return image_features 289 | 290 | def __getitem__(self, index): 291 | # Get the data and label corresponding to "index" 292 | # index: (subjects * classes * 10 * 4) 293 | x = self.data[index] 294 | label = self.labels[index] 295 | 296 | if self.pictures is None: 297 | if self.classes is None: 298 | index_n_sub_train = self.n_cls * 10 * 4 299 | index_n_sub_test = self.n_cls * 1 * 80 300 | else: 301 | index_n_sub_test = len(self.classes)* 1 * 80 302 | index_n_sub_train = len(self.classes)* 10 * 4 303 | # text_index: classes 304 | if self.train: 305 | text_index = (index % index_n_sub_train) // (10 * 4) 306 | else: 307 | text_index = (index % index_n_sub_test) 308 | # img_index: classes * 10 309 | if self.train: 310 | img_index = (index % index_n_sub_train) // (4) 311 | else: 312 | img_index = (index % index_n_sub_test) 313 | else: 314 | if self.classes is None: 315 | index_n_sub_train = self.n_cls * 1 * 4 316 | index_n_sub_test = self.n_cls * 1 * 80 317 | else: 318 | index_n_sub_test = len(self.classes)* 1 * 80 319 | index_n_sub_train = len(self.classes)* 1 * 4 320 | # text_index: classes 321 | if self.train: 322 | text_index = (index % index_n_sub_train) // (1 * 4) 323 | else: 324 | text_index = (index % index_n_sub_test) 325 | # img_index: classes * 10 326 | if self.train: 327 | img_index = (index % index_n_sub_train) // (4) 328 | else: 329 | img_index = (index % index_n_sub_test) 330 | 331 | text = self.text[text_index] 332 | img = self.img[img_index] 333 | 334 | text_features = self.text_features[text_index] 335 | img_features = self.img_features[img_index] 336 | 337 | return x, label, text, text_features, img, img_features 338 | 339 | def __len__(self): 340 | return self.data.shape[0] # or self.labels.shape[0] which should be the same 341 | 342 | if __name__ == "__main__": 343 | # Instantiate the dataset and dataloader 344 | # data_path = "/home/ldy/Workspace/THINGS/EEG/osfstorage-archive" # Replace with the path to your data 345 | data_path = data_path 346 | train_dataset = EEGDataset(data_path, subjects=['sub-01'], train=True) 347 | test_dataset = EEGDataset(data_path, subjects=['sub-01'], train=False) 348 | # train_dataset = EEGDataset(data_path, adap_subject='sub-01', train=True) 349 | # test_dataset = EEGDataset(data_path, adap_subject='sub-01', train=False) 350 | # train_dataset = EEGDataset(data_path, train=True) 351 | # test_dataset = EEGDataset(data_path, train=False) 352 | # Training EEG data shape: torch.Size([16540, 4, 17, 100]) [Number of training images, repetition count, channels, EEG time points] 353 | # Testing EEG data shape: torch.Size([200, 80, 17, 100]) 354 | # 1 second 'times': array([-0.2 , -0.19, -0.18, ... , 0.76, 0.77, 0.78, 0.79])} 355 | # 17 channels 'ch_names': ['Pz', 'P3', 'P7', 'O1', 'Oz', 'O2', 'P4', 'P8', 'P1', 'P5', 'PO7', 'PO3', 'POz', 'PO4', 'PO8', 'P6', 'P2'] 356 | # 100 Hz 357 | train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) 358 | test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True) 359 | 360 | i = 80*1-1 361 | x, label, text, text_features, img, img_features = test_dataset[i] 362 | print(f"Index {i}, Label: {label}, text: {text}") 363 | Image.open(img) 364 | -------------------------------------------------------------------------------- /Retrieval/eegdatasets_leaveone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import numpy as np 4 | import os 5 | import clip 6 | from torch.nn import functional as F 7 | import torch.nn as nn 8 | from torchvision import transforms 9 | from PIL import Image 10 | import requests 11 | 12 | import os 13 | # Note: Set http_proxy/https_proxy environment variables if needed 14 | cuda_device_count = torch.cuda.device_count() 15 | print(cuda_device_count) 16 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 17 | # vlmodel, preprocess = clip.load("ViT-B/32", device=device) 18 | model_type = 'ViT-H-14' 19 | import open_clip 20 | vlmodel, preprocess_train, feature_extractor = open_clip.create_model_and_transforms( 21 | model_type, pretrained='laion2b_s32b_b79k', precision='fp32', device = device) 22 | 23 | import json 24 | 25 | # Load the configuration from the JSON file 26 | config_path = "data_config.json" 27 | with open(config_path, "r") as config_file: 28 | config = json.load(config_file) 29 | 30 | # Access the paths from the config 31 | data_path = config["data_path"] 32 | img_directory_training = config["img_directory_training"] 33 | img_directory_test = config["img_directory_test"] 34 | 35 | 36 | class EEGDataset(): 37 | """ 38 | subjects = ['sub-01', 'sub-02', 'sub-05', 'sub-04', 'sub-03', 'sub-06', 'sub-07', 'sub-08', 'sub-09', 'sub-10'] 39 | """ 40 | def __init__(self, data_path, exclude_subject=None, subjects=None, train=True, time_window=[0, 1.0], classes = None, pictures = None, val_size=None): 41 | self.data_path = data_path 42 | self.train = train 43 | self.subject_list = os.listdir(data_path) 44 | self.subjects = self.subject_list if subjects is None else subjects 45 | self.n_sub = len(self.subjects) 46 | self.time_window = time_window 47 | self.n_cls = 1654 if train else 200 48 | self.classes = classes 49 | self.pictures = pictures 50 | self.exclude_subject = exclude_subject 51 | self.val_size = val_size 52 | # assert any subjects in subject_list 53 | assert any(sub in self.subject_list for sub in self.subjects) 54 | 55 | self.data, self.labels, self.text, self.img = self.load_data() 56 | 57 | self.data = self.extract_eeg(self.data, time_window) 58 | 59 | 60 | if self.classes is None and self.pictures is None: 61 | # Try to load the saved features if they exist 62 | features_filename = os.path.join(f'{model_type}_features_train.pt') if self.train else os.path.join(f'{model_type}_features_test.pt') 63 | 64 | if os.path.exists(features_filename) : 65 | saved_features = torch.load(features_filename) 66 | self.text_features = saved_features['text_features'] 67 | self.img_features = saved_features['img_features'] 68 | else: 69 | self.text_features = self.Textencoder(self.text) 70 | self.img_features = self.ImageEncoder(self.img) 71 | torch.save({ 72 | 'text_features': self.text_features.cpu(), 73 | 'img_features': self.img_features.cpu(), 74 | }, features_filename) 75 | else: 76 | self.text_features = self.Textencoder(self.text) 77 | self.img_features = self.ImageEncoder(self.img) 78 | 79 | def load_data(self): 80 | data_list = [] 81 | label_list = [] 82 | texts = [] 83 | images = [] 84 | 85 | if self.train: 86 | directory = img_directory_training 87 | else: 88 | directory = img_directory_test 89 | 90 | dirnames = [d for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d))] 91 | dirnames.sort() 92 | 93 | if self.classes is not None: 94 | dirnames = [dirnames[i] for i in self.classes] 95 | 96 | for dir in dirnames: 97 | 98 | try: 99 | idx = dir.index('_') 100 | description = dir[idx+1:] 101 | except ValueError: 102 | print(f"Skipped: {dir} due to no '_' found.") 103 | continue 104 | 105 | new_description = f"This picture is {description}" 106 | texts.append(new_description) 107 | 108 | if self.train: 109 | img_directory = img_directory_training 110 | else: 111 | img_directory = img_directory_test 112 | 113 | all_folders = [d for d in os.listdir(img_directory) if os.path.isdir(os.path.join(img_directory, d))] 114 | all_folders.sort() 115 | 116 | if self.classes is not None and self.pictures is not None: 117 | images = [] 118 | for i in range(len(self.classes)): 119 | class_idx = self.classes[i] 120 | pic_idx = self.pictures[i] 121 | if class_idx < len(all_folders): 122 | folder = all_folders[class_idx] 123 | folder_path = os.path.join(img_directory, folder) 124 | all_images = [img for img in os.listdir(folder_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))] 125 | all_images.sort() 126 | if pic_idx < len(all_images): 127 | images.append(os.path.join(folder_path, all_images[pic_idx])) 128 | elif self.classes is not None and self.pictures is None: 129 | images = [] 130 | for i in range(len(self.classes)): 131 | class_idx = self.classes[i] 132 | if class_idx < len(all_folders): 133 | folder = all_folders[class_idx] 134 | folder_path = os.path.join(img_directory, folder) 135 | all_images = [img for img in os.listdir(folder_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))] 136 | all_images.sort() 137 | images.extend(os.path.join(folder_path, img) for img in all_images) 138 | elif self.classes is None: 139 | images = [] 140 | for folder in all_folders: 141 | folder_path = os.path.join(img_directory, folder) 142 | all_images = [img for img in os.listdir(folder_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))] 143 | all_images.sort() 144 | images.extend(os.path.join(folder_path, img) for img in all_images) 145 | else: 146 | 147 | print("Error") 148 | 149 | print("self.subjects", self.subjects) 150 | print("exclude_subject", self.exclude_subject) 151 | for subject in self.subjects: 152 | if self.train: 153 | if subject == self.exclude_subject: 154 | continue 155 | # print("subject:", subject) 156 | file_name = 'preprocessed_eeg_training.npy' 157 | 158 | file_path = os.path.join(self.data_path, subject, file_name) 159 | data = np.load(file_path, allow_pickle=True) 160 | 161 | preprocessed_eeg_data = torch.from_numpy(data['preprocessed_eeg_data']).float().detach() 162 | times = torch.from_numpy(data['times']).detach()[50:] 163 | ch_names = data['ch_names'] 164 | 165 | n_classes = 1654 166 | samples_per_class = 10 167 | 168 | if self.classes is not None and self.pictures is not None: 169 | for c, p in zip(self.classes, self.pictures): 170 | start_index = c * 1 + p 171 | if start_index < len(preprocessed_eeg_data): 172 | preprocessed_eeg_data_class = preprocessed_eeg_data[start_index: start_index+1] 173 | labels = torch.full((1,), c, dtype=torch.long).detach() 174 | data_list.append(preprocessed_eeg_data_class) 175 | label_list.append(labels) 176 | 177 | elif self.classes is not None and self.pictures is None: 178 | for c in self.classes: 179 | start_index = c * samples_per_class 180 | preprocessed_eeg_data_class = preprocessed_eeg_data[start_index: start_index+samples_per_class] 181 | labels = torch.full((samples_per_class,), c, dtype=torch.long).detach() 182 | data_list.append(preprocessed_eeg_data_class) 183 | label_list.append(labels) 184 | 185 | else: 186 | for i in range(n_classes): 187 | start_index = i * samples_per_class 188 | # if self.exclude_subject==None: 189 | # preprocessed_eeg_data_class = preprocessed_eeg_data[start_index: start_index+samples_per_class] 190 | # else: 191 | preprocessed_eeg_data_class = preprocessed_eeg_data[start_index: start_index+samples_per_class] 192 | # print("preprocessed_eeg_data_class", preprocessed_eeg_data_class.shape) 193 | # preprocessed_eeg_data_class = torch.mean(preprocessed_eeg_data_class, 1) 194 | # preprocessed_eeg_data_class = torch.mean(preprocessed_eeg_data_class, 0) 195 | # print("preprocessed_eeg_data_class", preprocessed_eeg_data_class.shape) 196 | labels = torch.full((samples_per_class,), i, dtype=torch.long).detach() 197 | data_list.append(preprocessed_eeg_data_class) 198 | label_list.append(labels) 199 | 200 | 201 | else: 202 | if subject == self.exclude_subject or self.exclude_subject==None: 203 | file_name = 'preprocessed_eeg_test.npy' 204 | file_path = os.path.join(self.data_path, subject, file_name) 205 | data = np.load(file_path, allow_pickle=True) 206 | preprocessed_eeg_data = torch.from_numpy(data['preprocessed_eeg_data']).float().detach() 207 | times = torch.from_numpy(data['times']).detach()[50:] 208 | ch_names = data['ch_names'] 209 | n_classes = 200 # Each class contains 1 images 210 | 211 | samples_per_class = 1 212 | 213 | for i in range(n_classes): 214 | if self.classes is not None and i not in self.classes: # If we've defined specific classes and the current class is not in the list, skip 215 | continue 216 | start_index = i * samples_per_class # Update start_index for each class 217 | preprocessed_eeg_data_class = preprocessed_eeg_data[start_index:start_index+samples_per_class] 218 | # print("preprocessed_eeg_data_class", preprocessed_eeg_data_class.shape) 219 | labels = torch.full((samples_per_class,), i, dtype=torch.long).detach() # Add class labels 220 | preprocessed_eeg_data_class = torch.mean(preprocessed_eeg_data_class.squeeze(0), 0) 221 | # print("preprocessed_eeg_data_class", preprocessed_eeg_data_class.shape) 222 | data_list.append(preprocessed_eeg_data_class) 223 | label_list.append(labels) # Add labels to the label list 224 | else: 225 | continue 226 | # datalist: (subjects * classes) * (10 * 4 * 17 * 100) 227 | # data_tensor: (subjects * classes * 10 * 4) * 17 * 100 228 | # data_list = np.mean(data_list, ) 229 | # print("data_list", len(data_list)) 230 | if self.train: 231 | # print("data_list", *data_list[0].shape[1:]) 232 | data_tensor = torch.cat(data_list, dim=0).view(-1, *data_list[0].shape[2:]) 233 | # data_tensor = torch.cat(data_list, dim=0).view(-1, *data_list[0].shape[1:]) 234 | # data_tensor = torch.cat(data_list, dim=0).view(-1, *data_list[0].shape) 235 | # print("label_tensor", label_tensor.shape) 236 | print("data_tensor", data_tensor.shape) 237 | else: 238 | data_tensor = torch.cat(data_list, dim=0).view(-1, *data_list[0].shape) 239 | # label_tensor = torch.cat(label_list, dim=0) 240 | # print("label_tensor", label_tensor.shape) 241 | # data_tensor = torch.cat(data_list, dim=0).view(-1, *data_list[0].shape[2:]) 242 | # print("data_tensor", data_tensor.shape) 243 | # label_list: (subjects * classes) * 10 244 | # label_tensor: (subjects * classes * 10) 245 | # print("label_tensor = torch.cat(label_list, dim=0)") 246 | # print(label_list) 247 | label_tensor = torch.cat(label_list, dim=0) 248 | # label_tensor = torch.cat(label_list, dim=0) 249 | # print(label_tensor[:300]) 250 | if self.train: 251 | # label_tensor: (subjects * classes * 10 * 4) 252 | label_tensor = label_tensor.repeat_interleave(4) 253 | if self.classes is not None: 254 | unique_values = list(label_tensor.numpy()) 255 | lis = [] 256 | for i in unique_values: 257 | if i not in lis: 258 | lis.append(i) 259 | unique_values = torch.tensor(lis) 260 | mapping = {val.item(): index for index, val in enumerate(unique_values)} 261 | label_tensor = torch.tensor([mapping[val.item()] for val in label_tensor], dtype=torch.long) 262 | 263 | else: 264 | # label_tensor = label_tensor.repeat_interleave(80) 265 | # if self.classes is not None: 266 | # unique_values = torch.unique(label_tensor, sorted=False) 267 | 268 | # mapping = {val.item(): index for index, val in enumerate(torch.flip(unique_values, [0]))} 269 | # label_tensor = torch.tensor([mapping[val.item()] for val in label_tensor], dtype=torch.long) 270 | pass 271 | 272 | 273 | self.times = times 274 | self.ch_names = ch_names 275 | 276 | print(f"Data tensor shape: {data_tensor.shape}, label tensor shape: {label_tensor.shape}, text length: {len(texts)}, image length: {len(images)}") 277 | 278 | return data_tensor, label_tensor, texts, images 279 | 280 | def extract_eeg(self, eeg_data, time_window): 281 | 282 | start, end = time_window 283 | 284 | # Get the indices of the times within the specified window 285 | indices = (self.times >= start) & (self.times <= end) 286 | # print("self.times", self.times.shape) 287 | # print("indices", indices) 288 | # print("indices", indices.shape) 289 | # print("eeg_data", eeg_data.shape) 290 | # Use these indices to select the corresponding data 291 | extracted_data = eeg_data[..., indices] 292 | # print(f"extracted_data shape: {extracted_data.shape}") 293 | 294 | return extracted_data 295 | 296 | def Textencoder(self, text): 297 | 298 | text_inputs = torch.cat([clip.tokenize(t) for t in text]).to(device) 299 | # print("text_inputs", text_inputs) 300 | 301 | with torch.no_grad(): 302 | text_features = vlmodel.encode_text(text_inputs) 303 | 304 | text_features = F.normalize(text_features, dim=-1).detach() 305 | 306 | return text_features 307 | 308 | def ImageEncoder(self,images): 309 | batch_size = 20 310 | image_features_list = [] 311 | 312 | for i in range(0, len(images), batch_size): 313 | batch_images = images[i:i + batch_size] 314 | image_inputs = torch.stack([preprocess_train(Image.open(img).convert("RGB")) for img in batch_images]).to(device) 315 | 316 | with torch.no_grad(): 317 | batch_image_features = vlmodel.encode_image(image_inputs) 318 | batch_image_features /= batch_image_features.norm(dim=-1, keepdim=True) 319 | 320 | image_features_list.append(batch_image_features) 321 | 322 | image_features = torch.cat(image_features_list, dim=0) 323 | 324 | return image_features 325 | 326 | def __getitem__(self, index): 327 | # Get the data and label corresponding to "index" 328 | # index: (subjects * classes * 10 * 4) 329 | x = self.data[index] 330 | label = self.labels[index] 331 | 332 | if self.pictures is None: 333 | if self.classes is None: 334 | index_n_sub_train = self.n_cls * 10 * 4 335 | index_n_sub_test = self.n_cls * 1 * 80 336 | else: 337 | index_n_sub_test = len(self.classes)* 1 * 80 338 | index_n_sub_train = len(self.classes)* 10 * 4 339 | # text_index: classes 340 | if self.train: 341 | text_index = (index % index_n_sub_train) // (10 * 4) 342 | else: 343 | text_index = (index % index_n_sub_test) 344 | # img_index: classes * 10 345 | if self.train: 346 | img_index = (index % index_n_sub_train) // (4) 347 | else: 348 | img_index = (index % index_n_sub_test) 349 | else: 350 | if self.classes is None: 351 | index_n_sub_train = self.n_cls * 1 * 4 352 | index_n_sub_test = self.n_cls * 1 * 80 353 | else: 354 | index_n_sub_test = len(self.classes)* 1 * 80 355 | index_n_sub_train = len(self.classes)* 1 * 4 356 | # text_index: classes 357 | if self.train: 358 | text_index = (index % index_n_sub_train) // (1 * 4) 359 | else: 360 | text_index = (index % index_n_sub_test) 361 | # img_index: classes * 10 362 | if self.train: 363 | img_index = (index % index_n_sub_train) // (4) 364 | else: 365 | img_index = (index % index_n_sub_test) 366 | # print("text_index", text_index) 367 | # print("self.text", self.text) 368 | # print("self.text", len(self.text)) 369 | text = self.text[text_index] 370 | img = self.img[img_index] 371 | 372 | text_features = self.text_features[text_index] 373 | img_features = self.img_features[img_index] 374 | 375 | return x, label, text, text_features, img, img_features 376 | 377 | def __len__(self): 378 | return self.data.shape[0] # or self.labels.shape[0] which should be the same 379 | 380 | if __name__ == "__main__": 381 | # Instantiate the dataset and dataloader 382 | # data_path = "/home/ldy/Workspace/THINGS/EEG/osfstorage-archive" # Replace with the path to your data 383 | data_path = data_path 384 | train_dataset = EEGDataset(data_path, subjects = ['sub-01'], train=True) 385 | test_dataset = EEGDataset(data_path, subjects = ['sub-01'], train=False) 386 | # train_dataset = EEGDataset(data_path, exclude_subject = 'sub-01', train=True) 387 | # test_dataset = EEGDataset(data_path, exclude_subject = 'sub-01', train=False) 388 | # train_dataset = EEGDataset(data_path, train=True) 389 | # test_dataset = EEGDataset(data_path, train=False) 390 | 391 | 392 | 393 | 394 | # 100 Hz 395 | train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) 396 | test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True) 397 | 398 | i = 80*1-1 399 | x, label, text, text_features, img, img_features = test_dataset[i] 400 | print(f"Index {i}, Label: {label}, text: {text}") 401 | Image.open(img) 402 | 403 | 404 | 405 | -------------------------------------------------------------------------------- /Generation/eegdatasets_leaveone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import numpy as np 4 | import os 5 | import clip 6 | from torch.nn import functional as F 7 | import torch.nn as nn 8 | from torchvision import transforms 9 | from PIL import Image 10 | import requests 11 | 12 | import os 13 | # Note: Set http_proxy/https_proxy environment variables if needed 14 | cuda_device_count = torch.cuda.device_count() 15 | print(cuda_device_count) 16 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 17 | # vlmodel, preprocess = clip.load("ViT-B/32", device=device) 18 | model_type = 'ViT-H-14' 19 | import open_clip 20 | vlmodel, preprocess_train, feature_extractor = open_clip.create_model_and_transforms( 21 | model_type, pretrained='laion2b_s32b_b79k', precision='fp32', device = device) 22 | 23 | import json 24 | 25 | # Load the configuration from the JSON file 26 | config_path = "data_config.json" 27 | with open(config_path, "r") as config_file: 28 | config = json.load(config_file) 29 | 30 | # Access the paths from the config 31 | data_path = config["data_path"] 32 | img_directory_training = config["img_directory_training"] 33 | img_directory_test = config["img_directory_test"] 34 | 35 | 36 | class EEGDataset(): 37 | """ 38 | subjects = ['sub-01', 'sub-02', 'sub-05', 'sub-04', 'sub-03', 'sub-06', 'sub-07', 'sub-08', 'sub-09', 'sub-10'] 39 | """ 40 | def __init__(self, data_path, exclude_subject=None, subjects=None, train=True, time_window=[0, 1.0], classes = None, pictures = None, val_size=None): 41 | self.data_path = data_path 42 | self.train = train 43 | self.subject_list = os.listdir(data_path) 44 | self.subjects = self.subject_list if subjects is None else subjects 45 | self.n_sub = len(self.subjects) 46 | self.time_window = time_window 47 | self.n_cls = 1654 if train else 200 48 | self.classes = classes 49 | self.pictures = pictures 50 | self.exclude_subject = exclude_subject 51 | self.val_size = val_size 52 | # assert any subjects in subject_list 53 | assert any(sub in self.subject_list for sub in self.subjects) 54 | 55 | self.data, self.labels, self.text, self.img = self.load_data() 56 | 57 | self.data = self.extract_eeg(self.data, time_window) 58 | 59 | 60 | if self.classes is None and self.pictures is None: 61 | # Try to load the saved features if they exist 62 | features_filename = os.path.join(f'{model_type}_features_train.pt') if self.train else os.path.join(f'{model_type}_features_test.pt') 63 | 64 | if os.path.exists(features_filename) : 65 | saved_features = torch.load(features_filename) 66 | self.text_features = saved_features['text_features'] 67 | self.img_features = saved_features['img_features'] 68 | else: 69 | self.text_features = self.Textencoder(self.text) 70 | self.img_features = self.ImageEncoder(self.img) 71 | torch.save({ 72 | 'text_features': self.text_features.cpu(), 73 | 'img_features': self.img_features.cpu(), 74 | }, features_filename) 75 | else: 76 | self.text_features = self.Textencoder(self.text) 77 | self.img_features = self.ImageEncoder(self.img) 78 | 79 | def load_data(self): 80 | data_list = [] 81 | label_list = [] 82 | texts = [] 83 | images = [] 84 | 85 | if self.train: 86 | directory = img_directory_training 87 | else: 88 | directory = img_directory_test 89 | 90 | dirnames = [d for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d))] 91 | dirnames.sort() 92 | 93 | if self.classes is not None: 94 | dirnames = [dirnames[i] for i in self.classes] 95 | 96 | for dir in dirnames: 97 | 98 | try: 99 | idx = dir.index('_') 100 | description = dir[idx+1:] 101 | except ValueError: 102 | print(f"Skipped: {dir} due to no '_' found.") 103 | continue 104 | 105 | new_description = f"This picture is {description}" 106 | texts.append(new_description) 107 | 108 | if self.train: 109 | img_directory = img_directory_training 110 | else: 111 | img_directory = img_directory_test 112 | 113 | all_folders = [d for d in os.listdir(img_directory) if os.path.isdir(os.path.join(img_directory, d))] 114 | all_folders.sort() 115 | 116 | if self.classes is not None and self.pictures is not None: 117 | images = [] 118 | for i in range(len(self.classes)): 119 | class_idx = self.classes[i] 120 | pic_idx = self.pictures[i] 121 | if class_idx < len(all_folders): 122 | folder = all_folders[class_idx] 123 | folder_path = os.path.join(img_directory, folder) 124 | all_images = [img for img in os.listdir(folder_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))] 125 | all_images.sort() 126 | if pic_idx < len(all_images): 127 | images.append(os.path.join(folder_path, all_images[pic_idx])) 128 | elif self.classes is not None and self.pictures is None: 129 | images = [] 130 | for i in range(len(self.classes)): 131 | class_idx = self.classes[i] 132 | if class_idx < len(all_folders): 133 | folder = all_folders[class_idx] 134 | folder_path = os.path.join(img_directory, folder) 135 | all_images = [img for img in os.listdir(folder_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))] 136 | all_images.sort() 137 | images.extend(os.path.join(folder_path, img) for img in all_images) 138 | elif self.classes is None: 139 | images = [] 140 | for folder in all_folders: 141 | folder_path = os.path.join(img_directory, folder) 142 | all_images = [img for img in os.listdir(folder_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))] 143 | all_images.sort() 144 | images.extend(os.path.join(folder_path, img) for img in all_images) 145 | else: 146 | 147 | print("Error") 148 | 149 | print("self.subjects", self.subjects) 150 | print("exclude_subject", self.exclude_subject) 151 | for subject in self.subjects: 152 | if self.train: 153 | if subject == self.exclude_subject: 154 | continue 155 | # print("subject:", subject) 156 | file_name = 'preprocessed_eeg_training.npy' 157 | 158 | file_path = os.path.join(self.data_path, subject, file_name) 159 | data = np.load(file_path, allow_pickle=True) 160 | 161 | preprocessed_eeg_data = torch.from_numpy(data['preprocessed_eeg_data']).float().detach() 162 | times = torch.from_numpy(data['times']).detach()[50:] 163 | ch_names = data['ch_names'] 164 | 165 | n_classes = 1654 166 | samples_per_class = 10 167 | 168 | if self.classes is not None and self.pictures is not None: 169 | for c, p in zip(self.classes, self.pictures): 170 | start_index = c * 1 + p 171 | if start_index < len(preprocessed_eeg_data): 172 | preprocessed_eeg_data_class = preprocessed_eeg_data[start_index: start_index+1] 173 | labels = torch.full((1,), c, dtype=torch.long).detach() 174 | data_list.append(preprocessed_eeg_data_class) 175 | label_list.append(labels) 176 | 177 | elif self.classes is not None and self.pictures is None: 178 | for c in self.classes: 179 | start_index = c * samples_per_class 180 | preprocessed_eeg_data_class = preprocessed_eeg_data[start_index: start_index+samples_per_class] 181 | labels = torch.full((samples_per_class,), c, dtype=torch.long).detach() 182 | data_list.append(preprocessed_eeg_data_class) 183 | label_list.append(labels) 184 | 185 | else: 186 | for i in range(n_classes): 187 | start_index = i * samples_per_class 188 | # if self.exclude_subject==None: 189 | # preprocessed_eeg_data_class = preprocessed_eeg_data[start_index: start_index+samples_per_class] 190 | # else: 191 | preprocessed_eeg_data_class = preprocessed_eeg_data[start_index: start_index+samples_per_class] 192 | # print("preprocessed_eeg_data_class", preprocessed_eeg_data_class.shape) 193 | # preprocessed_eeg_data_class = torch.mean(preprocessed_eeg_data_class, 1) 194 | # preprocessed_eeg_data_class = torch.mean(preprocessed_eeg_data_class, 0) 195 | # print("preprocessed_eeg_data_class", preprocessed_eeg_data_class.shape) 196 | labels = torch.full((samples_per_class,), i, dtype=torch.long).detach() 197 | data_list.append(preprocessed_eeg_data_class) 198 | label_list.append(labels) 199 | 200 | 201 | else: 202 | if subject == self.exclude_subject or self.exclude_subject==None: 203 | file_name = 'preprocessed_eeg_test.npy' 204 | file_path = os.path.join(self.data_path, subject, file_name) 205 | data = np.load(file_path, allow_pickle=True) 206 | preprocessed_eeg_data = torch.from_numpy(data['preprocessed_eeg_data']).float().detach() 207 | times = torch.from_numpy(data['times']).detach()[50:] 208 | ch_names = data['ch_names'] 209 | n_classes = 200 # Each class contains 1 images 210 | 211 | samples_per_class = 1 212 | 213 | for i in range(n_classes): 214 | if self.classes is not None and i not in self.classes: # If we've defined specific classes and the current class is not in the list, skip 215 | continue 216 | start_index = i * samples_per_class # Update start_index for each class 217 | preprocessed_eeg_data_class = preprocessed_eeg_data[start_index:start_index+samples_per_class] 218 | # print("preprocessed_eeg_data_class", preprocessed_eeg_data_class.shape) 219 | labels = torch.full((samples_per_class,), i, dtype=torch.long).detach() # Add class labels 220 | preprocessed_eeg_data_class = torch.mean(preprocessed_eeg_data_class.squeeze(0), 0) 221 | # print("preprocessed_eeg_data_class", preprocessed_eeg_data_class.shape) 222 | data_list.append(preprocessed_eeg_data_class) 223 | label_list.append(labels) # Add labels to the label list 224 | else: 225 | continue 226 | # datalist: (subjects * classes) * (10 * 4 * 17 * 100) 227 | # data_tensor: (subjects * classes * 10 * 4) * 17 * 100 228 | # data_list = np.mean(data_list, ) 229 | # print("data_list", len(data_list)) 230 | if self.train: 231 | # print("data_list", *data_list[0].shape[1:]) 232 | data_tensor = torch.cat(data_list, dim=0).view(-1, *data_list[0].shape[2:]) 233 | # data_tensor = torch.cat(data_list, dim=0).view(-1, *data_list[0].shape[1:]) 234 | # data_tensor = torch.cat(data_list, dim=0).view(-1, *data_list[0].shape) 235 | # print("label_tensor", label_tensor.shape) 236 | print("data_tensor", data_tensor.shape) 237 | else: 238 | data_tensor = torch.cat(data_list, dim=0).view(-1, *data_list[0].shape) 239 | # label_tensor = torch.cat(label_list, dim=0) 240 | # print("label_tensor", label_tensor.shape) 241 | # data_tensor = torch.cat(data_list, dim=0).view(-1, *data_list[0].shape[2:]) 242 | # print("data_tensor", data_tensor.shape) 243 | # label_list: (subjects * classes) * 10 244 | # label_tensor: (subjects * classes * 10) 245 | # print("label_tensor = torch.cat(label_list, dim=0)") 246 | # print(label_list) 247 | label_tensor = torch.cat(label_list, dim=0) 248 | # label_tensor = torch.cat(label_list, dim=0) 249 | # print(label_tensor[:300]) 250 | if self.train: 251 | # label_tensor: (subjects * classes * 10 * 4) 252 | label_tensor = label_tensor.repeat_interleave(4) 253 | if self.classes is not None: 254 | unique_values = list(label_tensor.numpy()) 255 | lis = [] 256 | for i in unique_values: 257 | if i not in lis: 258 | lis.append(i) 259 | unique_values = torch.tensor(lis) 260 | mapping = {val.item(): index for index, val in enumerate(unique_values)} 261 | label_tensor = torch.tensor([mapping[val.item()] for val in label_tensor], dtype=torch.long) 262 | 263 | else: 264 | # label_tensor = label_tensor.repeat_interleave(80) 265 | # if self.classes is not None: 266 | # unique_values = torch.unique(label_tensor, sorted=False) 267 | 268 | # mapping = {val.item(): index for index, val in enumerate(torch.flip(unique_values, [0]))} 269 | # label_tensor = torch.tensor([mapping[val.item()] for val in label_tensor], dtype=torch.long) 270 | pass 271 | 272 | 273 | self.times = times 274 | self.ch_names = ch_names 275 | 276 | print(f"Data tensor shape: {data_tensor.shape}, label tensor shape: {label_tensor.shape}, text length: {len(texts)}, image length: {len(images)}") 277 | 278 | return data_tensor, label_tensor, texts, images 279 | 280 | def extract_eeg(self, eeg_data, time_window): 281 | 282 | start, end = time_window 283 | 284 | # Get the indices of the times within the specified window 285 | indices = (self.times >= start) & (self.times <= end) 286 | # print("self.times", self.times.shape) 287 | # print("indices", indices) 288 | # print("indices", indices.shape) 289 | # print("eeg_data", eeg_data.shape) 290 | # Use these indices to select the corresponding data 291 | extracted_data = eeg_data[..., indices] 292 | # print(f"extracted_data shape: {extracted_data.shape}") 293 | 294 | return extracted_data 295 | 296 | def Textencoder(self, text): 297 | 298 | text_inputs = torch.cat([clip.tokenize(t) for t in text]).to(device) 299 | # print("text_inputs", text_inputs) 300 | 301 | with torch.no_grad(): 302 | text_features = vlmodel.encode_text(text_inputs) 303 | 304 | text_features = F.normalize(text_features, dim=-1).detach() 305 | 306 | return text_features 307 | 308 | def ImageEncoder(self,images): 309 | batch_size = 20 310 | image_features_list = [] 311 | 312 | for i in range(0, len(images), batch_size): 313 | batch_images = images[i:i + batch_size] 314 | image_inputs = torch.stack([preprocess_train(Image.open(img).convert("RGB")) for img in batch_images]).to(device) 315 | 316 | with torch.no_grad(): 317 | batch_image_features = vlmodel.encode_image(image_inputs) 318 | # batch_image_features /= batch_image_features.norm(dim=-1, keepdim=True) #use unnormalized clip embedding in reconstruction training 319 | 320 | image_features_list.append(batch_image_features) 321 | 322 | image_features = torch.cat(image_features_list, dim=0) 323 | 324 | return image_features 325 | 326 | def __getitem__(self, index): 327 | # Get the data and label corresponding to "index" 328 | # index: (subjects * classes * 10 * 4) 329 | x = self.data[index] 330 | label = self.labels[index] 331 | 332 | if self.pictures is None: 333 | if self.classes is None: 334 | index_n_sub_train = self.n_cls * 10 * 4 335 | index_n_sub_test = self.n_cls * 1 * 80 336 | else: 337 | index_n_sub_test = len(self.classes)* 1 * 80 338 | index_n_sub_train = len(self.classes)* 10 * 4 339 | # text_index: classes 340 | if self.train: 341 | text_index = (index % index_n_sub_train) // (10 * 4) 342 | else: 343 | text_index = (index % index_n_sub_test) 344 | # img_index: classes * 10 345 | if self.train: 346 | img_index = (index % index_n_sub_train) // (4) 347 | else: 348 | img_index = (index % index_n_sub_test) 349 | else: 350 | if self.classes is None: 351 | index_n_sub_train = self.n_cls * 1 * 4 352 | index_n_sub_test = self.n_cls * 1 * 80 353 | else: 354 | index_n_sub_test = len(self.classes)* 1 * 80 355 | index_n_sub_train = len(self.classes)* 1 * 4 356 | # text_index: classes 357 | if self.train: 358 | text_index = (index % index_n_sub_train) // (1 * 4) 359 | else: 360 | text_index = (index % index_n_sub_test) 361 | # img_index: classes * 10 362 | if self.train: 363 | img_index = (index % index_n_sub_train) // (4) 364 | else: 365 | img_index = (index % index_n_sub_test) 366 | # print("text_index", text_index) 367 | # print("self.text", self.text) 368 | # print("self.text", len(self.text)) 369 | text = self.text[text_index] 370 | img = self.img[img_index] 371 | 372 | text_features = self.text_features[text_index] 373 | img_features = self.img_features[img_index] 374 | 375 | return x, label, text, text_features, img, img_features 376 | 377 | def __len__(self): 378 | return self.data.shape[0] # or self.labels.shape[0] which should be the same 379 | 380 | if __name__ == "__main__": 381 | # Instantiate the dataset and dataloader 382 | # data_path = "/home/ldy/Workspace/THINGS/EEG/osfstorage-archive" # Replace with the path to your data 383 | data_path = data_path 384 | train_dataset = EEGDataset(data_path, subjects = ['sub-01'], train=True) 385 | test_dataset = EEGDataset(data_path, subjects = ['sub-01'], train=False) 386 | # train_dataset = EEGDataset(data_path, exclude_subject = 'sub-01', train=True) 387 | # test_dataset = EEGDataset(data_path, exclude_subject = 'sub-01', train=False) 388 | # train_dataset = EEGDataset(data_path, train=True) 389 | # test_dataset = EEGDataset(data_path, train=False) 390 | 391 | 392 | 393 | 394 | # 100 Hz 395 | train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) 396 | test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True) 397 | 398 | i = 80*1-1 399 | x, label, text, text_features, img, img_features = test_dataset[i] 400 | print(f"Index {i}, Label: {label}, text: {text}") 401 | Image.open(img) 402 | 403 | 404 | 405 | --------------------------------------------------------------------------------