├── models
├── __init__.py
├── utils
│ ├── __init__.py
│ ├── masking.py
│ ├── metrics.py
│ ├── tools.py
│ └── timefeatures.py
├── subject_layers
│ ├── __init__.py
│ ├── StandardNorm.py
│ ├── Conv_Blocks.py
│ ├── Crossformer_EncDec.py
│ ├── Transformer_EncDec.py
│ ├── AutoCorrelation.py
│ ├── Autoformer_EncDec.py
│ ├── FourierCorrelation.py
│ ├── Pyraformer_EncDec.py
│ ├── Embed.py
│ ├── ETSformer_EncDec.py
│ └── SelfAttention_Family.py
├── loss.py
└── util.py
├── imgs
├── encoder.png
├── test_acc.png
├── bs=16_test_acc.png
├── fig-framework.png
├── fig-genexample.png
└── temporal_analysis.png
├── Retrieval
├── data_config.json
├── eegdatasets_joint_subjects.py
└── eegdatasets_leaveone.py
├── Generation
├── data_config.json
├── image_adapter.ipynb
├── diffusion_prior.py
└── eegdatasets_leaveone.py
├── requirements.txt
├── LICENSE
├── environment.yml
├── setup.sh
├── .gitignore
├── EEG-preprocessing
├── preprocessing.py
└── preprocessing_utils.py
└── README.md
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/subject_layers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/imgs/encoder.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ncclab-sustech/EEG_Image_decode/HEAD/imgs/encoder.png
--------------------------------------------------------------------------------
/imgs/test_acc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ncclab-sustech/EEG_Image_decode/HEAD/imgs/test_acc.png
--------------------------------------------------------------------------------
/imgs/bs=16_test_acc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ncclab-sustech/EEG_Image_decode/HEAD/imgs/bs=16_test_acc.png
--------------------------------------------------------------------------------
/imgs/fig-framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ncclab-sustech/EEG_Image_decode/HEAD/imgs/fig-framework.png
--------------------------------------------------------------------------------
/imgs/fig-genexample.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ncclab-sustech/EEG_Image_decode/HEAD/imgs/fig-genexample.png
--------------------------------------------------------------------------------
/imgs/temporal_analysis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ncclab-sustech/EEG_Image_decode/HEAD/imgs/temporal_analysis.png
--------------------------------------------------------------------------------
/Retrieval/data_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_path": "/home/ldy/Workspace/THINGS/Preprocessed_data_250Hz",
3 | "img_directory_training": "/home/ldy/Workspace/THINGS/images_set/training_images",
4 | "img_directory_test": "/home/ldy/Workspace/THINGS/images_set/test_images",
5 | "features_path":"/home/ldy/Workspace/THINGS/CLIP"
6 | }
--------------------------------------------------------------------------------
/Generation/data_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_path": "/home/ldy/Workspace/THINGS/Preprocessed_data_250Hz",
3 | "img_directory_training": "/home/ldy/Workspace/THINGS/images_set/training_images",
4 |
5 | "img_directory_test": "/home/ldy/Workspace/Generation/generated_imgs",
6 | "features_path":"/home/ldy/Workspace/THINGS/CLIP"
7 | }
8 |
9 |
--------------------------------------------------------------------------------
/models/utils/masking.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class TriangularCausalMask():
5 | def __init__(self, B, L, device="cpu"):
6 | mask_shape = [B, 1, L, L]
7 | with torch.no_grad():
8 | self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device)
9 |
10 | @property
11 | def mask(self):
12 | return self._mask
13 |
14 |
15 | class ProbMask():
16 | def __init__(self, B, H, L, index, scores, device="cpu"):
17 | _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1)
18 | _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1])
19 | indicator = _mask_ex[torch.arange(B)[:, None, None],
20 | torch.arange(H)[None, :, None],
21 | index, :].to(device)
22 | self._mask = indicator.view(scores.shape).to(device)
23 |
24 | @property
25 | def mask(self):
26 | return self._mask
27 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # Core Deep Learning Framework
2 | torch==2.5.0
3 | torchvision==0.20.0
4 | torchaudio==2.5.0
5 |
6 | # Hugging Face Ecosystem
7 | transformers==4.36.0
8 | diffusers==0.30.0
9 | accelerate==1.5.2
10 | huggingface-hub==0.30.2
11 |
12 | # EEG Processing
13 | braindecode==0.8.1
14 | mne==1.9.0
15 |
16 | # CLIP and Vision
17 | clip @ git+https://github.com/openai/CLIP.git
18 | open-clip-torch
19 |
20 | # Image Generation
21 | dalle2-pytorch==1.15.6
22 | pytorch-msssim==1.0.0
23 |
24 | # Data Processing
25 | numpy==1.26.4
26 | pandas==2.3.3
27 | scipy==1.15.3
28 | scikit-learn==1.6.1
29 | h5py==3.13.0
30 |
31 | # Visualization
32 | matplotlib==3.10.7
33 | seaborn==0.13.2
34 | tqdm==4.67.1
35 |
36 | # Deep Learning Utilities
37 | einops==0.8.1
38 | info-nce-pytorch==0.1.0
39 | reformer-pytorch==1.4.4
40 |
41 | # Logging and Tracking
42 | wandb==0.19.10
43 |
44 | # Image Processing
45 | Pillow
46 | imageio==2.37.0
47 | kornia==0.8.0
48 |
49 | # Utilities
50 | ftfy==6.3.1
51 | regex==2024.11.6
52 | clip-retrieval
53 |
--------------------------------------------------------------------------------
/models/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def RSE(pred, true):
5 | return np.sqrt(np.sum((true - pred) ** 2)) / np.sqrt(np.sum((true - true.mean()) ** 2))
6 |
7 |
8 | def CORR(pred, true):
9 | u = ((true - true.mean(0)) * (pred - pred.mean(0))).sum(0)
10 | d = np.sqrt(((true - true.mean(0)) ** 2 * (pred - pred.mean(0)) ** 2).sum(0))
11 | return (u / d).mean(-1)
12 |
13 |
14 | def MAE(pred, true):
15 | return np.mean(np.abs(pred - true))
16 |
17 |
18 | def MSE(pred, true):
19 | return np.mean((pred - true) ** 2)
20 |
21 |
22 | def RMSE(pred, true):
23 | return np.sqrt(MSE(pred, true))
24 |
25 |
26 | def MAPE(pred, true):
27 | return np.mean(np.abs((pred - true) / true))
28 |
29 |
30 | def MSPE(pred, true):
31 | return np.mean(np.square((pred - true) / true))
32 |
33 |
34 | def metric(pred, true):
35 | mae = MAE(pred, true)
36 | mse = MSE(pred, true)
37 | rmse = RMSE(pred, true)
38 | mape = MAPE(pred, true)
39 | mspe = MSPE(pred, true)
40 |
41 | return mae, mse, rmse, mape, mspe
42 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 DongyangLi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: BCI
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - conda-forge
6 | - defaults
7 |
8 | dependencies:
9 | # Python version
10 | - python=3.12
11 |
12 | # Core tools
13 | - pip
14 | - ipykernel
15 | - ipython
16 | - jupyter
17 | - jupyterlab
18 |
19 | # Scientific computing (conda)
20 | - numpy=1.26.4
21 | - pandas=2.3.3
22 | - matplotlib=3.10.7
23 | - scikit-learn=1.6.1
24 | - scipy=1.15.3
25 | - seaborn=0.13.2
26 | - h5py=3.13.0
27 |
28 | # System utilities
29 | - pillow
30 | - pyyaml
31 | - tqdm=4.67.1
32 | - ffmpeg
33 |
34 | # PyTorch (CUDA 12.4)
35 | - pytorch=2.5.0
36 | - torchvision=0.20.0
37 | - torchaudio=2.5.0
38 | - pytorch-cuda=12.4
39 |
40 | # Pip dependencies
41 | - pip:
42 | # Hugging Face Ecosystem
43 | - transformers==4.36.0
44 | - diffusers==0.30.0
45 | - accelerate==1.5.2
46 | - huggingface-hub==0.30.2
47 |
48 | # EEG Processing
49 | - braindecode==0.8.1
50 | - mne==1.9.0
51 |
52 | # CLIP and Vision
53 | - git+https://github.com/openai/CLIP.git
54 | - open-clip-torch
55 |
56 | # Image Generation
57 | - dalle2-pytorch==1.15.6
58 | - pytorch-msssim==1.0.0
59 |
60 | # Deep Learning Utilities
61 | - einops==0.8.1
62 | - info-nce-pytorch==0.1.0
63 | - reformer-pytorch==1.4.4
64 |
65 | # Logging and Tracking
66 | - wandb==0.19.10
67 |
68 | # Image Processing
69 | - imageio==2.37.0
70 | - kornia==0.8.0
71 |
72 | # Utilities
73 | - ftfy==6.3.1
74 | - regex==2024.11.6
75 | - clip-retrieval
76 |
--------------------------------------------------------------------------------
/models/subject_layers/StandardNorm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class Normalize(nn.Module):
6 | def __init__(self, num_features: int, eps=1e-5, affine=False, subtract_last=False, non_norm=False):
7 | """
8 | :param num_features: the number of features or channels
9 | :param eps: a value added for numerical stability
10 | :param affine: if True, RevIN has learnable affine parameters
11 | """
12 | super(Normalize, self).__init__()
13 | self.num_features = num_features
14 | self.eps = eps
15 | self.affine = affine
16 | self.subtract_last = subtract_last
17 | self.non_norm = non_norm
18 | if self.affine:
19 | self._init_params()
20 |
21 | def forward(self, x, mode: str):
22 | if mode == 'norm':
23 | self._get_statistics(x)
24 | x = self._normalize(x)
25 | elif mode == 'denorm':
26 | x = self._denormalize(x)
27 | else:
28 | raise NotImplementedError
29 | return x
30 |
31 | def _init_params(self):
32 | # initialize RevIN params: (C,)
33 | self.affine_weight = nn.Parameter(torch.ones(self.num_features))
34 | self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
35 |
36 | def _get_statistics(self, x):
37 | dim2reduce = tuple(range(1, x.ndim - 1))
38 | if self.subtract_last:
39 | self.last = x[:, -1, :].unsqueeze(1)
40 | else:
41 | self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
42 | self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
43 |
44 | def _normalize(self, x):
45 | if self.non_norm:
46 | return x
47 | if self.subtract_last:
48 | x = x - self.last
49 | else:
50 | x = x - self.mean
51 | x = x / self.stdev
52 | if self.affine:
53 | x = x * self.affine_weight
54 | x = x + self.affine_bias
55 | return x
56 |
57 | def _denormalize(self, x):
58 | if self.non_norm:
59 | return x
60 | if self.affine:
61 | x = x - self.affine_bias
62 | x = x / (self.affine_weight + self.eps * self.eps)
63 | x = x * self.stdev
64 | if self.subtract_last:
65 | x = x + self.last
66 | else:
67 | x = x + self.mean
68 | return x
69 |
--------------------------------------------------------------------------------
/models/subject_layers/Conv_Blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class Inception_Block_V1(nn.Module):
6 | def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True):
7 | super(Inception_Block_V1, self).__init__()
8 | self.in_channels = in_channels
9 | self.out_channels = out_channels
10 | self.num_kernels = num_kernels
11 | kernels = []
12 | for i in range(self.num_kernels):
13 | kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=2 * i + 1, padding=i))
14 | self.kernels = nn.ModuleList(kernels)
15 | if init_weight:
16 | self._initialize_weights()
17 |
18 | def _initialize_weights(self):
19 | for m in self.modules():
20 | if isinstance(m, nn.Conv2d):
21 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
22 | if m.bias is not None:
23 | nn.init.constant_(m.bias, 0)
24 |
25 | def forward(self, x):
26 | res_list = []
27 | for i in range(self.num_kernels):
28 | res_list.append(self.kernels[i](x))
29 | res = torch.stack(res_list, dim=-1).mean(-1)
30 | return res
31 |
32 |
33 | class Inception_Block_V2(nn.Module):
34 | def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True):
35 | super(Inception_Block_V2, self).__init__()
36 | self.in_channels = in_channels
37 | self.out_channels = out_channels
38 | self.num_kernels = num_kernels
39 | kernels = []
40 | for i in range(self.num_kernels // 2):
41 | kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=[1, 2 * i + 3], padding=[0, i + 1]))
42 | kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=[2 * i + 3, 1], padding=[i + 1, 0]))
43 | kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=1))
44 | self.kernels = nn.ModuleList(kernels)
45 | if init_weight:
46 | self._initialize_weights()
47 |
48 | def _initialize_weights(self):
49 | for m in self.modules():
50 | if isinstance(m, nn.Conv2d):
51 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
52 | if m.bias is not None:
53 | nn.init.constant_(m.bias, 0)
54 |
55 | def forward(self, x):
56 | res_list = []
57 | for i in range(self.num_kernels + 1):
58 | res_list.append(self.kernels[i](x))
59 | res = torch.stack(res_list, dim=-1).mean(-1)
60 | return res
61 |
--------------------------------------------------------------------------------
/setup.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # ============================================================
3 | # EEG Image Decode - Environment Setup Script
4 | # ============================================================
5 | # This script creates a conda environment with all dependencies
6 | # for reproducing the experiments in the paper.
7 | #
8 | # Usage: . setup.sh
9 | # ============================================================
10 |
11 | set -e
12 |
13 | ENV_NAME="BCI"
14 | PYTHON_VERSION="3.12"
15 |
16 | echo "============================================================"
17 | echo "Creating conda environment: $ENV_NAME with Python $PYTHON_VERSION"
18 | echo "============================================================"
19 |
20 | # Create conda environment
21 | conda create -n $ENV_NAME python=$PYTHON_VERSION -y
22 | conda activate $ENV_NAME
23 |
24 | echo "Installing base packages via conda..."
25 | conda install numpy matplotlib tqdm scikit-image jupyterlab -y
26 | conda install -c conda-forge accelerate -y
27 |
28 | echo "Installing PyTorch ecosystem (CUDA 12.4)..."
29 | pip install torch==2.5.0 torchvision==0.20.0 torchaudio==2.5.0 --index-url https://download.pytorch.org/whl/cu124
30 |
31 | echo "Installing Hugging Face packages..."
32 | pip install transformers==4.36.0
33 | pip install diffusers==0.30.0
34 | pip install huggingface-hub==0.30.2
35 | pip install accelerate==1.5.2
36 |
37 | echo "Installing CLIP packages..."
38 | pip install git+https://github.com/openai/CLIP.git
39 | pip install open_clip_torch
40 | pip install clip-retrieval
41 |
42 | echo "Installing EEG processing packages..."
43 | pip install braindecode==0.8.1
44 | pip install mne==1.9.0
45 |
46 | echo "Installing image generation packages..."
47 | pip install dalle2-pytorch==1.15.6
48 | pip install pytorch-msssim==1.0.0
49 | pip install kornia==0.8.0
50 |
51 | echo "Installing deep learning utilities..."
52 | pip install einops==0.8.1
53 | pip install info-nce-pytorch==0.1.0
54 | pip install reformer_pytorch==1.4.4
55 |
56 | echo "Installing logging and visualization..."
57 | pip install wandb==0.19.10
58 | pip install seaborn==0.13.2
59 |
60 | echo "Installing other utilities..."
61 | pip install ftfy==6.3.1
62 | pip install regex==2024.11.6
63 | pip install h5py==3.13.0
64 | pip install pandas==2.3.3
65 | pip install imageio==2.37.0
66 | pip install scipy==1.15.3
67 | pip install scikit-learn==1.6.1
68 |
69 | echo "============================================================"
70 | echo "Environment setup complete!"
71 | echo "Activate with: conda activate $ENV_NAME"
72 | echo "============================================================"
73 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # ========================================
2 | # IDE / 编辑器
3 | # ========================================
4 | .history/
5 | .vscode/
6 | .idea/
7 | *.swp
8 | *.swo
9 | *~
10 |
11 | # ========================================
12 | # 操作系统
13 | # ========================================
14 | .DS_Store
15 | Thumbs.db
16 | desktop.ini
17 |
18 | # ========================================
19 | # Python
20 | # ========================================
21 | __pycache__/
22 | *.py[cod]
23 | *$py.class
24 | *.so
25 | .Python
26 | *.egg
27 | *.egg-info/
28 | dist/
29 | build/
30 | eggs/
31 | .eggs/
32 | *.manifest
33 | *.spec
34 |
35 | # 虚拟环境
36 | env/
37 | venv/
38 | .venv/
39 | ENV/
40 |
41 | # ========================================
42 | # Jupyter Notebook
43 | # ========================================
44 | .ipynb_checkpoints/
45 | */.ipynb_checkpoints/
46 |
47 | # ========================================
48 | # 深度学习 / 模型权重
49 | # ========================================
50 | *.ckpt
51 | *.pth
52 | *.pt
53 | *.h5
54 | *.pb
55 | *.onnx
56 | *.safetensors
57 | checkpoints/
58 | ckpts/
59 | weights/
60 |
61 | # ========================================
62 | # 实验输出 / 日志
63 | # ========================================
64 | outputs/
65 | old_outputs/
66 | logs/
67 | wandb/
68 | runs/
69 | mlruns/
70 | lightning_logs/
71 | *.log
72 |
73 | # ========================================
74 | # 数据文件(通常太大不应上传)
75 | # ========================================
76 | # 如需上传特定数据,可用 !data/important.npy 排除
77 | data/
78 | datasets/
79 | *.npy
80 | *.npz
81 | *.mat
82 | *.pkl
83 | *.pickle
84 | *.hdf5
85 |
86 | # ========================================
87 | # 生成的图像 / 结果
88 | # ========================================
89 | generated_imgs/
90 | generated_imgs_tensor/
91 | generation_metric_outputs/
92 | tab_generation_metric_outputs/
93 |
94 | # ========================================
95 | # 项目特定目录
96 | # ========================================
97 | # Generation 模块
98 | Generation/ckpts/
99 | Generation/fintune_ckpts/
100 | Generation/CLIP-dissect/
101 | Generation/centralRepo/
102 | Generation/diffusers/
103 | Generation/generated_imgs/
104 | Generation/generated_imgs_tensor/
105 | Generation/Generation/
106 | Generation/generation_metric_outputs/
107 | Generation/LAVIS/
108 | Generation/models/
109 | Generation/old_outputs/
110 | Generation/outputs/
111 | Generation/tab_generation_metric_outputs/
112 | Generation/THINGS/
113 | Generation/Visualize/
114 | Generation/wandb/
115 |
116 | # Retrieval 模块
117 | Retrieval/centralRepo/
118 | Retrieval/CLIP-dissect/
119 | Retrieval/conditional_outputs/
120 | Retrieval/LAVIS/
121 | Retrieval/old_outputs/
122 | Retrieval/outputs/
123 | Retrieval/Outputs/
124 | Retrieval/sliding_window_outputs/
125 | Retrieval/Visualize/
126 | Retrieval/wandb/
127 |
128 | # ========================================
129 | # 其他
130 | # ========================================
131 | *.tmp
132 | *.bak
133 | *.cache
134 | .cache/
135 | tmp/
136 | temp/
137 |
--------------------------------------------------------------------------------
/EEG-preprocessing/preprocessing.py:
--------------------------------------------------------------------------------
1 | """
2 | Refer to the code of Things-EEG2 but with a few differences.
3 | Many thanks!
4 | https://www.sciencedirect.com/science/article/pii/S1053811922008758
5 | """
6 |
7 | """Preprocess the raw EEG data: channel selection, epoching, frequency
8 | downsampling, baseline correction, multivariate noise normalization (MVNN),
9 | sorting of the data image conditions and reshaping the data to:
10 | Image conditions × EEG repetitions × EEG channels × EEG time points.
11 | Then, the data of both test and training EEG partitions is saved.
12 |
13 | Parameters
14 | ----------
15 | sub : int
16 | Used subject.
17 | n_ses : int
18 | Number of EEG sessions.
19 | sfreq : int
20 | Downsampling frequency.
21 | mvnn_dim : str
22 | Whether to compute the MVNN covariace matrices for each time point
23 | ('time') or for each epoch/repetition ('epochs').
24 | project_dir : str
25 | Directory of the project folder.
26 |
27 | """
28 |
29 | import argparse
30 | from preprocessing_utils import epoching
31 | from preprocessing_utils import mvnn
32 | from preprocessing_utils import save_prepr
33 |
34 |
35 | # =============================================================================
36 | # Input arguments
37 | # =============================================================================
38 | parser = argparse.ArgumentParser()
39 | parser.add_argument('--sub', default=10, type=int)
40 | parser.add_argument('--n_ses', default=4, type=int)
41 | parser.add_argument('--sfreq', default=250, type=int)
42 | parser.add_argument('--mvnn_dim', default='epochs', type=str)
43 | parser.add_argument('--project_dir', default='/home/Data/Things-EEG2/', type=str)
44 | args = parser.parse_args()
45 |
46 | print('>>> EEG data preprocessing <<<')
47 | print('\nInput arguments:')
48 | for key, val in vars(args).items():
49 | print('{:16} {}'.format(key, val))
50 |
51 | # Set random seed for reproducible results
52 | seed = 20200220
53 |
54 |
55 | # =============================================================================
56 | # Epoch and sort the data
57 | # =============================================================================
58 | # Channel selection, epoching, baseline correction and frequency downsampling of
59 | # the test and training data partitions.
60 | # Then, the conditions are sorted and the EEG data is reshaped to:
61 | # Image conditions × EGG repetitions × EEG channels × EEG time points
62 | # This step is applied independently to the data of each partition and session.
63 | epoched_test, _, ch_names, times = epoching(args, 'test', seed)
64 | epoched_train, img_conditions_train, _, _ = epoching(args, 'training', seed)
65 |
66 |
67 | # =============================================================================
68 | # Multivariate Noise Normalization
69 | # =============================================================================
70 | # MVNN is applied independently to the data of each session.
71 | whitened_test, whitened_train = mvnn(args, epoched_test, epoched_train)
72 | del epoched_test, epoched_train
73 |
74 |
75 | # =============================================================================
76 | # Merge and save the preprocessed data
77 | # =============================================================================
78 | # In this step the data of all sessions is merged into the shape:
79 | # Image conditions × EGG repetitions × EEG channels × EEG time points
80 | # Then, the preprocessed data of the test and training data partitions is saved.
81 | save_prepr(args, whitened_test, whitened_train, img_conditions_train, ch_names,
82 | times, seed)
83 |
84 |
--------------------------------------------------------------------------------
/models/utils/tools.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import torch
5 | import matplotlib.pyplot as plt
6 | import pandas as pd
7 |
8 | plt.switch_backend('agg')
9 |
10 |
11 | def adjust_learning_rate(optimizer, epoch, args):
12 | # lr = args.learning_rate * (0.2 ** (epoch // 2))
13 | if args.lradj == 'type1':
14 | lr_adjust = {epoch: args.learning_rate * (0.5 ** ((epoch - 1) // 1))}
15 | elif args.lradj == 'type2':
16 | lr_adjust = {
17 | 2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6,
18 | 10: 5e-7, 15: 1e-7, 20: 5e-8
19 | }
20 | if epoch in lr_adjust.keys():
21 | lr = lr_adjust[epoch]
22 | for param_group in optimizer.param_groups:
23 | param_group['lr'] = lr
24 | print('Updating learning rate to {}'.format(lr))
25 |
26 |
27 | class EarlyStopping:
28 | def __init__(self, patience=7, verbose=False, delta=0):
29 | self.patience = patience
30 | self.verbose = verbose
31 | self.counter = 0
32 | self.best_score = None
33 | self.early_stop = False
34 | self.val_loss_min = np.Inf
35 | self.delta = delta
36 |
37 | def __call__(self, val_loss, model, path):
38 | score = -val_loss
39 | if self.best_score is None:
40 | self.best_score = score
41 | self.save_checkpoint(val_loss, model, path)
42 | elif score < self.best_score + self.delta:
43 | self.counter += 1
44 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
45 | if self.counter >= self.patience:
46 | self.early_stop = True
47 | else:
48 | self.best_score = score
49 | self.save_checkpoint(val_loss, model, path)
50 | self.counter = 0
51 |
52 | def save_checkpoint(self, val_loss, model, path):
53 | if self.verbose:
54 | print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
55 | torch.save(model.state_dict(), path + '/' + 'checkpoint.pth')
56 | self.val_loss_min = val_loss
57 |
58 |
59 | class dotdict(dict):
60 | """dot.notation access to dictionary attributes"""
61 | __getattr__ = dict.get
62 | __setattr__ = dict.__setitem__
63 | __delattr__ = dict.__delitem__
64 |
65 |
66 | class StandardScaler():
67 | def __init__(self, mean, std):
68 | self.mean = mean
69 | self.std = std
70 |
71 | def transform(self, data):
72 | return (data - self.mean) / self.std
73 |
74 | def inverse_transform(self, data):
75 | return (data * self.std) + self.mean
76 |
77 |
78 | def visual(true, preds=None, name='./pic/test.pdf'):
79 | """
80 | Results visualization
81 | """
82 | plt.figure()
83 | plt.plot(true, label='GroundTruth', linewidth=2)
84 | if preds is not None:
85 | plt.plot(preds, label='Prediction', linewidth=2)
86 | plt.legend()
87 | plt.savefig(name, bbox_inches='tight')
88 |
89 |
90 | def adjustment(gt, pred):
91 | anomaly_state = False
92 | for i in range(len(gt)):
93 | if gt[i] == 1 and pred[i] == 1 and not anomaly_state:
94 | anomaly_state = True
95 | for j in range(i, 0, -1):
96 | if gt[j] == 0:
97 | break
98 | else:
99 | if pred[j] == 0:
100 | pred[j] = 1
101 | for j in range(i, len(gt)):
102 | if gt[j] == 0:
103 | break
104 | else:
105 | if pred[j] == 0:
106 | pred[j] = 1
107 | elif gt[i] == 0:
108 | anomaly_state = False
109 | if anomaly_state:
110 | pred[i] = 1
111 | return gt, pred
112 |
113 |
114 | def cal_accuracy(y_pred, y_true):
115 | return np.mean(y_pred == y_true)
116 |
--------------------------------------------------------------------------------
/models/subject_layers/Crossformer_EncDec.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from einops import rearrange, repeat
4 | from layers.SelfAttention_Family import TwoStageAttentionLayer
5 |
6 |
7 | class SegMerging(nn.Module):
8 | def __init__(self, d_model, win_size, norm_layer=nn.LayerNorm):
9 | super().__init__()
10 | self.d_model = d_model
11 | self.win_size = win_size
12 | self.linear_trans = nn.Linear(win_size * d_model, d_model)
13 | self.norm = norm_layer(win_size * d_model)
14 |
15 | def forward(self, x):
16 | batch_size, ts_d, seg_num, d_model = x.shape
17 | pad_num = seg_num % self.win_size
18 | if pad_num != 0:
19 | pad_num = self.win_size - pad_num
20 | x = torch.cat((x, x[:, :, -pad_num:, :]), dim=-2)
21 |
22 | seg_to_merge = []
23 | for i in range(self.win_size):
24 | seg_to_merge.append(x[:, :, i::self.win_size, :])
25 | x = torch.cat(seg_to_merge, -1)
26 |
27 | x = self.norm(x)
28 | x = self.linear_trans(x)
29 |
30 | return x
31 |
32 |
33 | class scale_block(nn.Module):
34 | def __init__(self, configs, win_size, d_model, n_heads, d_ff, depth, dropout, \
35 | seg_num=10, factor=10):
36 | super(scale_block, self).__init__()
37 |
38 | if win_size > 1:
39 | self.merge_layer = SegMerging(d_model, win_size, nn.LayerNorm)
40 | else:
41 | self.merge_layer = None
42 |
43 | self.encode_layers = nn.ModuleList()
44 |
45 | for i in range(depth):
46 | self.encode_layers.append(TwoStageAttentionLayer(configs, seg_num, factor, d_model, n_heads, \
47 | d_ff, dropout))
48 |
49 | def forward(self, x, attn_mask=None, tau=None, delta=None):
50 | _, ts_dim, _, _ = x.shape
51 |
52 | if self.merge_layer is not None:
53 | x = self.merge_layer(x)
54 |
55 | for layer in self.encode_layers:
56 | x = layer(x)
57 |
58 | return x, None
59 |
60 |
61 | class Encoder(nn.Module):
62 | def __init__(self, attn_layers):
63 | super(Encoder, self).__init__()
64 | self.encode_blocks = nn.ModuleList(attn_layers)
65 |
66 | def forward(self, x):
67 | encode_x = []
68 | encode_x.append(x)
69 |
70 | for block in self.encode_blocks:
71 | x, attns = block(x)
72 | encode_x.append(x)
73 |
74 | return encode_x, None
75 |
76 |
77 | class DecoderLayer(nn.Module):
78 | def __init__(self, self_attention, cross_attention, seg_len, d_model, d_ff=None, dropout=0.1):
79 | super(DecoderLayer, self).__init__()
80 | self.self_attention = self_attention
81 | self.cross_attention = cross_attention
82 | self.norm1 = nn.LayerNorm(d_model)
83 | self.norm2 = nn.LayerNorm(d_model)
84 | self.dropout = nn.Dropout(dropout)
85 | self.MLP1 = nn.Sequential(nn.Linear(d_model, d_model),
86 | nn.GELU(),
87 | nn.Linear(d_model, d_model))
88 | self.linear_pred = nn.Linear(d_model, seg_len)
89 |
90 | def forward(self, x, cross):
91 | batch = x.shape[0]
92 | x = self.self_attention(x)
93 | x = rearrange(x, 'b ts_d out_seg_num d_model -> (b ts_d) out_seg_num d_model')
94 |
95 | cross = rearrange(cross, 'b ts_d in_seg_num d_model -> (b ts_d) in_seg_num d_model')
96 | tmp, attn = self.cross_attention(x, cross, cross, None, None, None,)
97 | x = x + self.dropout(tmp)
98 | y = x = self.norm1(x)
99 | y = self.MLP1(y)
100 | dec_output = self.norm2(x + y)
101 |
102 | dec_output = rearrange(dec_output, '(b ts_d) seg_dec_num d_model -> b ts_d seg_dec_num d_model', b=batch)
103 | layer_predict = self.linear_pred(dec_output)
104 | layer_predict = rearrange(layer_predict, 'b out_d seg_num seg_len -> b (out_d seg_num) seg_len')
105 |
106 | return dec_output, layer_predict
107 |
108 |
109 | class Decoder(nn.Module):
110 | def __init__(self, layers):
111 | super(Decoder, self).__init__()
112 | self.decode_layers = nn.ModuleList(layers)
113 |
114 |
115 | def forward(self, x, cross):
116 | final_predict = None
117 | i = 0
118 |
119 | ts_d = x.shape[1]
120 | for layer in self.decode_layers:
121 | cross_enc = cross[i]
122 | x, layer_predict = layer(x, cross_enc)
123 | if final_predict is None:
124 | final_predict = layer_predict
125 | else:
126 | final_predict = final_predict + layer_predict
127 | i += 1
128 |
129 | final_predict = rearrange(final_predict, 'b (out_d seg_num) seg_len -> b (seg_num seg_len) out_d', out_d=ts_d)
130 |
131 | return final_predict
132 |
--------------------------------------------------------------------------------
/models/utils/timefeatures.py:
--------------------------------------------------------------------------------
1 | # From: gluonts/src/gluonts/time_feature/_base.py
2 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License").
5 | # You may not use this file except in compliance with the License.
6 | # A copy of the License is located at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # or in the "license" file accompanying this file. This file is distributed
11 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12 | # express or implied. See the License for the specific language governing
13 | # permissions and limitations under the License.
14 |
15 | from typing import List
16 |
17 | import numpy as np
18 | import pandas as pd
19 | from pandas.tseries import offsets
20 | from pandas.tseries.frequencies import to_offset
21 |
22 |
23 | class TimeFeature:
24 | def __init__(self):
25 | pass
26 |
27 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
28 | pass
29 |
30 | def __repr__(self):
31 | return self.__class__.__name__ + "()"
32 |
33 |
34 | class SecondOfMinute(TimeFeature):
35 | """Minute of hour encoded as value between [-0.5, 0.5]"""
36 |
37 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
38 | return index.second / 59.0 - 0.5
39 |
40 |
41 | class MinuteOfHour(TimeFeature):
42 | """Minute of hour encoded as value between [-0.5, 0.5]"""
43 |
44 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
45 | return index.minute / 59.0 - 0.5
46 |
47 |
48 | class HourOfDay(TimeFeature):
49 | """Hour of day encoded as value between [-0.5, 0.5]"""
50 |
51 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
52 | return index.hour / 23.0 - 0.5
53 |
54 |
55 | class DayOfWeek(TimeFeature):
56 | """Hour of day encoded as value between [-0.5, 0.5]"""
57 |
58 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
59 | return index.dayofweek / 6.0 - 0.5
60 |
61 |
62 | class DayOfMonth(TimeFeature):
63 | """Day of month encoded as value between [-0.5, 0.5]"""
64 |
65 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
66 | return (index.day - 1) / 30.0 - 0.5
67 |
68 |
69 | class DayOfYear(TimeFeature):
70 | """Day of year encoded as value between [-0.5, 0.5]"""
71 |
72 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
73 | return (index.dayofyear - 1) / 365.0 - 0.5
74 |
75 |
76 | class MonthOfYear(TimeFeature):
77 | """Month of year encoded as value between [-0.5, 0.5]"""
78 |
79 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
80 | return (index.month - 1) / 11.0 - 0.5
81 |
82 |
83 | class WeekOfYear(TimeFeature):
84 | """Week of year encoded as value between [-0.5, 0.5]"""
85 |
86 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
87 | return (index.isocalendar().week - 1) / 52.0 - 0.5
88 |
89 |
90 | def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]:
91 | """
92 | Returns a list of time features that will be appropriate for the given frequency string.
93 | Parameters
94 | ----------
95 | freq_str
96 | Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc.
97 | """
98 |
99 | features_by_offsets = {
100 | offsets.YearEnd: [],
101 | offsets.QuarterEnd: [MonthOfYear],
102 | offsets.MonthEnd: [MonthOfYear],
103 | offsets.Week: [DayOfMonth, WeekOfYear],
104 | offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear],
105 | offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear],
106 | offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear],
107 | offsets.Minute: [
108 | MinuteOfHour,
109 | HourOfDay,
110 | DayOfWeek,
111 | DayOfMonth,
112 | DayOfYear,
113 | ],
114 | offsets.Second: [
115 | SecondOfMinute,
116 | MinuteOfHour,
117 | HourOfDay,
118 | DayOfWeek,
119 | DayOfMonth,
120 | DayOfYear,
121 | ],
122 | }
123 |
124 | offset = to_offset(freq_str)
125 |
126 | for offset_type, feature_classes in features_by_offsets.items():
127 | if isinstance(offset, offset_type):
128 | return [cls() for cls in feature_classes]
129 |
130 | supported_freq_msg = f"""
131 | Unsupported frequency {freq_str}
132 | The following frequencies are supported:
133 | Y - yearly
134 | alias: A
135 | M - monthly
136 | W - weekly
137 | D - daily
138 | B - business days
139 | H - hourly
140 | T - minutely
141 | alias: min
142 | S - secondly
143 | """
144 | raise RuntimeError(supported_freq_msg)
145 |
146 |
147 | def time_features(dates, freq='h'):
148 | return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)])
149 |
--------------------------------------------------------------------------------
/models/subject_layers/Transformer_EncDec.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class ConvLayer(nn.Module):
7 | def __init__(self, c_in):
8 | super(ConvLayer, self).__init__()
9 | self.downConv = nn.Conv1d(in_channels=c_in,
10 | out_channels=c_in,
11 | kernel_size=3,
12 | padding=2,
13 | padding_mode='circular')
14 | self.norm = nn.BatchNorm1d(c_in)
15 | self.activation = nn.ELU()
16 | self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
17 |
18 | def forward(self, x):
19 | x = self.downConv(x.permute(0, 2, 1))
20 | x = self.norm(x)
21 | x = self.activation(x)
22 | x = self.maxPool(x)
23 | x = x.transpose(1, 2)
24 | return x
25 |
26 |
27 | class EncoderLayer(nn.Module):
28 | def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
29 | super(EncoderLayer, self).__init__()
30 | d_ff = d_ff or 4 * d_model
31 | self.attention = attention
32 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
33 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
34 | self.norm1 = nn.LayerNorm(d_model)
35 | self.norm2 = nn.LayerNorm(d_model)
36 | self.dropout = nn.Dropout(dropout)
37 | self.activation = F.relu if activation == "relu" else F.gelu
38 |
39 | def forward(self, x, attn_mask=None, tau=None, delta=None):
40 | new_x, attn = self.attention(
41 | x, x, x,
42 | attn_mask=attn_mask,
43 | tau=tau, delta=delta
44 | )
45 | x = x + self.dropout(new_x)
46 |
47 | y = x = self.norm1(x)
48 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
49 | y = self.dropout(self.conv2(y).transpose(-1, 1))
50 |
51 | return self.norm2(x + y), attn
52 |
53 |
54 | class Encoder(nn.Module):
55 | def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
56 | super(Encoder, self).__init__()
57 | self.attn_layers = nn.ModuleList(attn_layers)
58 | self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
59 | self.norm = norm_layer
60 |
61 | def forward(self, x, attn_mask=None, tau=None, delta=None):
62 | # x [B, L, D]
63 | attns = []
64 | if self.conv_layers is not None:
65 | for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)):
66 | delta = delta if i == 0 else None
67 | x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
68 | x = conv_layer(x)
69 | attns.append(attn)
70 | x, attn = self.attn_layers[-1](x, tau=tau, delta=None)
71 | attns.append(attn)
72 | else:
73 | for attn_layer in self.attn_layers:
74 | x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
75 | attns.append(attn)
76 |
77 | if self.norm is not None:
78 | x = self.norm(x)
79 |
80 | return x, attns
81 |
82 |
83 | class DecoderLayer(nn.Module):
84 | def __init__(self, self_attention, cross_attention, d_model, d_ff=None,
85 | dropout=0.1, activation="relu"):
86 | super(DecoderLayer, self).__init__()
87 | d_ff = d_ff or 4 * d_model
88 | self.self_attention = self_attention
89 | self.cross_attention = cross_attention
90 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
91 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
92 | self.norm1 = nn.LayerNorm(d_model)
93 | self.norm2 = nn.LayerNorm(d_model)
94 | self.norm3 = nn.LayerNorm(d_model)
95 | self.dropout = nn.Dropout(dropout)
96 | self.activation = F.relu if activation == "relu" else F.gelu
97 |
98 | def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
99 | x = x + self.dropout(self.self_attention(
100 | x, x, x,
101 | attn_mask=x_mask,
102 | tau=tau, delta=None
103 | )[0])
104 | x = self.norm1(x)
105 |
106 | x = x + self.dropout(self.cross_attention(
107 | x, cross, cross,
108 | attn_mask=cross_mask,
109 | tau=tau, delta=delta
110 | )[0])
111 |
112 | y = x = self.norm2(x)
113 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
114 | y = self.dropout(self.conv2(y).transpose(-1, 1))
115 |
116 | return self.norm3(x + y)
117 |
118 |
119 | class Decoder(nn.Module):
120 | def __init__(self, layers, norm_layer=None, projection=None):
121 | super(Decoder, self).__init__()
122 | self.layers = nn.ModuleList(layers)
123 | self.norm = norm_layer
124 | self.projection = projection
125 |
126 | def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
127 | for layer in self.layers:
128 | x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta)
129 |
130 | if self.norm is not None:
131 | x = self.norm(x)
132 |
133 | if self.projection is not None:
134 | x = self.projection(x)
135 | return x
136 |
--------------------------------------------------------------------------------
/models/loss.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import logging
9 | import torch
10 | import torch.distributed.nn
11 | from torch import distributed as dist, nn as nn
12 | from torch.nn import functional as F
13 |
14 | try:
15 | import horovod.torch as hvd
16 | except ImportError:
17 | hvd = None
18 |
19 |
20 | def gather_features(
21 | image_features,
22 | text_features,
23 | local_loss=False,
24 | gather_with_grad=False,
25 | rank=0,
26 | world_size=1,
27 | use_horovod=False,
28 | ):
29 | if use_horovod:
30 | assert hvd is not None, "Please install horovod"
31 | if gather_with_grad:
32 | all_image_features = hvd.allgather(image_features)
33 | all_text_features = hvd.allgather(text_features)
34 | else:
35 | with torch.no_grad():
36 | all_image_features = hvd.allgather(image_features)
37 | all_text_features = hvd.allgather(text_features)
38 | if not local_loss:
39 | # ensure grads for local rank when all_* features don't have a gradient
40 | gathered_image_features = list(
41 | all_image_features.chunk(world_size, dim=0)
42 | )
43 | gathered_text_features = list(
44 | all_text_features.chunk(world_size, dim=0)
45 | )
46 | gathered_image_features[rank] = image_features
47 | gathered_text_features[rank] = text_features
48 | all_image_features = torch.cat(gathered_image_features, dim=0)
49 | all_text_features = torch.cat(gathered_text_features, dim=0)
50 | else:
51 | # We gather tensors from all gpus
52 | if gather_with_grad:
53 | all_image_features = torch.cat(
54 | torch.distributed.nn.all_gather(image_features), dim=0
55 | )
56 | all_text_features = torch.cat(
57 | torch.distributed.nn.all_gather(text_features), dim=0
58 | )
59 | else:
60 | gathered_image_features = [
61 | torch.zeros_like(image_features) for _ in range(world_size)
62 | ]
63 | gathered_text_features = [
64 | torch.zeros_like(text_features) for _ in range(world_size)
65 | ]
66 | dist.all_gather(gathered_image_features, image_features)
67 | dist.all_gather(gathered_text_features, text_features)
68 | if not local_loss:
69 | # ensure grads for local rank when all_* features don't have a gradient
70 | gathered_image_features[rank] = image_features
71 | gathered_text_features[rank] = text_features
72 | all_image_features = torch.cat(gathered_image_features, dim=0)
73 | all_text_features = torch.cat(gathered_text_features, dim=0)
74 |
75 | return all_image_features, all_text_features
76 |
77 |
78 | class ClipLoss(nn.Module):
79 | def __init__(
80 | self,
81 | local_loss=False,
82 | gather_with_grad=False,
83 | cache_labels=False,
84 | rank=0,
85 | world_size=1,
86 | use_horovod=False,
87 | ):
88 | super().__init__()
89 | self.local_loss = local_loss
90 | self.gather_with_grad = gather_with_grad
91 | self.cache_labels = cache_labels
92 | self.rank = rank
93 | self.world_size = world_size
94 | self.use_horovod = use_horovod
95 |
96 | # cache state
97 | self.prev_num_logits = 0
98 | self.labels = {}
99 |
100 | def forward(self, image_features, text_features, logit_scale):
101 | device = image_features.device
102 | if self.world_size > 1:
103 | all_image_features, all_text_features = gather_features(
104 | image_features,
105 | text_features,
106 | self.local_loss,
107 | self.gather_with_grad,
108 | self.rank,
109 | self.world_size,
110 | self.use_horovod,
111 | )
112 |
113 | if self.local_loss:
114 | logits_per_image = logit_scale * image_features @ all_text_features.T
115 | logits_per_text = logit_scale * text_features @ all_image_features.T
116 | else:
117 | logits_per_image = (
118 | logit_scale * all_image_features @ all_text_features.T
119 | )
120 | logits_per_text = logits_per_image.T
121 | else:
122 | logits_per_image = logit_scale * image_features @ text_features.T
123 | logits_per_text = logit_scale * text_features @ image_features.T
124 |
125 | # calculated ground-truth and cache if enabled
126 | num_logits = logits_per_image.shape[0]
127 | if self.prev_num_logits != num_logits or device not in self.labels:
128 | labels = torch.arange(num_logits, device=device, dtype=torch.long)
129 | if self.world_size > 1 and self.local_loss:
130 | labels = labels + num_logits * self.rank
131 | if self.cache_labels:
132 | self.labels[device] = labels
133 | self.prev_num_logits = num_logits
134 | else:
135 | labels = self.labels[device]
136 |
137 | total_loss = (
138 | F.cross_entropy(logits_per_image, labels)
139 | + F.cross_entropy(logits_per_text, labels)
140 | ) / 2
141 | return total_loss
142 |
--------------------------------------------------------------------------------
/Generation/image_adapter.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from functools import partial\n",
10 | "\n",
11 | "from transformers import CLIPVisionModel \n",
12 | "import torch\n",
13 | "from torch import nn\n",
14 | "from torchvision import transforms\n",
15 | "from PIL import Image\n",
16 | "\n",
17 | "import torch\n",
18 | "import torch.nn as nn\n",
19 | "from transformers import CLIPVisionModel\n",
20 | "from torchvision import transforms\n",
21 | "\n",
22 | "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
23 | "\n"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": null,
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "train_pixel_img_feature = torch.load('/root/autodl-tmp/Workspace/EEG_caption/ViT-L-14_features_GIT_train.pt')['img_features']# \n",
33 | "test_pixel_img_feature = torch.load('/root/autodl-tmp/Workspace/EEG_caption/ViT-L-14_features_GIT_test.pt')['img_features']# \n",
34 | "train_img_feature = torch.load('/root/autodl-tmp/Workspace/EEG_caption/ViT-H-14_features_train.pt')['img_features'].unsqueeze(1)# \n",
35 | "test_img_feature = torch.load('/root/autodl-tmp/Workspace/EEG_caption/ViT-H-14_features_test.pt')['img_features'].unsqueeze(1)# \n"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": null,
41 | "metadata": {},
42 | "outputs": [],
43 | "source": [
44 | "test_img_feature.shape"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": null,
50 | "metadata": {},
51 | "outputs": [],
52 | "source": [
53 | "import torch\n",
54 | "import torch.nn as nn\n",
55 | "import torch.optim as optim\n",
56 | "from torch.utils.data import DataLoader, TensorDataset\n",
57 | "from einops.layers.torch import Rearrange, Reduce\n",
58 | "\n",
59 | "# Define the neural network\n",
60 | "class PixelProjector(nn.Sequential):\n",
61 | " def __init__(self, proj_dim=1024):\n",
62 | " super().__init__(\n",
63 | " Rearrange('B C L->B L C'), \n",
64 | " nn.Linear(1, 257),\n",
65 | " nn.LayerNorm(257),\n",
66 | " Rearrange('B L C->B C L'),\n",
67 | " nn.Linear(1024, 1024),\n",
68 | " nn.LayerNorm(proj_dim),\n",
69 | " )\n",
70 | " \n",
71 | " \n",
72 | "\n",
73 | "# Instantiate the model, loss function, and optimizer\n",
74 | "\n",
75 | "model = PixelProjector(proj_dim=1024).to(torch.bfloat16).to(device)\n",
76 | "criterion = nn.MSELoss()\n",
77 | "optimizer = optim.AdamW(model.parameters(), lr=0.001)\n",
78 | "\n",
79 | "# Prepare data loaders\n",
80 | "train_dataset = TensorDataset(train_img_feature, train_pixel_img_feature)\n",
81 | "test_dataset = TensorDataset(test_img_feature, test_pixel_img_feature)\n",
82 | "\n",
83 | "train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)\n",
84 | "test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)\n",
85 | "\n",
86 | "# Training loop\n",
87 | "num_epochs = 30\n",
88 | "for epoch in range(num_epochs):\n",
89 | " model.train()\n",
90 | " running_loss = 0.0\n",
91 | " for inputs, targets in train_loader:\n",
92 | " inputs, targets = inputs.to(torch.bfloat16).to(device), targets.to(torch.bfloat16).to(device)\n",
93 | " optimizer.zero_grad()\n",
94 | " outputs = model(inputs)\n",
95 | " loss = criterion(outputs, targets)\n",
96 | " loss.backward()\n",
97 | " optimizer.step()\n",
98 | " running_loss += loss.item()\n",
99 | " \n",
100 | " print(f\"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}\")\n",
101 | "\n",
102 | "# Testing loop\n",
103 | "model.eval()\n",
104 | "test_loss = 0.0\n",
105 | "with torch.no_grad():\n",
106 | " for inputs, targets in test_loader:\n",
107 | " inputs, targets = inputs.to(torch.bfloat16).to(device), targets.to(torch.bfloat16).to(device)\n",
108 | " outputs = model(inputs)\n",
109 | " loss = criterion(outputs, targets)\n",
110 | " test_loss += loss.item()\n",
111 | "\n",
112 | "print(f\"Test Loss: {test_loss/len(test_loader)}\")\n",
113 | "\n",
114 | "# Save the trained model\n",
115 | "torch.save(model.state_dict(), '/root/autodl-tmp/Workspace/EEG_caption/model_weights/PixelProjector_best.bin')\n",
116 | "print(\"Model saved as PixelProjector.bin\")\n"
117 | ]
118 | },
119 | {
120 | "cell_type": "code",
121 | "execution_count": 8,
122 | "metadata": {},
123 | "outputs": [
124 | {
125 | "name": "stdout",
126 | "output_type": "stream",
127 | "text": [
128 | "Model saved as PixelProjector.bin\n"
129 | ]
130 | }
131 | ],
132 | "source": [
133 | "# Save the trained model\n",
134 | "torch.save(model.state_dict(), '/root/autodl-tmp/Workspace/EEG_caption/model_weights/PixelProjector_best.bin')\n",
135 | "print(\"Model saved as PixelProjector.bin\")"
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": null,
141 | "metadata": {},
142 | "outputs": [],
143 | "source": []
144 | }
145 | ],
146 | "metadata": {
147 | "kernelspec": {
148 | "display_name": "BCI",
149 | "language": "python",
150 | "name": "python3"
151 | },
152 | "language_info": {
153 | "codemirror_mode": {
154 | "name": "ipython",
155 | "version": 3
156 | },
157 | "file_extension": ".py",
158 | "mimetype": "text/x-python",
159 | "name": "python",
160 | "nbconvert_exporter": "python",
161 | "pygments_lexer": "ipython3",
162 | "version": "3.10.0"
163 | }
164 | },
165 | "nbformat": 4,
166 | "nbformat_minor": 2
167 | }
168 |
--------------------------------------------------------------------------------
/models/subject_layers/AutoCorrelation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | import math
7 | from math import sqrt
8 | import os
9 |
10 |
11 | class AutoCorrelation(nn.Module):
12 | """
13 | AutoCorrelation Mechanism with the following two phases:
14 | (1) period-based dependencies discovery
15 | (2) time delay aggregation
16 | This block can replace the self-attention family mechanism seamlessly.
17 | """
18 |
19 | def __init__(self, mask_flag=True, factor=1, scale=None, attention_dropout=0.1, output_attention=False):
20 | super(AutoCorrelation, self).__init__()
21 | self.factor = factor
22 | self.scale = scale
23 | self.mask_flag = mask_flag
24 | self.output_attention = output_attention
25 | self.dropout = nn.Dropout(attention_dropout)
26 |
27 | def time_delay_agg_training(self, values, corr):
28 | """
29 | SpeedUp version of Autocorrelation (a batch-normalization style design)
30 | This is for the training phase.
31 | """
32 | head = values.shape[1]
33 | channel = values.shape[2]
34 | length = values.shape[3]
35 | # find top k
36 | top_k = int(self.factor * math.log(length))
37 | mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
38 | index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1]
39 | weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1)
40 | # update corr
41 | tmp_corr = torch.softmax(weights, dim=-1)
42 | # aggregation
43 | tmp_values = values
44 | delays_agg = torch.zeros_like(values).float()
45 | for i in range(top_k):
46 | pattern = torch.roll(tmp_values, -int(index[i]), -1)
47 | delays_agg = delays_agg + pattern * \
48 | (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))
49 | return delays_agg
50 |
51 | def time_delay_agg_inference(self, values, corr):
52 | """
53 | SpeedUp version of Autocorrelation (a batch-normalization style design)
54 | This is for the inference phase.
55 | """
56 | batch = values.shape[0]
57 | head = values.shape[1]
58 | channel = values.shape[2]
59 | length = values.shape[3]
60 | # index init
61 | init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).cuda()
62 | # find top k
63 | top_k = int(self.factor * math.log(length))
64 | mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
65 | weights, delay = torch.topk(mean_value, top_k, dim=-1)
66 | # update corr
67 | tmp_corr = torch.softmax(weights, dim=-1)
68 | # aggregation
69 | tmp_values = values.repeat(1, 1, 1, 2)
70 | delays_agg = torch.zeros_like(values).float()
71 | for i in range(top_k):
72 | tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)
73 | pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
74 | delays_agg = delays_agg + pattern * \
75 | (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))
76 | return delays_agg
77 |
78 | def time_delay_agg_full(self, values, corr):
79 | """
80 | Standard version of Autocorrelation
81 | """
82 | batch = values.shape[0]
83 | head = values.shape[1]
84 | channel = values.shape[2]
85 | length = values.shape[3]
86 | # index init
87 | init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).cuda()
88 | # find top k
89 | top_k = int(self.factor * math.log(length))
90 | weights, delay = torch.topk(corr, top_k, dim=-1)
91 | # update corr
92 | tmp_corr = torch.softmax(weights, dim=-1)
93 | # aggregation
94 | tmp_values = values.repeat(1, 1, 1, 2)
95 | delays_agg = torch.zeros_like(values).float()
96 | for i in range(top_k):
97 | tmp_delay = init_index + delay[..., i].unsqueeze(-1)
98 | pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
99 | delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1))
100 | return delays_agg
101 |
102 | def forward(self, queries, keys, values, attn_mask):
103 | B, L, H, E = queries.shape
104 | _, S, _, D = values.shape
105 | if L > S:
106 | zeros = torch.zeros_like(queries[:, :(L - S), :]).float()
107 | values = torch.cat([values, zeros], dim=1)
108 | keys = torch.cat([keys, zeros], dim=1)
109 | else:
110 | values = values[:, :L, :, :]
111 | keys = keys[:, :L, :, :]
112 |
113 | # period-based dependencies
114 | q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1)
115 | k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1)
116 | res = q_fft * torch.conj(k_fft)
117 | corr = torch.fft.irfft(res, dim=-1)
118 |
119 | # time delay agg
120 | if self.training:
121 | V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)
122 | else:
123 | V = self.time_delay_agg_inference(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)
124 |
125 | if self.output_attention:
126 | return (V.contiguous(), corr.permute(0, 3, 1, 2))
127 | else:
128 | return (V.contiguous(), None)
129 |
130 |
131 | class AutoCorrelationLayer(nn.Module):
132 | def __init__(self, correlation, d_model, n_heads, d_keys=None,
133 | d_values=None):
134 | super(AutoCorrelationLayer, self).__init__()
135 |
136 | d_keys = d_keys or (d_model // n_heads)
137 | d_values = d_values or (d_model // n_heads)
138 |
139 | self.inner_correlation = correlation
140 | self.query_projection = nn.Linear(d_model, d_keys * n_heads)
141 | self.key_projection = nn.Linear(d_model, d_keys * n_heads)
142 | self.value_projection = nn.Linear(d_model, d_values * n_heads)
143 | self.out_projection = nn.Linear(d_values * n_heads, d_model)
144 | self.n_heads = n_heads
145 |
146 | def forward(self, queries, keys, values, attn_mask):
147 | B, L, _ = queries.shape
148 | _, S, _ = keys.shape
149 | H = self.n_heads
150 |
151 | queries = self.query_projection(queries).view(B, L, H, -1)
152 | keys = self.key_projection(keys).view(B, S, H, -1)
153 | values = self.value_projection(values).view(B, S, H, -1)
154 |
155 | out, attn = self.inner_correlation(
156 | queries,
157 | keys,
158 | values,
159 | attn_mask
160 | )
161 | out = out.view(B, L, -1)
162 |
163 | return self.out_projection(out), attn
164 |
--------------------------------------------------------------------------------
/models/subject_layers/Autoformer_EncDec.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class my_Layernorm(nn.Module):
7 | """
8 | Special designed layernorm for the seasonal part
9 | """
10 |
11 | def __init__(self, channels):
12 | super(my_Layernorm, self).__init__()
13 | self.layernorm = nn.LayerNorm(channels)
14 |
15 | def forward(self, x):
16 | x_hat = self.layernorm(x)
17 | bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1)
18 | return x_hat - bias
19 |
20 |
21 | class moving_avg(nn.Module):
22 | """
23 | Moving average block to highlight the trend of time series
24 | """
25 |
26 | def __init__(self, kernel_size, stride):
27 | super(moving_avg, self).__init__()
28 | self.kernel_size = kernel_size
29 | self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
30 |
31 | def forward(self, x):
32 | # padding on the both ends of time series
33 | front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
34 | end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
35 | x = torch.cat([front, x, end], dim=1)
36 | x = self.avg(x.permute(0, 2, 1))
37 | x = x.permute(0, 2, 1)
38 | return x
39 |
40 |
41 | class series_decomp(nn.Module):
42 | """
43 | Series decomposition block
44 | """
45 |
46 | def __init__(self, kernel_size):
47 | super(series_decomp, self).__init__()
48 | self.moving_avg = moving_avg(kernel_size, stride=1)
49 |
50 | def forward(self, x):
51 | moving_mean = self.moving_avg(x)
52 | res = x - moving_mean
53 | return res, moving_mean
54 |
55 |
56 | class series_decomp_multi(nn.Module):
57 | """
58 | Multiple Series decomposition block from FEDformer
59 | """
60 |
61 | def __init__(self, kernel_size):
62 | super(series_decomp_multi, self).__init__()
63 | self.kernel_size = kernel_size
64 | self.series_decomp = [series_decomp(kernel) for kernel in kernel_size]
65 |
66 | def forward(self, x):
67 | moving_mean = []
68 | res = []
69 | for func in self.series_decomp:
70 | sea, moving_avg = func(x)
71 | moving_mean.append(moving_avg)
72 | res.append(sea)
73 |
74 | sea = sum(res) / len(res)
75 | moving_mean = sum(moving_mean) / len(moving_mean)
76 | return sea, moving_mean
77 |
78 |
79 | class EncoderLayer(nn.Module):
80 | """
81 | Autoformer encoder layer with the progressive decomposition architecture
82 | """
83 |
84 | def __init__(self, attention, d_model, d_ff=None, moving_avg=25, dropout=0.1, activation="relu"):
85 | super(EncoderLayer, self).__init__()
86 | d_ff = d_ff or 4 * d_model
87 | self.attention = attention
88 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False)
89 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False)
90 | self.decomp1 = series_decomp(moving_avg)
91 | self.decomp2 = series_decomp(moving_avg)
92 | self.dropout = nn.Dropout(dropout)
93 | self.activation = F.relu if activation == "relu" else F.gelu
94 |
95 | def forward(self, x, attn_mask=None):
96 | new_x, attn = self.attention(
97 | x, x, x,
98 | attn_mask=attn_mask
99 | )
100 | x = x + self.dropout(new_x)
101 | x, _ = self.decomp1(x)
102 | y = x
103 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
104 | y = self.dropout(self.conv2(y).transpose(-1, 1))
105 | res, _ = self.decomp2(x + y)
106 | return res, attn
107 |
108 |
109 | class Encoder(nn.Module):
110 | """
111 | Autoformer encoder
112 | """
113 |
114 | def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
115 | super(Encoder, self).__init__()
116 | self.attn_layers = nn.ModuleList(attn_layers)
117 | self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
118 | self.norm = norm_layer
119 |
120 | def forward(self, x, attn_mask=None):
121 | attns = []
122 | if self.conv_layers is not None:
123 | for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):
124 | x, attn = attn_layer(x, attn_mask=attn_mask)
125 | x = conv_layer(x)
126 | attns.append(attn)
127 | x, attn = self.attn_layers[-1](x)
128 | attns.append(attn)
129 | else:
130 | for attn_layer in self.attn_layers:
131 | x, attn = attn_layer(x, attn_mask=attn_mask)
132 | attns.append(attn)
133 |
134 | if self.norm is not None:
135 | x = self.norm(x)
136 |
137 | return x, attns
138 |
139 |
140 | class DecoderLayer(nn.Module):
141 | """
142 | Autoformer decoder layer with the progressive decomposition architecture
143 | """
144 |
145 | def __init__(self, self_attention, cross_attention, d_model, c_out, d_ff=None,
146 | moving_avg=25, dropout=0.1, activation="relu"):
147 | super(DecoderLayer, self).__init__()
148 | d_ff = d_ff or 4 * d_model
149 | self.self_attention = self_attention
150 | self.cross_attention = cross_attention
151 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False)
152 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False)
153 | self.decomp1 = series_decomp(moving_avg)
154 | self.decomp2 = series_decomp(moving_avg)
155 | self.decomp3 = series_decomp(moving_avg)
156 | self.dropout = nn.Dropout(dropout)
157 | self.projection = nn.Conv1d(in_channels=d_model, out_channels=c_out, kernel_size=3, stride=1, padding=1,
158 | padding_mode='circular', bias=False)
159 | self.activation = F.relu if activation == "relu" else F.gelu
160 |
161 | def forward(self, x, cross, x_mask=None, cross_mask=None):
162 | x = x + self.dropout(self.self_attention(
163 | x, x, x,
164 | attn_mask=x_mask
165 | )[0])
166 | x, trend1 = self.decomp1(x)
167 | x = x + self.dropout(self.cross_attention(
168 | x, cross, cross,
169 | attn_mask=cross_mask
170 | )[0])
171 | x, trend2 = self.decomp2(x)
172 | y = x
173 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
174 | y = self.dropout(self.conv2(y).transpose(-1, 1))
175 | x, trend3 = self.decomp3(x + y)
176 |
177 | residual_trend = trend1 + trend2 + trend3
178 | residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(1, 2)
179 | return x, residual_trend
180 |
181 |
182 | class Decoder(nn.Module):
183 | """
184 | Autoformer encoder
185 | """
186 |
187 | def __init__(self, layers, norm_layer=None, projection=None):
188 | super(Decoder, self).__init__()
189 | self.layers = nn.ModuleList(layers)
190 | self.norm = norm_layer
191 | self.projection = projection
192 |
193 | def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None):
194 | for layer in self.layers:
195 | x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask)
196 | trend = trend + residual_trend
197 |
198 | if self.norm is not None:
199 | x = self.norm(x)
200 |
201 | if self.projection is not None:
202 | x = self.projection(x)
203 | return x, trend
204 |
--------------------------------------------------------------------------------
/models/subject_layers/FourierCorrelation.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # author=maziqing
3 | # email=maziqing.mzq@alibaba-inc.com
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 |
9 |
10 | def get_frequency_modes(seq_len, modes=64, mode_select_method='random'):
11 | """
12 | get modes on frequency domain:
13 | 'random' means sampling randomly;
14 | 'else' means sampling the lowest modes;
15 | """
16 | modes = min(modes, seq_len // 2)
17 | if mode_select_method == 'random':
18 | index = list(range(0, seq_len // 2))
19 | np.random.shuffle(index)
20 | index = index[:modes]
21 | else:
22 | index = list(range(0, modes))
23 | index.sort()
24 | return index
25 |
26 |
27 | # ########## fourier layer #############
28 | class FourierBlock(nn.Module):
29 | def __init__(self, in_channels, out_channels, seq_len, modes=0, mode_select_method='random'):
30 | super(FourierBlock, self).__init__()
31 | print('fourier enhanced block used!')
32 | """
33 | 1D Fourier block. It performs representation learning on frequency domain,
34 | it does FFT, linear transform, and Inverse FFT.
35 | """
36 | # get modes on frequency domain
37 | self.index = get_frequency_modes(seq_len, modes=modes, mode_select_method=mode_select_method)
38 | print('modes={}, index={}'.format(modes, self.index))
39 |
40 | self.scale = (1 / (in_channels * out_channels))
41 | self.weights1 = nn.Parameter(
42 | self.scale * torch.rand(8, in_channels // 8, out_channels // 8, len(self.index), dtype=torch.float))
43 | self.weights2 = nn.Parameter(
44 | self.scale * torch.rand(8, in_channels // 8, out_channels // 8, len(self.index), dtype=torch.float))
45 |
46 | # Complex multiplication
47 | def compl_mul1d(self, order, x, weights):
48 | x_flag = True
49 | w_flag = True
50 | if not torch.is_complex(x):
51 | x_flag = False
52 | x = torch.complex(x, torch.zeros_like(x).to(x.device))
53 | if not torch.is_complex(weights):
54 | w_flag = False
55 | weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device))
56 | if x_flag or w_flag:
57 | return torch.complex(torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag),
58 | torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real))
59 | else:
60 | return torch.einsum(order, x.real, weights.real)
61 |
62 | def forward(self, q, k, v, mask):
63 | # size = [B, L, H, E]
64 | B, L, H, E = q.shape
65 | x = q.permute(0, 2, 3, 1)
66 | # Compute Fourier coefficients
67 | x_ft = torch.fft.rfft(x, dim=-1)
68 | # Perform Fourier neural operations
69 | out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat)
70 | for wi, i in enumerate(self.index):
71 | if i >= x_ft.shape[3] or wi >= out_ft.shape[3]:
72 | continue
73 | out_ft[:, :, :, wi] = self.compl_mul1d("bhi,hio->bho", x_ft[:, :, :, i],
74 | torch.complex(self.weights1, self.weights2)[:, :, :, wi])
75 | # Return to time domain
76 | x = torch.fft.irfft(out_ft, n=x.size(-1))
77 | return (x, None)
78 |
79 |
80 | # ########## Fourier Cross Former ####################
81 | class FourierCrossAttention(nn.Module):
82 | def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes=64, mode_select_method='random',
83 | activation='tanh', policy=0, num_heads=8):
84 | super(FourierCrossAttention, self).__init__()
85 | print(' fourier enhanced cross attention used!')
86 | """
87 | 1D Fourier Cross Attention layer. It does FFT, linear transform, attention mechanism and Inverse FFT.
88 | """
89 | self.activation = activation
90 | self.in_channels = in_channels
91 | self.out_channels = out_channels
92 | # get modes for queries and keys (& values) on frequency domain
93 | self.index_q = get_frequency_modes(seq_len_q, modes=modes, mode_select_method=mode_select_method)
94 | self.index_kv = get_frequency_modes(seq_len_kv, modes=modes, mode_select_method=mode_select_method)
95 |
96 | print('modes_q={}, index_q={}'.format(len(self.index_q), self.index_q))
97 | print('modes_kv={}, index_kv={}'.format(len(self.index_kv), self.index_kv))
98 |
99 | self.scale = (1 / (in_channels * out_channels))
100 | self.weights1 = nn.Parameter(
101 | self.scale * torch.rand(num_heads, in_channels // num_heads, out_channels // num_heads, len(self.index_q), dtype=torch.float))
102 | self.weights2 = nn.Parameter(
103 | self.scale * torch.rand(num_heads, in_channels // num_heads, out_channels // num_heads, len(self.index_q), dtype=torch.float))
104 |
105 | # Complex multiplication
106 | def compl_mul1d(self, order, x, weights):
107 | x_flag = True
108 | w_flag = True
109 | if not torch.is_complex(x):
110 | x_flag = False
111 | x = torch.complex(x, torch.zeros_like(x).to(x.device))
112 | if not torch.is_complex(weights):
113 | w_flag = False
114 | weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device))
115 | if x_flag or w_flag:
116 | return torch.complex(torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag),
117 | torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real))
118 | else:
119 | return torch.einsum(order, x.real, weights.real)
120 |
121 | def forward(self, q, k, v, mask):
122 | # size = [B, L, H, E]
123 | B, L, H, E = q.shape
124 | xq = q.permute(0, 2, 3, 1) # size = [B, H, E, L]
125 | xk = k.permute(0, 2, 3, 1)
126 | xv = v.permute(0, 2, 3, 1)
127 |
128 | # Compute Fourier coefficients
129 | xq_ft_ = torch.zeros(B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat)
130 | xq_ft = torch.fft.rfft(xq, dim=-1)
131 | for i, j in enumerate(self.index_q):
132 | if j >= xq_ft.shape[3]:
133 | continue
134 | xq_ft_[:, :, :, i] = xq_ft[:, :, :, j]
135 | xk_ft_ = torch.zeros(B, H, E, len(self.index_kv), device=xq.device, dtype=torch.cfloat)
136 | xk_ft = torch.fft.rfft(xk, dim=-1)
137 | for i, j in enumerate(self.index_kv):
138 | if j >= xk_ft.shape[3]:
139 | continue
140 | xk_ft_[:, :, :, i] = xk_ft[:, :, :, j]
141 |
142 | # perform attention mechanism on frequency domain
143 | xqk_ft = (self.compl_mul1d("bhex,bhey->bhxy", xq_ft_, xk_ft_))
144 | if self.activation == 'tanh':
145 | xqk_ft = torch.complex(xqk_ft.real.tanh(), xqk_ft.imag.tanh())
146 | elif self.activation == 'softmax':
147 | xqk_ft = torch.softmax(abs(xqk_ft), dim=-1)
148 | xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft))
149 | else:
150 | raise Exception('{} actiation function is not implemented'.format(self.activation))
151 | xqkv_ft = self.compl_mul1d("bhxy,bhey->bhex", xqk_ft, xk_ft_)
152 | xqkvw = self.compl_mul1d("bhex,heox->bhox", xqkv_ft, torch.complex(self.weights1, self.weights2))
153 | out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat)
154 | for i, j in enumerate(self.index_q):
155 | if i >= xqkvw.shape[3] or j >= out_ft.shape[3]:
156 | continue
157 | out_ft[:, :, :, j] = xqkvw[:, :, :, i]
158 | # Return to time domain
159 | out = torch.fft.irfft(out_ft / self.in_channels / self.out_channels, n=xq.size(-1))
160 | return (out, None)
161 |
--------------------------------------------------------------------------------
/models/subject_layers/Pyraformer_EncDec.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn.modules.linear import Linear
5 | from layers.SelfAttention_Family import AttentionLayer, FullAttention
6 | from layers.Embed import DataEmbedding
7 | import math
8 |
9 |
10 | def get_mask(input_size, window_size, inner_size):
11 | """Get the attention mask of PAM-Naive"""
12 | # Get the size of all layers
13 | all_size = []
14 | all_size.append(input_size)
15 | for i in range(len(window_size)):
16 | layer_size = math.floor(all_size[i] / window_size[i])
17 | all_size.append(layer_size)
18 |
19 | seq_length = sum(all_size)
20 | mask = torch.zeros(seq_length, seq_length)
21 |
22 | # get intra-scale mask
23 | inner_window = inner_size // 2
24 | for layer_idx in range(len(all_size)):
25 | start = sum(all_size[:layer_idx])
26 | for i in range(start, start + all_size[layer_idx]):
27 | left_side = max(i - inner_window, start)
28 | right_side = min(i + inner_window + 1, start + all_size[layer_idx])
29 | mask[i, left_side:right_side] = 1
30 |
31 | # get inter-scale mask
32 | for layer_idx in range(1, len(all_size)):
33 | start = sum(all_size[:layer_idx])
34 | for i in range(start, start + all_size[layer_idx]):
35 | left_side = (start - all_size[layer_idx - 1]) + \
36 | (i - start) * window_size[layer_idx - 1]
37 | if i == (start + all_size[layer_idx] - 1):
38 | right_side = start
39 | else:
40 | right_side = (
41 | start - all_size[layer_idx - 1]) + (i - start + 1) * window_size[layer_idx - 1]
42 | mask[i, left_side:right_side] = 1
43 | mask[left_side:right_side, i] = 1
44 |
45 | mask = (1 - mask).bool()
46 |
47 | return mask, all_size
48 |
49 |
50 | def refer_points(all_sizes, window_size):
51 | """Gather features from PAM's pyramid sequences"""
52 | input_size = all_sizes[0]
53 | indexes = torch.zeros(input_size, len(all_sizes))
54 |
55 | for i in range(input_size):
56 | indexes[i][0] = i
57 | former_index = i
58 | for j in range(1, len(all_sizes)):
59 | start = sum(all_sizes[:j])
60 | inner_layer_idx = former_index - (start - all_sizes[j - 1])
61 | former_index = start + \
62 | min(inner_layer_idx // window_size[j - 1], all_sizes[j] - 1)
63 | indexes[i][j] = former_index
64 |
65 | indexes = indexes.unsqueeze(0).unsqueeze(3)
66 |
67 | return indexes.long()
68 |
69 |
70 | class RegularMask():
71 | def __init__(self, mask):
72 | self._mask = mask.unsqueeze(1)
73 |
74 | @property
75 | def mask(self):
76 | return self._mask
77 |
78 |
79 | class EncoderLayer(nn.Module):
80 | """ Compose with two layers """
81 |
82 | def __init__(self, d_model, d_inner, n_head, dropout=0.1, normalize_before=True):
83 | super(EncoderLayer, self).__init__()
84 |
85 | self.slf_attn = AttentionLayer(
86 | FullAttention(mask_flag=True, factor=0,
87 | attention_dropout=dropout, output_attention=False),
88 | d_model, n_head)
89 | self.pos_ffn = PositionwiseFeedForward(
90 | d_model, d_inner, dropout=dropout, normalize_before=normalize_before)
91 |
92 | def forward(self, enc_input, slf_attn_mask=None):
93 | attn_mask = RegularMask(slf_attn_mask)
94 | enc_output, _ = self.slf_attn(
95 | enc_input, enc_input, enc_input, attn_mask=attn_mask)
96 | enc_output = self.pos_ffn(enc_output)
97 | return enc_output
98 |
99 |
100 | class Encoder(nn.Module):
101 | """ A encoder model with self attention mechanism. """
102 |
103 | def __init__(self, configs, window_size, inner_size):
104 | super().__init__()
105 |
106 | d_bottleneck = configs.d_model//4
107 |
108 | self.mask, self.all_size = get_mask(
109 | configs.seq_len, window_size, inner_size)
110 | self.indexes = refer_points(self.all_size, window_size)
111 | self.layers = nn.ModuleList([
112 | EncoderLayer(configs.d_model, configs.d_ff, configs.n_heads, dropout=configs.dropout,
113 | normalize_before=False) for _ in range(configs.e_layers)
114 | ]) # naive pyramid attention
115 |
116 | self.enc_embedding = DataEmbedding(
117 | configs.enc_in, configs.d_model, configs.dropout)
118 | self.conv_layers = Bottleneck_Construct(
119 | configs.d_model, window_size, d_bottleneck)
120 |
121 | def forward(self, x_enc, x_mark_enc):
122 | seq_enc = self.enc_embedding(x_enc, x_mark_enc)
123 |
124 | mask = self.mask.repeat(len(seq_enc), 1, 1).to(x_enc.device)
125 | seq_enc = self.conv_layers(seq_enc)
126 |
127 | for i in range(len(self.layers)):
128 | seq_enc = self.layers[i](seq_enc, mask)
129 |
130 | indexes = self.indexes.repeat(seq_enc.size(
131 | 0), 1, 1, seq_enc.size(2)).to(seq_enc.device)
132 | indexes = indexes.view(seq_enc.size(0), -1, seq_enc.size(2))
133 | all_enc = torch.gather(seq_enc, 1, indexes)
134 | seq_enc = all_enc.view(seq_enc.size(0), self.all_size[0], -1)
135 |
136 | return seq_enc
137 |
138 |
139 | class ConvLayer(nn.Module):
140 | def __init__(self, c_in, window_size):
141 | super(ConvLayer, self).__init__()
142 | self.downConv = nn.Conv1d(in_channels=c_in,
143 | out_channels=c_in,
144 | kernel_size=window_size,
145 | stride=window_size)
146 | self.norm = nn.BatchNorm1d(c_in)
147 | self.activation = nn.ELU()
148 |
149 | def forward(self, x):
150 | x = self.downConv(x)
151 | x = self.norm(x)
152 | x = self.activation(x)
153 | return x
154 |
155 |
156 | class Bottleneck_Construct(nn.Module):
157 | """Bottleneck convolution CSCM"""
158 |
159 | def __init__(self, d_model, window_size, d_inner):
160 | super(Bottleneck_Construct, self).__init__()
161 | if not isinstance(window_size, list):
162 | self.conv_layers = nn.ModuleList([
163 | ConvLayer(d_inner, window_size),
164 | ConvLayer(d_inner, window_size),
165 | ConvLayer(d_inner, window_size)
166 | ])
167 | else:
168 | self.conv_layers = []
169 | for i in range(len(window_size)):
170 | self.conv_layers.append(ConvLayer(d_inner, window_size[i]))
171 | self.conv_layers = nn.ModuleList(self.conv_layers)
172 | self.up = Linear(d_inner, d_model)
173 | self.down = Linear(d_model, d_inner)
174 | self.norm = nn.LayerNorm(d_model)
175 |
176 | def forward(self, enc_input):
177 | temp_input = self.down(enc_input).permute(0, 2, 1)
178 | all_inputs = []
179 | for i in range(len(self.conv_layers)):
180 | temp_input = self.conv_layers[i](temp_input)
181 | all_inputs.append(temp_input)
182 |
183 | all_inputs = torch.cat(all_inputs, dim=2).transpose(1, 2)
184 | all_inputs = self.up(all_inputs)
185 | all_inputs = torch.cat([enc_input, all_inputs], dim=1)
186 |
187 | all_inputs = self.norm(all_inputs)
188 | return all_inputs
189 |
190 |
191 | class PositionwiseFeedForward(nn.Module):
192 | """ Two-layer position-wise feed-forward neural network. """
193 |
194 | def __init__(self, d_in, d_hid, dropout=0.1, normalize_before=True):
195 | super().__init__()
196 |
197 | self.normalize_before = normalize_before
198 |
199 | self.w_1 = nn.Linear(d_in, d_hid)
200 | self.w_2 = nn.Linear(d_hid, d_in)
201 |
202 | self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
203 | self.dropout = nn.Dropout(dropout)
204 |
205 | def forward(self, x):
206 | residual = x
207 | if self.normalize_before:
208 | x = self.layer_norm(x)
209 |
210 | x = F.gelu(self.w_1(x))
211 | x = self.dropout(x)
212 | x = self.w_2(x)
213 | x = self.dropout(x)
214 | x = x + residual
215 |
216 | if not self.normalize_before:
217 | x = self.layer_norm(x)
218 | return x
219 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
Visual Decoding and Reconstruction via EEG Embeddings with Guided Diffusion
4 |
5 |
6 |
18 |
19 |
20 |
21 |
22 |
23 |
25 |
27 |
28 |
29 |
30 |
33 |
34 |
35 |
36 | Framework of our proposed method.
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 | Some examples of using EEG to reconstruct stimulus images.
45 |
46 |
47 | ## News:
48 | - [2024/09/26] Our paper is accepted to **NeurIPS 2024**.
49 | - [2024/09/25] We have updated the [arxiv](https://arxiv.org/abs/2403.07721) paper.
50 | - [2024/08/01] Update scripts for training and inference in different tasks.
51 | - [2024/05/19] Update the dataset loading scripts.
52 | - [2024/03/12] The [arxiv](https://arxiv.org/abs/2403.07721) paper is available.
53 |
54 |
55 |
56 | Environment setup
57 |
58 | ### Option 1: Using setup.sh (Recommended)
59 |
60 | Run the setup script to create a conda environment with all dependencies:
61 |
62 | ```bash
63 | . setup.sh
64 | conda activate BCI
65 | ```
66 |
67 | ### Option 2: Using environment.yml
68 |
69 | ```bash
70 | conda env create -f environment.yml
71 | conda activate BCI
72 | ```
73 |
74 | ### Option 3: Using requirements.txt
75 |
76 | ```bash
77 | conda create -n BCI python=3.12 -y
78 | conda activate BCI
79 | pip install torch==2.5.0 torchvision==0.20.0 torchaudio==2.5.0 --index-url https://download.pytorch.org/whl/cu124
80 | pip install -r requirements.txt
81 | ```
82 |
83 |
84 |
85 | Quick training and test
86 |
87 | If you want to quickly reproduce the results in the paper, please download the relevant ``preprocessed data`` and ``model weights`` from [Hugging Face](https://huggingface.co/datasets/LidongYang/EEG_Image_decode) first.
88 | #### 1.Image Retrieval
89 | We provide the script to learn the training strategy of EEG Encoder and verify it during training. In this task, we use **normalized clip embedding** to train EEG encoder. Please modify your data set path and run:
90 | ```
91 | cd Retrieval/
92 | python ATMS_retrieval.py --logger True --gpu cuda:0 --output_dir ./outputs/contrast
93 | ```
94 | We also provide the script for ``joint subject training``, which aims to train all subjects jointly and test on a specific subject:
95 | ```
96 | cd Retrieval/
97 | python ATMS_retrieval_joint_train.py --joint_train --sub sub-01 True --logger True --gpu cuda:0 --output_dir ./outputs/contrast
98 | ```
99 |
100 | Additionally, replicating the results of other methods (e.g. EEGNetV4) by run
101 | ```
102 | cd Retrieval/
103 | contrast_retrieval.py --encoder_type EEGNetv4_Encoder --epochs 30 --batch_size 1024
104 | ```
105 |
106 | #### 2.Image Reconstruction
107 | We provide quick training and inference scripts for ``clip pipeline`` of visual reconstruction. In this task, we use **the original clip embedding** to train EEG encoder. Please modify your data set path and run zero-shot on 200 classes test dataset:
108 | ```
109 | # Train and generate eeg features in Subject 8
110 | cd Generation/
111 | python ATMS_reconstruction.py --insubject True --subjects sub-08 --logger True \
112 | --gpu cuda:0 --output_dir ./outputs/contrast
113 | ```
114 |
115 | ```
116 | # Reconstruct images in Subject 8
117 | Generation_metrics_sub8.ipynb
118 | ```
119 |
120 | We also provide scripts for image reconstruction combined ``with the low level pipeline``.
121 | ```
122 | cd Generation/
123 |
124 | # step 1: train vae encoder and then generate low level images
125 | train_vae_latent_512_low_level_no_average.py
126 |
127 | # step 2: load low level images and then reconstruct them
128 | 1x1024_reconstruct_sdxl.ipynb
129 | ```
130 |
131 |
132 | We provide scripts for caption generation combined ``with the semantic level pipeline``.
133 | ```
134 | cd Generation/
135 |
136 | # step 1: train feature adapter
137 | image_adapter.ipynb
138 |
139 | # step 2: get caption from eeg latent
140 | GIT_caption_batch.ipynb
141 |
142 | # step 3: load text prompt and then reconstruct images
143 | 1x1024_reconstruct_sdxl.ipynb
144 | ```
145 |
146 | To evaluate the quality of the reconstructed images, modify the paths of the reconstructed images and the original stimulus images in the notebook and run:
147 | ```
148 | #compute metrics, cited from MindEye
149 | Reconstruction_Metrics_ATM.ipynb
150 | ```
151 |
152 |
153 | Data availability
154 |
155 | We provide you with the ``preprocessed EEG`` and ``preprocessed MEG`` data used in our paper at [Hugging Face](https://huggingface.co/datasets/LidongYang/EEG_Image_decode), as well as the raw image data.
156 |
157 |
158 | Note that the experimental paradigms of the THINGS-EEG and THINGS-MEG datasets themselves are different, so we will provide images and data for the two datasets separately.
159 |
160 | You can also download the relevant THINGS-EEG data set and THINGS-MEG data set at osf.io.
161 |
162 | The raw and preprocessed EEG dataset, the training and test images are available on [osf](https://osf.io/3jk45/).
163 | - ``Raw EEG data:`` `../project_directory/eeg_dataset/raw_data/`.
164 | - ``Preprocessed EEG data:`` `../project_directory/eeg_dataset/preprocessed_data/`.
165 | - ``Training and test images:`` `../project_directory/image_set/`.
166 |
167 |
168 | The raw and preprocessed MEG dataset, the training and test images are available on [OpenNEURO](https://openneuro.org/datasets/ds004212/versions/2.0.0).
169 |
170 |
171 |
172 |
173 |
174 |
175 | EEG/MEG preprocessing
176 |
177 |
178 | Modify your path and execute the following code to perform the same preprocessing on the raw data as in our experiment:
179 | ```
180 | cd EEG-preprocessing/
181 | python EEG-preprocessing/preprocessing.py
182 | ```
183 |
184 | ```
185 | cd MEG-preprocessing/
186 | MEG-preprocessing/pre_possess.ipynb
187 | ```
188 | Also You can get the data set used in this project through the BaiduNetDisk [link](https://pan.baidu.com/s/1-1hgpoi4nereLVqE4ylE_g?pwd=nid5) to run the code.
189 |
190 | ## TODO
191 | - [√] Release retrieval and reconstruction scripts.
192 | - [√] Update training scripts of reconstruction pipeline.
193 | - [ ] Adding validation sets improves performance evaluation accuracy.
194 |
195 |
196 |
197 |
198 | Acknowledge
199 |
200 | 1.Thanks to Y Song et al. for their contribution in data set preprocessing and neural network structure, we refer to their work:"[Decoding Natural Images from EEG for Object Recognition](https://arxiv.org/pdf/2308.13234.pdf)". Yonghao Song, Bingchuan Liu, Xiang Li, Nanlin Shi, Yijun Wang, and Xiaorong Gao.
201 |
202 | 2.We also thank the authors of [SDRecon](https://github.com/yu-takagi/StableDiffusionReconstruction) for providing the codes and the results. Some parts of the training script are based on [MindEye](https://medarc-ai.github.io/mindeye/) and [MindEye2](https://github.com/MedARC-AI/MindEyeV2). Thanks for the awesome research works.
203 |
204 | 3.Here we provide our THING-EEG dataset cited in the paper:"[A large and rich EEG dataset for modeling human visual object recognition](https://www.sciencedirect.com/science/article/pii/S1053811922008758?via%3Dihub)".
205 | Alessandro T. Gifford, Kshitij Dwivedi, Gemma Roig, Radoslaw M. Cichy.
206 |
207 |
208 | 4.Another used THINGS-MEG data set provides a reference:"[THINGS-data, a multimodal collection of large-scale datasets for investigating object representations in human brain and behavior.](https://elifesciences.org/articles/82580.pdf)". Hebart, Martin N., Oliver Contier, Lina Teichmann, Adam H. Rockter, Charles Y. Zheng, Alexis Kidder, Anna Corriveau, Maryam Vaziri-Pashkam, and Chris I. Baker.
209 |
210 |
211 |
212 |
213 | Citation
214 |
215 | ```bibtex
216 | @inproceedings{li2024visual,
217 | author = {Li, Dongyang and Wei, Chen and Li, Shiying and Zou, Jiachen and Liu, Quanying},
218 | booktitle = {Advances in Neural Information Processing Systems},
219 | editor = {A. Globerson and L. Mackey and D. Belgrave and A. Fan and U. Paquet and J. Tomczak and C. Zhang},
220 | pages = {102822--102864},
221 | publisher = {Curran Associates, Inc.},
222 | title = {Visual Decoding and Reconstruction via EEG Embeddings with Guided Diffusion},
223 | url = {https://proceedings.neurips.cc/paper_files/paper/2024/file/ba5f1233efa77787ff9ec015877dbd1f-Paper-Conference.pdf},
224 | volume = {37},
225 | year = {2024}
226 | }
227 |
228 |
229 | @article{li2024visual,
230 | title={Visual Decoding and Reconstruction via EEG Embeddings with Guided Diffusion},
231 | author={Li, Dongyang and Wei, Chen and Li, Shiying and Zou, Jiachen and Liu, Quanying},
232 | journal={arXiv preprint arXiv:2403.07721},
233 | year={2024}
234 | }
235 | ```
236 |
--------------------------------------------------------------------------------
/models/subject_layers/Embed.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn.utils import weight_norm
5 | import math
6 |
7 |
8 | class PositionalEmbedding(nn.Module):
9 | def __init__(self, d_model, max_len=5000):
10 | super(PositionalEmbedding, self).__init__()
11 | # Compute the positional encodings once in log space.
12 | pe = torch.zeros(max_len, d_model).float()
13 | pe.require_grad = False
14 |
15 | position = torch.arange(0, max_len).float().unsqueeze(1)
16 | div_term = (torch.arange(0, d_model, 2).float()
17 | * -(math.log(10000.0) / d_model)).exp()
18 |
19 | pe[:, 0::2] = torch.sin(position * div_term)
20 | pe[:, 1::2] = torch.cos(position * div_term)
21 |
22 | pe = pe.unsqueeze(0)
23 | self.register_buffer('pe', pe)
24 |
25 | def forward(self, x):
26 | return self.pe[:, :x.size(1)]
27 |
28 |
29 | class TokenEmbedding(nn.Module):
30 | def __init__(self, c_in, d_model):
31 | super(TokenEmbedding, self).__init__()
32 | padding = 1 if torch.__version__ >= '1.5.0' else 2
33 | self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
34 | kernel_size=3, padding=padding, padding_mode='circular', bias=False)
35 | for m in self.modules():
36 | if isinstance(m, nn.Conv1d):
37 | nn.init.kaiming_normal_(
38 | m.weight, mode='fan_in', nonlinearity='leaky_relu')
39 |
40 | def forward(self, x):
41 | x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
42 | return x
43 |
44 |
45 | class FixedEmbedding(nn.Module):
46 | def __init__(self, c_in, d_model):
47 | super(FixedEmbedding, self).__init__()
48 |
49 | w = torch.zeros(c_in, d_model).float()
50 | w.require_grad = False
51 |
52 | position = torch.arange(0, c_in).float().unsqueeze(1)
53 | div_term = (torch.arange(0, d_model, 2).float()
54 | * -(math.log(10000.0) / d_model)).exp()
55 |
56 | w[:, 0::2] = torch.sin(position * div_term)
57 | w[:, 1::2] = torch.cos(position * div_term)
58 |
59 | self.emb = nn.Embedding(c_in, d_model)
60 | self.emb.weight = nn.Parameter(w, requires_grad=False)
61 |
62 | def forward(self, x):
63 | return self.emb(x).detach()
64 |
65 |
66 | class TemporalEmbedding(nn.Module):
67 | def __init__(self, d_model, embed_type='fixed', freq='h'):
68 | super(TemporalEmbedding, self).__init__()
69 |
70 | minute_size = 4
71 | hour_size = 24
72 | weekday_size = 7
73 | day_size = 32
74 | month_size = 13
75 |
76 | Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding
77 | if freq == 't':
78 | self.minute_embed = Embed(minute_size, d_model)
79 | self.hour_embed = Embed(hour_size, d_model)
80 | self.weekday_embed = Embed(weekday_size, d_model)
81 | self.day_embed = Embed(day_size, d_model)
82 | self.month_embed = Embed(month_size, d_model)
83 |
84 | def forward(self, x):
85 | x = x.long()
86 | minute_x = self.minute_embed(x[:, :, 4]) if hasattr(
87 | self, 'minute_embed') else 0.
88 | hour_x = self.hour_embed(x[:, :, 3])
89 | weekday_x = self.weekday_embed(x[:, :, 2])
90 | day_x = self.day_embed(x[:, :, 1])
91 | month_x = self.month_embed(x[:, :, 0])
92 |
93 | return hour_x + weekday_x + day_x + month_x + minute_x
94 |
95 |
96 | class TimeFeatureEmbedding(nn.Module):
97 | def __init__(self, d_model, embed_type='timeF', freq='h'):
98 | super(TimeFeatureEmbedding, self).__init__()
99 |
100 | freq_map = {'h': 4, 't': 5, 's': 6,
101 | 'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3}
102 | d_inp = freq_map[freq]
103 | self.embed = nn.Linear(d_inp, d_model, bias=False)
104 |
105 | def forward(self, x):
106 | return self.embed(x)
107 |
108 |
109 | class SubjectEmbedding(nn.Module):
110 | def __init__(self, num_subjects, d_model):
111 | super(SubjectEmbedding, self).__init__()
112 | self.subject_embedding = nn.Embedding(num_subjects, d_model)
113 | self.shared_embedding = nn.Parameter(torch.randn(1, d_model)) # Shared token for unknown subjects
114 | self.mask_embedding = nn.Parameter(torch.randn(1, d_model)) # Mask token embedding
115 |
116 | def forward(self, subject_ids):
117 | if subject_ids[0] is None or torch.any(subject_ids >= self.subject_embedding.num_embeddings):
118 | batch_size = subject_ids.size(0)
119 | return self.shared_embedding.expand(batch_size, 1, -1)
120 | else:
121 | return self.subject_embedding(subject_ids).unsqueeze(1)
122 |
123 |
124 | # class DataEmbedding(nn.Module):
125 | # def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1, num_subjects=None):
126 | # super(DataEmbedding, self).__init__()
127 | # self.value_embedding = nn.Linear(c_in, d_model)
128 | # self.position_embedding = PositionalEmbedding(d_model=d_model)
129 | # self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq)
130 | # self.dropout = nn.Dropout(p=dropout)
131 | # self.subject_embedding = SubjectEmbedding(num_subjects, d_model) if num_subjects is not None else None
132 | # self.mask_token = nn.Parameter(torch.randn(1, d_model)) # Mask token embedding
133 |
134 | # def forward(self, x, x_mark, subject_ids=None, mask=None):
135 | # if x_mark is None:
136 | # x = self.value_embedding(x)
137 | # else:
138 | # x = self.value_embedding(x) + self.temporal_embedding(x_mark) + self.position_embedding(x)
139 |
140 | # if mask is not None:
141 | # x = x * (~mask.bool()) + self.mask_token * mask.float()
142 |
143 | # if self.subject_embedding is not None:
144 | # subject_emb = self.subject_embedding(subject_ids) # (batch_size, 1, d_model)
145 | # x = torch.cat([subject_emb, x], dim=1) # Concatenate along sequence dimension (batch_size, seq_len + 1, d_model)
146 |
147 | # return self.dropout(x)
148 |
149 | class DataEmbedding(nn.Module):
150 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1, joint_train=False, num_subjects=None):
151 | super(DataEmbedding, self).__init__()
152 | if joint_train and num_subjects is not None:
153 | self.value_embedding = nn.ModuleDict({
154 | str(subject_id): nn.Linear(c_in, d_model) for subject_id in range(num_subjects)
155 | })
156 | else:
157 | self.value_embedding = nn.Linear(c_in, d_model) # If no subjects specified, use a single value embedding
158 |
159 | self.position_embedding = PositionalEmbedding(d_model=d_model)
160 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq)
161 | self.dropout = nn.Dropout(p=dropout)
162 | self.subject_embedding = SubjectEmbedding(num_subjects, d_model) if num_subjects is not None else None
163 | self.mask_token = nn.Parameter(torch.randn(1, d_model)) # Mask token embedding
164 | self.joint_train = joint_train
165 |
166 | def forward(self, x, x_mark, subject_ids=None, mask=None):
167 | if self.joint_train:
168 | # Use subject-specific value embedding for each subject
169 | x = torch.stack([self.value_embedding[str(subject_id.item())](x[i]) for i, subject_id in enumerate(subject_ids)])
170 | else:
171 | x = self.value_embedding(x)
172 |
173 | if x_mark is not None:
174 | x = x + self.temporal_embedding(x_mark) + self.position_embedding(x)
175 |
176 | if mask is not None:
177 | x = x * (~mask.bool()) + self.mask_token * mask.float()
178 |
179 | if self.subject_embedding is not None:
180 | subject_emb = self.subject_embedding(subject_ids) # (batch_size, 1, d_model)
181 | x = torch.cat([subject_emb, x], dim=1) # Concatenate along sequence dimension (batch_size, seq_len + 1, d_model)
182 |
183 | return self.dropout(x)
184 |
185 |
186 |
187 |
188 | class DataEmbedding_inverted(nn.Module):
189 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
190 | super(DataEmbedding_inverted, self).__init__()
191 | self.value_embedding = nn.Linear(c_in, d_model)
192 | self.dropout = nn.Dropout(p=dropout)
193 |
194 | def forward(self, x, x_mark):
195 | x = x.permute(0, 2, 1)
196 | # x: [Batch Variate Time]
197 | if x_mark is None:
198 | x = self.value_embedding(x)
199 | else:
200 | x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1))
201 | # x: [Batch Variate d_model]
202 | return self.dropout(x)
203 |
204 |
205 | class DataEmbedding_wo_pos(nn.Module):
206 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
207 | super(DataEmbedding_wo_pos, self).__init__()
208 |
209 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
210 | self.position_embedding = PositionalEmbedding(d_model=d_model)
211 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type,
212 | freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(
213 | d_model=d_model, embed_type=embed_type, freq=freq)
214 | self.dropout = nn.Dropout(p=dropout)
215 |
216 | def forward(self, x, x_mark):
217 | if x_mark is None:
218 | x = self.value_embedding(x) + self.position_embedding(x)
219 | else:
220 | x = self.value_embedding(x) + self.temporal_embedding(x_mark)
221 | return self.dropout(x)
222 |
223 |
224 | class PatchEmbedding(nn.Module):
225 | def __init__(self, d_model, patch_len, stride, padding, dropout):
226 | super(PatchEmbedding, self).__init__()
227 | # Patching
228 | self.patch_len = patch_len
229 | self.stride = stride
230 | self.padding_patch_layer = nn.ReplicationPad1d((0, padding))
231 |
232 | # Backbone, Input encoding: projection of feature vectors onto a d-dim vector space
233 | self.value_embedding = nn.Linear(patch_len, d_model, bias=False)
234 |
235 | # Positional embedding
236 | self.position_embedding = PositionalEmbedding(d_model)
237 |
238 | # Residual dropout
239 | self.dropout = nn.Dropout(dropout)
240 |
241 | def forward(self, x):
242 | # do patching
243 | n_vars = x.shape[1]
244 | x = self.padding_patch_layer(x)
245 | x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
246 | x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
247 | # Input encoding
248 | x = self.value_embedding(x) + self.position_embedding(x)
249 | return self.dropout(x), n_vars
250 |
--------------------------------------------------------------------------------
/EEG-preprocessing/preprocessing_utils.py:
--------------------------------------------------------------------------------
1 | def epoching(args, data_part, seed):
2 | """This function first converts the EEG data to MNE raw format, and
3 | performs channel selection, epoching, baseline correction and frequency
4 | downsampling. Then, it sorts the EEG data of each session according to the
5 | image conditions.
6 |
7 | Parameters
8 | ----------
9 | args : Namespace
10 | Input arguments.
11 | data_part : str
12 | 'test' or 'training' data partitions.
13 | seed : int
14 | Random seed.
15 |
16 | Returns
17 | -------
18 | epoched_data : list of float
19 | Epoched EEG data.
20 | img_conditions : list of int
21 | Unique image conditions of the epoched and sorted EEG data.
22 | ch_names : list of str
23 | EEG channel names.
24 | times : float
25 | EEG time points.
26 |
27 | """
28 |
29 | import os
30 | import mne
31 | import numpy as np
32 | from sklearn.utils import shuffle
33 |
34 | chan_order = ['Fp1', 'Fp2', 'AF7', 'AF3', 'AFz', 'AF4', 'AF8', 'F7', 'F5', 'F3',
35 | 'F1', 'F2', 'F4', 'F6', 'F8', 'FT9', 'FT7', 'FC5', 'FC3', 'FC1',
36 | 'FCz', 'FC2', 'FC4', 'FC6', 'FT8', 'FT10', 'T7', 'C5', 'C3', 'C1',
37 | 'Cz', 'C2', 'C4', 'C6', 'T8', 'TP9', 'TP7', 'CP5', 'CP3', 'CP1',
38 | 'CPz', 'CP2', 'CP4', 'CP6', 'TP8', 'TP10', 'P7', 'P5', 'P3', 'P1',
39 | 'Pz', 'P2', 'P4', 'P6', 'P8', 'PO7', 'PO3', 'POz', 'PO4', 'PO8',
40 | 'O1', 'Oz', 'O2']
41 |
42 | ### Loop across data collection sessions ###
43 | epoched_data = []
44 | img_conditions = []
45 | for s in range(args.n_ses):
46 |
47 | ### Load the EEG data and convert it to MNE raw format ###
48 | eeg_dir = os.path.join('Raw_data', 'sub-'+
49 | format(args.sub,'02'), 'ses-'+format(s+1,'02'), 'raw_eeg_'+
50 | data_part+'.npy')
51 | eeg_data = np.load(os.path.join(args.project_dir, eeg_dir),
52 | allow_pickle=True).item()
53 | ch_names = eeg_data['ch_names']
54 | sfreq = eeg_data['sfreq']
55 | ch_types = eeg_data['ch_types']
56 | eeg_data = eeg_data['raw_eeg_data']
57 | # Convert to MNE raw format
58 | info = mne.create_info(ch_names, sfreq, ch_types)
59 | raw = mne.io.RawArray(eeg_data, info)
60 | del eeg_data
61 |
62 | ### Get events, drop unused channels and reject target trials ###
63 | events = mne.find_events(raw, stim_channel='stim')
64 | # # Select only occipital (O) and posterior (P) channels
65 | # chan_idx = np.asarray(mne.pick_channels_regexp(raw.info['ch_names'],
66 | # '^O *|^P *'))
67 | # new_chans = [raw.info['ch_names'][c] for c in chan_idx]
68 | # raw.pick_channels(new_chans)
69 | # * chose all channels
70 | raw.pick_channels(chan_order, ordered=True)
71 | # Reject the target trials (event 99999)
72 | idx_target = np.where(events[:,2] == 99999)[0]
73 | events = np.delete(events, idx_target, 0)
74 | ### Epoching, baseline correction and resampling ###
75 | # * [0, 1.0]
76 | epochs = mne.Epochs(raw, events, tmin=-.2, tmax=1.0, baseline=(None,0),
77 | preload=True)
78 | # epochs = mne.Epochs(raw, events, tmin=-.2, tmax=.8, baseline=(None,0),
79 | # preload=True)
80 | del raw
81 | # Resampling
82 | if args.sfreq < 1000:
83 | epochs.resample(args.sfreq)
84 | ch_names = epochs.info['ch_names']
85 | times = epochs.times
86 |
87 | ### Sort the data ###
88 | data = epochs.get_data()
89 | events = epochs.events[:,2]
90 | img_cond = np.unique(events)
91 | del epochs
92 | # Select only a maximum number of EEG repetitions
93 | if data_part == 'test':
94 | max_rep = 20
95 | else:
96 | max_rep = 2
97 | # Sorted data matrix of shape:
98 | # Image conditions × EEG repetitions × EEG channels × EEG time points
99 | sorted_data = np.zeros((len(img_cond),max_rep,data.shape[1],
100 | data.shape[2]))
101 | for i in range(len(img_cond)):
102 | # Find the indices of the selected image condition
103 | idx = np.where(events == img_cond[i])[0]
104 | # Randomly select only the max number of EEG repetitions
105 | idx = shuffle(idx, random_state=seed, n_samples=max_rep)
106 | sorted_data[i] = data[idx]
107 | del data
108 | epoched_data.append(sorted_data[:, :, :, 50:])
109 | img_conditions.append(img_cond)
110 | del sorted_data
111 |
112 | ### Output ###
113 | return epoched_data, img_conditions, ch_names, times
114 |
115 |
116 | def mvnn(args, epoched_test, epoched_train):
117 | """Compute the covariance matrices of the EEG data (calculated for each
118 | time-point or epoch/repetitions of each image condition), and then average
119 | them across image conditions and data partitions. The inverse of the
120 | resulting averaged covariance matrix is used to whiten the EEG data
121 | (independently for each session).
122 |
123 | zero-score standardization also has well performance
124 |
125 | Parameters
126 | ----------
127 | args : Namespace
128 | Input arguments.
129 | epoched_test : list of floats
130 | Epoched test EEG data.
131 | epoched_train : list of floats
132 | Epoched training EEG data.
133 |
134 | Returns
135 | -------
136 | whitened_test : list of float
137 | Whitened test EEG data.
138 | whitened_train : list of float
139 | Whitened training EEG data.
140 |
141 | """
142 |
143 | import numpy as np
144 | from tqdm import tqdm
145 | from sklearn.discriminant_analysis import _cov
146 | import scipy
147 |
148 | ### Loop across data collection sessions ###
149 | whitened_test = []
150 | whitened_train = []
151 | for s in range(args.n_ses):
152 | session_data = [epoched_test[s], epoched_train[s]]
153 |
154 | ### Compute the covariance matrices ###
155 | # Data partitions covariance matrix of shape:
156 | # Data partitions × EEG channels × EEG channels
157 | sigma_part = np.empty((len(session_data),session_data[0].shape[2],
158 | session_data[0].shape[2]))
159 | for p in range(sigma_part.shape[0]):
160 | # Image conditions covariance matrix of shape:
161 | # Image conditions × EEG channels × EEG channels
162 | sigma_cond = np.empty((session_data[p].shape[0],
163 | session_data[0].shape[2],session_data[0].shape[2]))
164 | for i in tqdm(range(session_data[p].shape[0])):
165 | cond_data = session_data[p][i]
166 | # Compute covariace matrices at each time point, and then
167 | # average across time points
168 | if args.mvnn_dim == "time":
169 | sigma_cond[i] = np.mean([_cov(cond_data[:,:,t],
170 | shrinkage='auto') for t in range(cond_data.shape[2])],
171 | axis=0)
172 | # Compute covariace matrices at each epoch (EEG repetition),
173 | # and then average across epochs/repetitions
174 | elif args.mvnn_dim == "epochs":
175 | sigma_cond[i] = np.mean([_cov(np.transpose(cond_data[e]),
176 | shrinkage='auto') for e in range(cond_data.shape[0])],
177 | axis=0)
178 | # Average the covariance matrices across image conditions
179 | sigma_part[p] = sigma_cond.mean(axis=0)
180 | # # Average the covariance matrices across image partitions
181 | # sigma_tot = sigma_part.mean(axis=0)
182 | # ? It seems not fair to use test data for mvnn, so we change to just use training data
183 | sigma_tot = sigma_part[1]
184 | # Compute the inverse of the covariance matrix
185 | sigma_inv = scipy.linalg.fractional_matrix_power(sigma_tot, -0.5)
186 |
187 | ### Whiten the data ###
188 | whitened_test.append(np.reshape((np.reshape(session_data[0], (-1,
189 | session_data[0].shape[2],session_data[0].shape[3])).swapaxes(1, 2)
190 | @ sigma_inv).swapaxes(1, 2), session_data[0].shape))
191 | whitened_train.append(np.reshape((np.reshape(session_data[1], (-1,
192 | session_data[1].shape[2],session_data[1].shape[3])).swapaxes(1, 2)
193 | @ sigma_inv).swapaxes(1, 2), session_data[1].shape))
194 |
195 | ### Output ###
196 | return whitened_test, whitened_train
197 |
198 |
199 | def save_prepr(args, whitened_test, whitened_train, img_conditions_train,
200 | ch_names, times, seed):
201 | """Merge the EEG data of all sessions together, shuffle the EEG repetitions
202 | across sessions and reshaping the data to the format:
203 | Image conditions × EGG repetitions × EEG channels × EEG time points.
204 | Then, the data of both test and training EEG partitions is saved.
205 |
206 | Parameters
207 | ----------
208 | args : Namespace
209 | Input arguments.
210 | whitened_test : list of float
211 | Whitened test EEG data.
212 | whitened_train : list of float
213 | Whitened training EEG data.
214 | img_conditions_train : list of int
215 | Unique image conditions of the epoched and sorted train EEG data.
216 | ch_names : list of str
217 | EEG channel names.
218 | times : float
219 | EEG time points.
220 | seed : int
221 | Random seed.
222 |
223 | """
224 |
225 | import numpy as np
226 | from sklearn.utils import shuffle
227 | import os
228 | import pickle
229 |
230 | ### Merge and save the test data ###
231 | for s in range(args.n_ses):
232 | if s == 0:
233 | merged_test = whitened_test[s]
234 | else:
235 | merged_test = np.append(merged_test, whitened_test[s], 1)
236 | del whitened_test
237 | # Shuffle the repetitions of different sessions
238 | idx = shuffle(np.arange(0, merged_test.shape[1]), random_state=seed)
239 | merged_test = merged_test[:,idx]
240 | # Insert the data into a dictionary
241 | test_dict = {
242 | 'preprocessed_eeg_data': merged_test,
243 | 'ch_names': ch_names,
244 | 'times': times
245 | }
246 | del merged_test
247 | # Saving directories
248 | save_dir = os.path.join(args.project_dir,
249 | 'Preprocessed_data_250Hz', 'sub-'+format(args.sub,'02'))
250 | file_name_test = 'preprocessed_eeg_test.npy'
251 | file_name_train = 'preprocessed_eeg_training.npy'
252 | # Create the directory if not existing and save the data
253 | if os.path.isdir(save_dir) == False:
254 | os.makedirs(save_dir)
255 | # np.save(os.path.join(save_dir, file_name_test), test_dict)
256 | save_pic = open(os.path.join(save_dir, file_name_test), 'wb')
257 | pickle.dump(test_dict, save_pic, protocol=4)
258 | save_pic.close()
259 | del test_dict
260 |
261 | ### Merge and save the training data ###
262 | for s in range(args.n_ses):
263 | if s == 0:
264 | white_data = whitened_train[s]
265 | img_cond = img_conditions_train[s]
266 | else:
267 | white_data = np.append(white_data, whitened_train[s], 0)
268 | img_cond = np.append(img_cond, img_conditions_train[s], 0)
269 | del whitened_train, img_conditions_train
270 | # Data matrix of shape:
271 | # Image conditions × EGG repetitions × EEG channels × EEG time points
272 | merged_train = np.zeros((len(np.unique(img_cond)), white_data.shape[1]*2,
273 | white_data.shape[2],white_data.shape[3]))
274 | for i in range(len(np.unique(img_cond))):
275 | # Find the indices of the selected category
276 | idx = np.where(img_cond == i+1)[0]
277 | for r in range(len(idx)):
278 | if r == 0:
279 | ordered_data = white_data[idx[r]]
280 | else:
281 | ordered_data = np.append(ordered_data, white_data[idx[r]], 0)
282 | merged_train[i] = ordered_data
283 | # Shuffle the repetitions of different sessions
284 | idx = shuffle(np.arange(0, merged_train.shape[1]), random_state=seed)
285 | merged_train = merged_train[:,idx]
286 | # Insert the data into a dictionary
287 | train_dict = {
288 | 'preprocessed_eeg_data': merged_train,
289 | 'ch_names': ch_names,
290 | 'times': times
291 | }
292 | del merged_train
293 | # Create the directory if not existing and save the data
294 | if os.path.isdir(save_dir) == False:
295 | os.makedirs(save_dir)
296 | # np.save(os.path.join(save_dir, file_name_train),
297 | # train_dict)
298 | save_pic = open(os.path.join(save_dir, file_name_train), 'wb')
299 | pickle.dump(train_dict, save_pic, protocol=4)
300 | save_pic.close()
301 | del train_dict
302 |
--------------------------------------------------------------------------------
/models/util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 | import torch
4 | import os
5 | import sys
6 | import time
7 | from torch import inf
8 | import wandb
9 |
10 | class NativeScaler:
11 | state_dict_key = "amp_scaler"
12 |
13 | def __init__(self):
14 | self._scaler = torch.cuda.amp.GradScaler()
15 |
16 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
17 | self._scaler.scale(loss).backward(create_graph=create_graph)
18 | if update_grad:
19 | if clip_grad is not None:
20 | assert parameters is not None
21 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
22 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
23 | else:
24 | self._scaler.unscale_(optimizer)
25 | norm = get_grad_norm_(parameters)
26 | self._scaler.step(optimizer)
27 | self._scaler.update()
28 | else:
29 | norm = None
30 | return norm
31 |
32 | def state_dict(self):
33 | return self._scaler.state_dict()
34 |
35 | def load_state_dict(self, state_dict):
36 | self._scaler.load_state_dict(state_dict)
37 |
38 |
39 |
40 | def get_grad_norm_(parameters, norm_type: float = 2.0):
41 | if isinstance(parameters, torch.Tensor):
42 | parameters = [parameters]
43 | parameters = [p for p in parameters if p.grad is not None]
44 | norm_type = float(norm_type)
45 | if len(parameters) == 0:
46 | return torch.tensor(0.)
47 | device = parameters[0].grad.device
48 | if norm_type == inf:
49 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
50 | else:
51 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
52 | return total_norm
53 |
54 | def train_one_epoch(model, data_loader, optimizer, device, epoch,
55 | loss_scaler, log_writer=None, config=None, start_time=None, model_without_ddp=None,
56 | img_feature_extractor=None, preprocess=None):
57 | model.train(True)
58 | optimizer.zero_grad()
59 | total_loss = []
60 | total_cor = []
61 | accum_iter = config.accum_iter
62 | for data_iter_step, (data_dcit) in enumerate(data_loader):
63 |
64 | # we use a per iteration (instead of per epoch) lr scheduler
65 | # print(data_iter_step)
66 | # print(len(data_loader))
67 |
68 | if data_iter_step % accum_iter == 0:
69 | adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, config)
70 | samples = data_dcit['eeg']
71 |
72 | img_features = None
73 | valid_idx = None
74 | if img_feature_extractor is not None:
75 | images = data_dcit['image']
76 | valid_idx = torch.nonzero(images.sum(dim=(1,2,3)) != 0).squeeze(1)
77 | img_feature_extractor.eval()
78 | with torch.no_grad():
79 | img_features = img_feature_extractor(preprocess(images[valid_idx]).to(device))['layer2']
80 | samples = samples.to(device)
81 | # img_features = img_features.to(device)
82 |
83 | optimizer.zero_grad()
84 | with torch.cuda.amp.autocast(enabled=True):
85 | loss, pred, _ = model(samples, img_features, valid_idx=valid_idx, mask_ratio=config.mask_ratio)
86 | # loss.backward()
87 | # norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.clip_grad)
88 | # optimizer.step()
89 |
90 | loss_value = loss.item()
91 |
92 | if not math.isfinite(loss_value):
93 | print(f"Loss is {loss_value}, stopping training at step {data_iter_step} epoch {epoch}")
94 | sys.exit(1)
95 |
96 | # loss /= accum_iter
97 | loss_scaler(loss, optimizer, parameters=model.parameters(), clip_grad=config.clip_grad)
98 |
99 | # if (data_iter_step + 1) % accum_iter == 0:
100 | # cal the cor
101 | pred = pred.to('cpu').detach()
102 | samples = samples.to('cpu').detach()
103 | # pred = pred.transpose(1,2) #model_without_ddp.unpatchify(pred)
104 | pred = model_without_ddp.unpatchify(pred)
105 |
106 | cor = torch.mean(torch.tensor([torch.corrcoef(torch.cat([p[0].unsqueeze(0), s[0].unsqueeze(0)],axis=0))[0,1] for p, s in zip(pred, samples)])).item()
107 | optimizer.zero_grad()
108 |
109 | total_loss.append(loss_value)
110 | total_cor.append(cor)
111 | if device == torch.device('cuda:0'):
112 | lr = optimizer.param_groups[0]["lr"]
113 | print('train_loss_step:', np.mean(total_loss), 'lr:', lr, 'cor', np.mean(total_cor))
114 |
115 | if log_writer is not None:
116 | lr = optimizer.param_groups[0]["lr"]
117 | log_writer.log('train_loss_step', np.mean(total_loss), step=epoch)
118 | log_writer.log('lr', lr, step=epoch)
119 | log_writer.log('cor', np.mean(total_cor), step=epoch)
120 | if start_time is not None:
121 | log_writer.log('time (min)', (time.time() - start_time)/60.0, step=epoch)
122 | if config.local_rank == 0:
123 | print(f'[Epoch {epoch}] loss: {np.mean(total_loss)}')
124 |
125 | return np.mean(total_cor)
126 |
127 | def get_1d_sincos_pos_embed(embed_dim, length, cls_token=False):
128 | """
129 | grid_size: int of the grid height and width
130 | return:
131 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
132 | """
133 | grid_l = np.arange(length, dtype=float)
134 |
135 | grid_l = grid_l.reshape([1, length])
136 | pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_l)
137 | if cls_token:
138 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
139 | return pos_embed
140 |
141 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
142 | """
143 | embed_dim: output dimension for each position
144 | pos: a list of positions to be encoded: size (M,)
145 | out: (M, D)
146 | """
147 | assert embed_dim % 2 == 0
148 | omega = np.arange(embed_dim // 2, dtype=float)
149 | omega /= embed_dim / 2.
150 | omega = 1. / 10000**omega # (D/2,)
151 |
152 | pos = pos.reshape(-1) # (M,)
153 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
154 |
155 | emb_sin = np.sin(out) # (M, D/2)
156 | emb_cos = np.cos(out) # (M, D/2)
157 |
158 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
159 | return emb
160 |
161 | def interpolate_pos_embed(model, checkpoint_model):
162 | if 'pos_embed' in checkpoint_model:
163 | pos_embed_checkpoint = checkpoint_model['pos_embed']
164 | embedding_size = pos_embed_checkpoint.shape[-1]
165 | num_patches = model.patch_embed.num_patches
166 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches # cls token
167 | # height (== width) for the checkpoint position embedding
168 | orig_size = int(pos_embed_checkpoint.shape[-2] - num_extra_tokens)
169 | # height (== width) for the new position embedding
170 | new_size = int(num_patches)
171 | # class_token and dist_token are kept unchanged
172 | if orig_size != new_size:
173 | print("Position interpolate from %d to %d" % (orig_size, new_size))
174 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
175 | # only the position tokens are interpolated
176 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
177 | pos_tokens = pos_tokens.reshape(-1, orig_size, embedding_size).permute(0, 2, 1)
178 | pos_tokens = torch.nn.functional.interpolate(
179 | pos_tokens, size=(new_size))
180 | pos_tokens = pos_tokens.permute(0, 2, 1)
181 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
182 | checkpoint_model['pos_embed'] = new_pos_embed
183 |
184 |
185 | def adjust_learning_rate(optimizer, epoch, config):
186 | """Decay the learning rate with half-cycle cosine after warmup"""
187 | if epoch < config.warmup_epochs:
188 | lr = config.lr * epoch / config.warmup_epochs
189 | else:
190 | lr = config.min_lr + (config.lr - config.min_lr) * 0.5 * \
191 | (1. + math.cos(math.pi * (epoch - config.warmup_epochs) / (config.num_epoch - config.warmup_epochs)))
192 | for param_group in optimizer.param_groups:
193 | if "lr_scale" in param_group:
194 | param_group["lr"] = lr * param_group["lr_scale"]
195 | else:
196 | param_group["lr"] = lr
197 | return lr
198 |
199 |
200 |
201 |
202 | def load_model(config, model, checkpoint_path ):
203 | checkpoint = torch.load(checkpoint_path, map_location='cpu')
204 | model.load_state_dict(checkpoint['model'])
205 | print(f'Model loaded with {checkpoint_path}')
206 |
207 |
208 | def patchify(imgs, patch_size):
209 | """
210 | imgs: (N, 1, num_voxels)
211 | x: (N, L, patch_size)
212 | """
213 | p = patch_size
214 | assert imgs.ndim == 3 and imgs.shape[2] % p == 0
215 |
216 | h = imgs.shape[2] // p
217 | x = imgs.reshape(shape=(imgs.shape[0], h, p))
218 | return x
219 |
220 | def unpatchify(x, patch_size):
221 | """
222 | x: (N, L, patch_size)
223 | imgs: (N, 1, num_voxels)
224 | """
225 | p = patch_size
226 | h = x.shape[1]
227 |
228 | imgs = x.reshape(shape=(x.shape[0], 1, h * p))
229 | return imgs
230 |
231 | class wandb_logger:
232 | def __init__(self, config):
233 | try:
234 | wandb.init(
235 | # Set the project where this run will be logged
236 | project=config['project'],
237 | name=config['name'],
238 | config=config,
239 | entity=config['entity'],
240 | )
241 | except:
242 | wandb.init(
243 | # Set the project where this run will be logged
244 | project=config.project,
245 | name=config.name,
246 | config=config,
247 | entity=config.entity,
248 | )
249 |
250 | self.config = config
251 | self.step = None
252 |
253 | def log(self, data, step=None):
254 | if step is None:
255 | wandb.log(data)
256 | else:
257 | wandb.log(data, step=step)
258 | self.step = step
259 |
260 | def watch_model(self, *args, **kwargs):
261 | wandb.watch(*args, **kwargs)
262 |
263 | def log_image(self, figs):
264 | if self.step is None:
265 | wandb.log(figs)
266 | else:
267 | wandb.log(figs, step=self.step)
268 |
269 | def finish(self):
270 | wandb.finish(quiet=True)
271 |
272 | def load(self, net):
273 | path = os.path.join(self.config['path_data'], self.config['path_ckpt'], self.config['file_ckpt'])
274 | net.load_state_dict(torch.load(path))
275 | print(f'load {path}')
276 |
277 | def save(self, net, file_name=None):
278 | path_ckpt = os.path.join(self.config['path_data'], self.config['path_ckpt'])
279 | if not os.path.exists(path_ckpt):
280 | os.makedirs(path_ckpt)
281 | print(f'{path_ckpt} created!')
282 |
283 | path = os.path.join(path_ckpt, file_name)
284 | torch.save(net.state_dict(), path)
285 |
286 | def watch(self, model, log):
287 | wandb.watch(model, log)
--------------------------------------------------------------------------------
/models/subject_layers/ETSformer_EncDec.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.fft as fft
5 | from einops import rearrange, reduce, repeat
6 | import math, random
7 | from scipy.fftpack import next_fast_len
8 |
9 |
10 | class Transform:
11 | def __init__(self, sigma):
12 | self.sigma = sigma
13 |
14 | @torch.no_grad()
15 | def transform(self, x):
16 | return self.jitter(self.shift(self.scale(x)))
17 |
18 | def jitter(self, x):
19 | return x + (torch.randn(x.shape).to(x.device) * self.sigma)
20 |
21 | def scale(self, x):
22 | return x * (torch.randn(x.size(-1)).to(x.device) * self.sigma + 1)
23 |
24 | def shift(self, x):
25 | return x + (torch.randn(x.size(-1)).to(x.device) * self.sigma)
26 |
27 |
28 | def conv1d_fft(f, g, dim=-1):
29 | N = f.size(dim)
30 | M = g.size(dim)
31 |
32 | fast_len = next_fast_len(N + M - 1)
33 |
34 | F_f = fft.rfft(f, fast_len, dim=dim)
35 | F_g = fft.rfft(g, fast_len, dim=dim)
36 |
37 | F_fg = F_f * F_g.conj()
38 | out = fft.irfft(F_fg, fast_len, dim=dim)
39 | out = out.roll((-1,), dims=(dim,))
40 | idx = torch.as_tensor(range(fast_len - N, fast_len)).to(out.device)
41 | out = out.index_select(dim, idx)
42 |
43 | return out
44 |
45 |
46 | class ExponentialSmoothing(nn.Module):
47 |
48 | def __init__(self, dim, nhead, dropout=0.1, aux=False):
49 | super().__init__()
50 | self._smoothing_weight = nn.Parameter(torch.randn(nhead, 1))
51 | self.v0 = nn.Parameter(torch.randn(1, 1, nhead, dim))
52 | self.dropout = nn.Dropout(dropout)
53 | if aux:
54 | self.aux_dropout = nn.Dropout(dropout)
55 |
56 | def forward(self, values, aux_values=None):
57 | b, t, h, d = values.shape
58 |
59 | init_weight, weight = self.get_exponential_weight(t)
60 | output = conv1d_fft(self.dropout(values), weight, dim=1)
61 | output = init_weight * self.v0 + output
62 |
63 | if aux_values is not None:
64 | aux_weight = weight / (1 - self.weight) * self.weight
65 | aux_output = conv1d_fft(self.aux_dropout(aux_values), aux_weight)
66 | output = output + aux_output
67 |
68 | return output
69 |
70 | def get_exponential_weight(self, T):
71 | # Generate array [0, 1, ..., T-1]
72 | powers = torch.arange(T, dtype=torch.float, device=self.weight.device)
73 |
74 | # (1 - \alpha) * \alpha^t, for all t = T-1, T-2, ..., 0]
75 | weight = (1 - self.weight) * (self.weight ** torch.flip(powers, dims=(0,)))
76 |
77 | # \alpha^t for all t = 1, 2, ..., T
78 | init_weight = self.weight ** (powers + 1)
79 |
80 | return rearrange(init_weight, 'h t -> 1 t h 1'), \
81 | rearrange(weight, 'h t -> 1 t h 1')
82 |
83 | @property
84 | def weight(self):
85 | return torch.sigmoid(self._smoothing_weight)
86 |
87 |
88 | class Feedforward(nn.Module):
89 | def __init__(self, d_model, dim_feedforward, dropout=0.1, activation='sigmoid'):
90 | # Implementation of Feedforward model
91 | super().__init__()
92 | self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False)
93 | self.dropout1 = nn.Dropout(dropout)
94 | self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False)
95 | self.dropout2 = nn.Dropout(dropout)
96 | self.activation = getattr(F, activation)
97 |
98 | def forward(self, x):
99 | x = self.linear2(self.dropout1(self.activation(self.linear1(x))))
100 | return self.dropout2(x)
101 |
102 |
103 | class GrowthLayer(nn.Module):
104 |
105 | def __init__(self, d_model, nhead, d_head=None, dropout=0.1):
106 | super().__init__()
107 | self.d_head = d_head or (d_model // nhead)
108 | self.d_model = d_model
109 | self.nhead = nhead
110 |
111 | self.z0 = nn.Parameter(torch.randn(self.nhead, self.d_head))
112 | self.in_proj = nn.Linear(self.d_model, self.d_head * self.nhead)
113 | self.es = ExponentialSmoothing(self.d_head, self.nhead, dropout=dropout)
114 | self.out_proj = nn.Linear(self.d_head * self.nhead, self.d_model)
115 |
116 | assert self.d_head * self.nhead == self.d_model, "d_model must be divisible by nhead"
117 |
118 | def forward(self, inputs):
119 | """
120 | :param inputs: shape: (batch, seq_len, dim)
121 | :return: shape: (batch, seq_len, dim)
122 | """
123 | b, t, d = inputs.shape
124 | values = self.in_proj(inputs).view(b, t, self.nhead, -1)
125 | values = torch.cat([repeat(self.z0, 'h d -> b 1 h d', b=b), values], dim=1)
126 | values = values[:, 1:] - values[:, :-1]
127 | out = self.es(values)
128 | out = torch.cat([repeat(self.es.v0, '1 1 h d -> b 1 h d', b=b), out], dim=1)
129 | out = rearrange(out, 'b t h d -> b t (h d)')
130 | return self.out_proj(out)
131 |
132 |
133 | class FourierLayer(nn.Module):
134 |
135 | def __init__(self, d_model, pred_len, k=None, low_freq=1):
136 | super().__init__()
137 | self.d_model = d_model
138 | self.pred_len = pred_len
139 | self.k = k
140 | self.low_freq = low_freq
141 |
142 | def forward(self, x):
143 | """x: (b, t, d)"""
144 | b, t, d = x.shape
145 | x_freq = fft.rfft(x, dim=1)
146 |
147 | if t % 2 == 0:
148 | x_freq = x_freq[:, self.low_freq:-1]
149 | f = fft.rfftfreq(t)[self.low_freq:-1]
150 | else:
151 | x_freq = x_freq[:, self.low_freq:]
152 | f = fft.rfftfreq(t)[self.low_freq:]
153 |
154 | x_freq, index_tuple = self.topk_freq(x_freq)
155 | f = repeat(f, 'f -> b f d', b=x_freq.size(0), d=x_freq.size(2))
156 | f = rearrange(f[index_tuple], 'b f d -> b f () d').to(x_freq.device)
157 |
158 | return self.extrapolate(x_freq, f, t)
159 |
160 | def extrapolate(self, x_freq, f, t):
161 | x_freq = torch.cat([x_freq, x_freq.conj()], dim=1)
162 | f = torch.cat([f, -f], dim=1)
163 | t_val = rearrange(torch.arange(t + self.pred_len, dtype=torch.float),
164 | 't -> () () t ()').to(x_freq.device)
165 |
166 | amp = rearrange(x_freq.abs() / t, 'b f d -> b f () d')
167 | phase = rearrange(x_freq.angle(), 'b f d -> b f () d')
168 |
169 | x_time = amp * torch.cos(2 * math.pi * f * t_val + phase)
170 |
171 | return reduce(x_time, 'b f t d -> b t d', 'sum')
172 |
173 | def topk_freq(self, x_freq):
174 | values, indices = torch.topk(x_freq.abs(), self.k, dim=1, largest=True, sorted=True)
175 | mesh_a, mesh_b = torch.meshgrid(torch.arange(x_freq.size(0)), torch.arange(x_freq.size(2)))
176 | index_tuple = (mesh_a.unsqueeze(1), indices, mesh_b.unsqueeze(1))
177 | x_freq = x_freq[index_tuple]
178 |
179 | return x_freq, index_tuple
180 |
181 |
182 | class LevelLayer(nn.Module):
183 |
184 | def __init__(self, d_model, c_out, dropout=0.1):
185 | super().__init__()
186 | self.d_model = d_model
187 | self.c_out = c_out
188 |
189 | self.es = ExponentialSmoothing(1, self.c_out, dropout=dropout, aux=True)
190 | self.growth_pred = nn.Linear(self.d_model, self.c_out)
191 | self.season_pred = nn.Linear(self.d_model, self.c_out)
192 |
193 | def forward(self, level, growth, season):
194 | b, t, _ = level.shape
195 | growth = self.growth_pred(growth).view(b, t, self.c_out, 1)
196 | season = self.season_pred(season).view(b, t, self.c_out, 1)
197 | growth = growth.view(b, t, self.c_out, 1)
198 | season = season.view(b, t, self.c_out, 1)
199 | level = level.view(b, t, self.c_out, 1)
200 | out = self.es(level - season, aux_values=growth)
201 | out = rearrange(out, 'b t h d -> b t (h d)')
202 | return out
203 |
204 |
205 | class EncoderLayer(nn.Module):
206 |
207 | def __init__(self, d_model, nhead, c_out, seq_len, pred_len, k, dim_feedforward=None, dropout=0.1,
208 | activation='sigmoid', layer_norm_eps=1e-5):
209 | super().__init__()
210 | self.d_model = d_model
211 | self.nhead = nhead
212 | self.c_out = c_out
213 | self.seq_len = seq_len
214 | self.pred_len = pred_len
215 | dim_feedforward = dim_feedforward or 4 * d_model
216 | self.dim_feedforward = dim_feedforward
217 |
218 | self.growth_layer = GrowthLayer(d_model, nhead, dropout=dropout)
219 | self.seasonal_layer = FourierLayer(d_model, pred_len, k=k)
220 | self.level_layer = LevelLayer(d_model, c_out, dropout=dropout)
221 |
222 | # Implementation of Feedforward model
223 | self.ff = Feedforward(d_model, dim_feedforward, dropout=dropout, activation=activation)
224 | self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
225 | self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
226 |
227 | self.dropout1 = nn.Dropout(dropout)
228 | self.dropout2 = nn.Dropout(dropout)
229 |
230 | def forward(self, res, level, attn_mask=None):
231 | season = self._season_block(res)
232 | res = res - season[:, :-self.pred_len]
233 | growth = self._growth_block(res)
234 | res = self.norm1(res - growth[:, 1:])
235 | res = self.norm2(res + self.ff(res))
236 |
237 | level = self.level_layer(level, growth[:, :-1], season[:, :-self.pred_len])
238 | return res, level, growth, season
239 |
240 | def _growth_block(self, x):
241 | x = self.growth_layer(x)
242 | return self.dropout1(x)
243 |
244 | def _season_block(self, x):
245 | x = self.seasonal_layer(x)
246 | return self.dropout2(x)
247 |
248 |
249 | class Encoder(nn.Module):
250 |
251 | def __init__(self, layers):
252 | super().__init__()
253 | self.layers = nn.ModuleList(layers)
254 |
255 | def forward(self, res, level, attn_mask=None):
256 | growths = []
257 | seasons = []
258 | for layer in self.layers:
259 | res, level, growth, season = layer(res, level, attn_mask=None)
260 | growths.append(growth)
261 | seasons.append(season)
262 |
263 | return level, growths, seasons
264 |
265 |
266 | class DampingLayer(nn.Module):
267 |
268 | def __init__(self, pred_len, nhead, dropout=0.1):
269 | super().__init__()
270 | self.pred_len = pred_len
271 | self.nhead = nhead
272 | self._damping_factor = nn.Parameter(torch.randn(1, nhead))
273 | self.dropout = nn.Dropout(dropout)
274 |
275 | def forward(self, x):
276 | x = repeat(x, 'b 1 d -> b t d', t=self.pred_len)
277 | b, t, d = x.shape
278 |
279 | powers = torch.arange(self.pred_len).to(self._damping_factor.device) + 1
280 | powers = powers.view(self.pred_len, 1)
281 | damping_factors = self.damping_factor ** powers
282 | damping_factors = damping_factors.cumsum(dim=0)
283 | x = x.view(b, t, self.nhead, -1)
284 | x = self.dropout(x) * damping_factors.unsqueeze(-1)
285 | return x.view(b, t, d)
286 |
287 | @property
288 | def damping_factor(self):
289 | return torch.sigmoid(self._damping_factor)
290 |
291 |
292 | class DecoderLayer(nn.Module):
293 |
294 | def __init__(self, d_model, nhead, c_out, pred_len, dropout=0.1):
295 | super().__init__()
296 | self.d_model = d_model
297 | self.nhead = nhead
298 | self.c_out = c_out
299 | self.pred_len = pred_len
300 |
301 | self.growth_damping = DampingLayer(pred_len, nhead, dropout=dropout)
302 | self.dropout1 = nn.Dropout(dropout)
303 |
304 | def forward(self, growth, season):
305 | growth_horizon = self.growth_damping(growth[:, -1:])
306 | growth_horizon = self.dropout1(growth_horizon)
307 |
308 | seasonal_horizon = season[:, -self.pred_len:]
309 | return growth_horizon, seasonal_horizon
310 |
311 |
312 | class Decoder(nn.Module):
313 |
314 | def __init__(self, layers):
315 | super().__init__()
316 | self.d_model = layers[0].d_model
317 | self.c_out = layers[0].c_out
318 | self.pred_len = layers[0].pred_len
319 | self.nhead = layers[0].nhead
320 |
321 | self.layers = nn.ModuleList(layers)
322 | self.pred = nn.Linear(self.d_model, self.c_out)
323 |
324 | def forward(self, growths, seasons):
325 | growth_repr = []
326 | season_repr = []
327 |
328 | for idx, layer in enumerate(self.layers):
329 | growth_horizon, season_horizon = layer(growths[idx], seasons[idx])
330 | growth_repr.append(growth_horizon)
331 | season_repr.append(season_horizon)
332 | growth_repr = sum(growth_repr)
333 | season_repr = sum(season_repr)
334 | return self.pred(growth_repr), self.pred(season_repr)
335 |
--------------------------------------------------------------------------------
/models/subject_layers/SelfAttention_Family.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from math import sqrt
5 | from models.utils.masking import TriangularCausalMask, ProbMask
6 | from reformer_pytorch import LSHSelfAttention
7 | from einops import rearrange, repeat
8 |
9 |
10 | class DSAttention(nn.Module):
11 | '''De-stationary Attention'''
12 |
13 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
14 | super(DSAttention, self).__init__()
15 | self.scale = scale
16 | self.mask_flag = mask_flag
17 | self.output_attention = output_attention
18 | self.dropout = nn.Dropout(attention_dropout)
19 |
20 | def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
21 | B, L, H, E = queries.shape
22 | _, S, _, D = values.shape
23 | scale = self.scale or 1. / sqrt(E)
24 |
25 | tau = 1.0 if tau is None else tau.unsqueeze(
26 | 1).unsqueeze(1) # B x 1 x 1 x 1
27 | delta = 0.0 if delta is None else delta.unsqueeze(
28 | 1).unsqueeze(1) # B x 1 x 1 x S
29 |
30 | # De-stationary Attention, rescaling pre-softmax score with learned de-stationary factors
31 | scores = torch.einsum("blhe,bshe->bhls", queries, keys) * tau + delta
32 |
33 | if self.mask_flag:
34 | if attn_mask is None:
35 | attn_mask = TriangularCausalMask(B, L, device=queries.device)
36 |
37 | scores.masked_fill_(attn_mask.mask, -np.inf)
38 |
39 | A = self.dropout(torch.softmax(scale * scores, dim=-1))
40 | V = torch.einsum("bhls,bshd->blhd", A, values)
41 |
42 | if self.output_attention:
43 | return V.contiguous(), A
44 | else:
45 | return V.contiguous(), None
46 |
47 |
48 | class FullAttention(nn.Module):
49 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
50 | super(FullAttention, self).__init__()
51 | self.scale = scale
52 | self.mask_flag = mask_flag
53 | self.output_attention = output_attention
54 | self.dropout = nn.Dropout(attention_dropout)
55 |
56 | def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
57 | B, L, H, E = queries.shape
58 | _, S, _, D = values.shape
59 | scale = self.scale or 1. / sqrt(E)
60 |
61 | scores = torch.einsum("blhe,bshe->bhls", queries, keys)
62 |
63 | if self.mask_flag:
64 | if attn_mask is None:
65 | attn_mask = TriangularCausalMask(B, L, device=queries.device)
66 |
67 | scores.masked_fill_(attn_mask.mask, -np.inf)
68 |
69 | A = self.dropout(torch.softmax(scale * scores, dim=-1))
70 | V = torch.einsum("bhls,bshd->blhd", A, values)
71 |
72 | if self.output_attention:
73 | return V.contiguous(), A
74 | else:
75 | return V.contiguous(), None
76 |
77 |
78 | class ProbAttention(nn.Module):
79 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
80 | super(ProbAttention, self).__init__()
81 | self.factor = factor
82 | self.scale = scale
83 | self.mask_flag = mask_flag
84 | self.output_attention = output_attention
85 | self.dropout = nn.Dropout(attention_dropout)
86 |
87 | def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q)
88 | # Q [B, H, L, D]
89 | B, H, L_K, E = K.shape
90 | _, _, L_Q, _ = Q.shape
91 |
92 | # calculate the sampled Q_K
93 | K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
94 | # real U = U_part(factor*ln(L_k))*L_q
95 | index_sample = torch.randint(L_K, (L_Q, sample_k))
96 | K_sample = K_expand[:, :, torch.arange(
97 | L_Q).unsqueeze(1), index_sample, :]
98 | Q_K_sample = torch.matmul(
99 | Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()
100 |
101 | # find the Top_k query with sparisty measurement
102 | M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
103 | M_top = M.topk(n_top, sorted=False)[1]
104 |
105 | # use the reduced Q to calculate Q_K
106 | Q_reduce = Q[torch.arange(B)[:, None, None],
107 | torch.arange(H)[None, :, None],
108 | M_top, :] # factor*ln(L_q)
109 | Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k
110 |
111 | return Q_K, M_top
112 |
113 | def _get_initial_context(self, V, L_Q):
114 | B, H, L_V, D = V.shape
115 | if not self.mask_flag:
116 | # V_sum = V.sum(dim=-2)
117 | V_sum = V.mean(dim=-2)
118 | contex = V_sum.unsqueeze(-2).expand(B, H,
119 | L_Q, V_sum.shape[-1]).clone()
120 | else: # use mask
121 | # requires that L_Q == L_V, i.e. for self-attention only
122 | assert (L_Q == L_V)
123 | contex = V.cumsum(dim=-2)
124 | return contex
125 |
126 | def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
127 | B, H, L_V, D = V.shape
128 |
129 | if self.mask_flag:
130 | attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
131 | scores.masked_fill_(attn_mask.mask, -np.inf)
132 |
133 | attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores)
134 |
135 | context_in[torch.arange(B)[:, None, None],
136 | torch.arange(H)[None, :, None],
137 | index, :] = torch.matmul(attn, V).type_as(context_in)
138 | if self.output_attention:
139 | attns = (torch.ones([B, H, L_V, L_V]) /
140 | L_V).type_as(attn).to(attn.device)
141 | attns[torch.arange(B)[:, None, None], torch.arange(H)[
142 | None, :, None], index, :] = attn
143 | return context_in, attns
144 | else:
145 | return context_in, None
146 |
147 | def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
148 | B, L_Q, H, D = queries.shape
149 | _, L_K, _, _ = keys.shape
150 |
151 | queries = queries.transpose(2, 1)
152 | keys = keys.transpose(2, 1)
153 | values = values.transpose(2, 1)
154 |
155 | U_part = self.factor * \
156 | np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k)
157 | u = self.factor * \
158 | np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q)
159 |
160 | U_part = U_part if U_part < L_K else L_K
161 | u = u if u < L_Q else L_Q
162 |
163 | scores_top, index = self._prob_QK(
164 | queries, keys, sample_k=U_part, n_top=u)
165 |
166 | # add scale factor
167 | scale = self.scale or 1. / sqrt(D)
168 | if scale is not None:
169 | scores_top = scores_top * scale
170 | # get the context
171 | context = self._get_initial_context(values, L_Q)
172 | # update the context with selected top_k queries
173 | context, attn = self._update_context(
174 | context, values, scores_top, index, L_Q, attn_mask)
175 |
176 | return context.contiguous(), attn
177 |
178 |
179 | class AttentionLayer(nn.Module):
180 | def __init__(self, attention, d_model, n_heads, d_keys=None,
181 | d_values=None):
182 | super(AttentionLayer, self).__init__()
183 |
184 | d_keys = d_keys or (d_model // n_heads)
185 | d_values = d_values or (d_model // n_heads)
186 |
187 | self.inner_attention = attention
188 | self.query_projection = nn.Linear(d_model, d_keys * n_heads)
189 | self.key_projection = nn.Linear(d_model, d_keys * n_heads)
190 | self.value_projection = nn.Linear(d_model, d_values * n_heads)
191 | self.out_projection = nn.Linear(d_values * n_heads, d_model)
192 | self.n_heads = n_heads
193 |
194 | def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
195 | B, L, _ = queries.shape
196 | _, S, _ = keys.shape
197 | H = self.n_heads
198 |
199 | queries = self.query_projection(queries).view(B, L, H, -1)
200 | keys = self.key_projection(keys).view(B, S, H, -1)
201 | values = self.value_projection(values).view(B, S, H, -1)
202 |
203 | out, attn = self.inner_attention(
204 | queries,
205 | keys,
206 | values,
207 | attn_mask,
208 | tau=tau,
209 | delta=delta
210 | )
211 | out = out.view(B, L, -1)
212 |
213 | return self.out_projection(out), attn
214 |
215 |
216 | class ReformerLayer(nn.Module):
217 | def __init__(self, attention, d_model, n_heads, d_keys=None,
218 | d_values=None, causal=False, bucket_size=4, n_hashes=4):
219 | super().__init__()
220 | self.bucket_size = bucket_size
221 | self.attn = LSHSelfAttention(
222 | dim=d_model,
223 | heads=n_heads,
224 | bucket_size=bucket_size,
225 | n_hashes=n_hashes,
226 | causal=causal
227 | )
228 |
229 | def fit_length(self, queries):
230 | # inside reformer: assert N % (bucket_size * 2) == 0
231 | B, N, C = queries.shape
232 | if N % (self.bucket_size * 2) == 0:
233 | return queries
234 | else:
235 | # fill the time series
236 | fill_len = (self.bucket_size * 2) - (N % (self.bucket_size * 2))
237 | return torch.cat([queries, torch.zeros([B, fill_len, C]).to(queries.device)], dim=1)
238 |
239 | def forward(self, queries, keys, values, attn_mask, tau, delta):
240 | # in Reformer: defalut queries=keys
241 | B, N, C = queries.shape
242 | queries = self.attn(self.fit_length(queries))[:, :N, :]
243 | return queries, None
244 |
245 |
246 | class TwoStageAttentionLayer(nn.Module):
247 | '''
248 | The Two Stage Attention (TSA) Layer
249 | input/output shape: [batch_size, Data_dim(D), Seg_num(L), d_model]
250 | '''
251 |
252 | def __init__(self, configs,
253 | seg_num, factor, d_model, n_heads, d_ff=None, dropout=0.1):
254 | super(TwoStageAttentionLayer, self).__init__()
255 | d_ff = d_ff or 4 * d_model
256 | self.time_attention = AttentionLayer(FullAttention(False, configs.factor, attention_dropout=configs.dropout,
257 | output_attention=configs.output_attention), d_model, n_heads)
258 | self.dim_sender = AttentionLayer(FullAttention(False, configs.factor, attention_dropout=configs.dropout,
259 | output_attention=configs.output_attention), d_model, n_heads)
260 | self.dim_receiver = AttentionLayer(FullAttention(False, configs.factor, attention_dropout=configs.dropout,
261 | output_attention=configs.output_attention), d_model, n_heads)
262 | self.router = nn.Parameter(torch.randn(seg_num, factor, d_model))
263 |
264 | self.dropout = nn.Dropout(dropout)
265 |
266 | self.norm1 = nn.LayerNorm(d_model)
267 | self.norm2 = nn.LayerNorm(d_model)
268 | self.norm3 = nn.LayerNorm(d_model)
269 | self.norm4 = nn.LayerNorm(d_model)
270 |
271 | self.MLP1 = nn.Sequential(nn.Linear(d_model, d_ff),
272 | nn.GELU(),
273 | nn.Linear(d_ff, d_model))
274 | self.MLP2 = nn.Sequential(nn.Linear(d_model, d_ff),
275 | nn.GELU(),
276 | nn.Linear(d_ff, d_model))
277 |
278 | def forward(self, x, attn_mask=None, tau=None, delta=None):
279 | # Cross Time Stage: Directly apply MSA to each dimension
280 | batch = x.shape[0]
281 | time_in = rearrange(x, 'b ts_d seg_num d_model -> (b ts_d) seg_num d_model')
282 | time_enc, attn = self.time_attention(
283 | time_in, time_in, time_in, attn_mask=None, tau=None, delta=None
284 | )
285 | dim_in = time_in + self.dropout(time_enc)
286 | dim_in = self.norm1(dim_in)
287 | dim_in = dim_in + self.dropout(self.MLP1(dim_in))
288 | dim_in = self.norm2(dim_in)
289 |
290 | # Cross Dimension Stage: use a small set of learnable vectors to aggregate and distribute messages to build the D-to-D connection
291 | dim_send = rearrange(dim_in, '(b ts_d) seg_num d_model -> (b seg_num) ts_d d_model', b=batch)
292 | batch_router = repeat(self.router, 'seg_num factor d_model -> (repeat seg_num) factor d_model', repeat=batch)
293 | dim_buffer, attn = self.dim_sender(batch_router, dim_send, dim_send, attn_mask=None, tau=None, delta=None)
294 | dim_receive, attn = self.dim_receiver(dim_send, dim_buffer, dim_buffer, attn_mask=None, tau=None, delta=None)
295 | dim_enc = dim_send + self.dropout(dim_receive)
296 | dim_enc = self.norm3(dim_enc)
297 | dim_enc = dim_enc + self.dropout(self.MLP2(dim_enc))
298 | dim_enc = self.norm4(dim_enc)
299 |
300 | final_out = rearrange(dim_enc, '(b seg_num) ts_d d_model -> b ts_d seg_num d_model', b=batch)
301 |
302 | return final_out
303 |
--------------------------------------------------------------------------------
/Generation/diffusion_prior.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch import nn
4 | import torch.nn.functional as F
5 | import torch.optim as optim
6 | from tqdm import tqdm
7 |
8 | from diffusers.models.embeddings import Timesteps, TimestepEmbedding
9 | from torch.utils.data import Dataset
10 |
11 |
12 | class DiffusionPrior(nn.Module):
13 |
14 | def __init__(
15 | self,
16 | embed_dim=1024,
17 | cond_dim=42,
18 | hidden_dim=1024,
19 | layers_per_block=4,
20 | time_embed_dim=512,
21 | act_fn=nn.SiLU,
22 | dropout=0.0,
23 | ):
24 | super().__init__()
25 |
26 | self.embed_dim = embed_dim
27 |
28 | # 1. time embedding
29 | self.time_proj = Timesteps(time_embed_dim, True, 0)
30 | self.time_embedding = TimestepEmbedding(
31 | time_embed_dim,
32 | hidden_dim,
33 | )
34 |
35 | # 2. conditional embedding
36 | self.cond_embedding = nn.Linear(cond_dim, hidden_dim)
37 |
38 | # 3. prior mlp
39 |
40 | # 3.1 input
41 | self.input_layer = nn.Sequential(
42 | nn.Linear(embed_dim, hidden_dim),
43 | nn.LayerNorm(hidden_dim),
44 | act_fn(),
45 | )
46 |
47 | # 3.2 hidden
48 | self.hidden_layers = nn.ModuleList(
49 | [
50 | nn.Sequential(
51 | nn.Linear(hidden_dim, hidden_dim),
52 | nn.LayerNorm(hidden_dim),
53 | act_fn(),
54 | nn.Dropout(dropout),
55 | )
56 | for _ in range(layers_per_block)
57 | ]
58 | )
59 |
60 | # 3.3 output
61 | self.output_layer = nn.Linear(hidden_dim, embed_dim)
62 |
63 |
64 | def forward(self, x, t, c=None):
65 | # x (batch_size, embed_dim)
66 | # t (batch_size, )
67 | # c (batch_size, cond_dim)
68 |
69 | # 1. time embedding
70 | t = self.time_proj(t) # (batch_size, time_embed_dim)
71 | t = self.time_embedding(t) # (batch_size, hidden_dim)
72 |
73 | # 2. conditional embedding
74 | c = self.cond_embedding(c) if c is not None else 0 # (batch_size, hidden_dim)
75 |
76 | # 3. prior mlp
77 |
78 | # 3.1 input
79 | x = self.input_layer(x)
80 |
81 | # 3.2 hidden
82 | for layer in self.hidden_layers:
83 | x = x + t + c
84 | x = layer(x) + x
85 |
86 | # 3.3 output
87 | x = self.output_layer(x)
88 |
89 | return x
90 |
91 |
92 | class DiffusionPriorUNet(nn.Module):
93 |
94 | def __init__(
95 | self,
96 | embed_dim=1024,
97 | cond_dim=42,
98 | hidden_dim=[1024, 512, 256, 128, 64],
99 | time_embed_dim=512,
100 | act_fn=nn.SiLU,
101 | dropout=0.0,
102 | ):
103 | super().__init__()
104 |
105 | self.embed_dim = embed_dim
106 | self.cond_dim = cond_dim
107 | self.hidden_dim = hidden_dim
108 |
109 | # 1. time embedding
110 | self.time_proj = Timesteps(time_embed_dim, True, 0)
111 |
112 | # 2. conditional embedding
113 | # to 3.2, 3,3
114 |
115 | # 3. prior mlp
116 |
117 | # 3.1 input
118 | self.input_layer = nn.Sequential(
119 | nn.Linear(embed_dim, hidden_dim[0]),
120 | nn.LayerNorm(hidden_dim[0]),
121 | act_fn(),
122 | )
123 |
124 | # 3.2 hidden encoder
125 | self.num_layers = len(hidden_dim)
126 | self.encode_time_embedding = nn.ModuleList(
127 | [TimestepEmbedding(
128 | time_embed_dim,
129 | hidden_dim[i],
130 | ) for i in range(self.num_layers-1)]
131 | ) # d_0, ..., d_{n-1}
132 | self.encode_cond_embedding = nn.ModuleList(
133 | [nn.Linear(cond_dim, hidden_dim[i]) for i in range(self.num_layers-1)]
134 | )
135 | self.encode_layers = nn.ModuleList(
136 | [nn.Sequential(
137 | nn.Linear(hidden_dim[i], hidden_dim[i+1]),
138 | nn.LayerNorm(hidden_dim[i+1]),
139 | act_fn(),
140 | nn.Dropout(dropout),
141 | ) for i in range(self.num_layers-1)]
142 | )
143 |
144 | # 3.3 hidden decoder
145 | self.decode_time_embedding = nn.ModuleList(
146 | [TimestepEmbedding(
147 | time_embed_dim,
148 | hidden_dim[i],
149 | ) for i in range(self.num_layers-1,0,-1)]
150 | ) # d_{n}, ..., d_1
151 | self.decode_cond_embedding = nn.ModuleList(
152 | [nn.Linear(cond_dim, hidden_dim[i]) for i in range(self.num_layers-1,0,-1)]
153 | )
154 | self.decode_layers = nn.ModuleList(
155 | [nn.Sequential(
156 | nn.Linear(hidden_dim[i], hidden_dim[i-1]),
157 | nn.LayerNorm(hidden_dim[i-1]),
158 | act_fn(),
159 | nn.Dropout(dropout),
160 | ) for i in range(self.num_layers-1,0,-1)]
161 | )
162 |
163 | # 3.4 output
164 | self.output_layer = nn.Linear(hidden_dim[0], embed_dim)
165 |
166 |
167 | def forward(self, x, t, c=None):
168 | # x (batch_size, embed_dim)
169 | # t (batch_size, )
170 | # c (batch_size, cond_dim)
171 |
172 | # 1. time embedding
173 | t = self.time_proj(t) # (batch_size, time_embed_dim)
174 |
175 | # 2. conditional embedding
176 | # to 3.2, 3.3
177 |
178 | # 3. prior mlp
179 |
180 | # 3.1 input
181 | x = self.input_layer(x)
182 |
183 | # 3.2 hidden encoder
184 | hidden_activations = []
185 | for i in range(self.num_layers-1):
186 | hidden_activations.append(x)
187 | t_emb = self.encode_time_embedding[i](t)
188 | c_emb = self.encode_cond_embedding[i](c) if c is not None else 0
189 | x = x + t_emb + c_emb
190 | x = self.encode_layers[i](x)
191 |
192 | # 3.3 hidden decoder
193 | for i in range(self.num_layers-1):
194 | t_emb = self.decode_time_embedding[i](t)
195 | c_emb = self.decode_cond_embedding[i](c) if c is not None else 0
196 | x = x + t_emb + c_emb
197 | x = self.decode_layers[i](x)
198 | x += hidden_activations[-1-i]
199 |
200 | # 3.4 output
201 | x = self.output_layer(x)
202 |
203 | return x
204 |
205 |
206 | class EmbeddingDataset(Dataset):
207 |
208 | def __init__(self, c_embeddings, h_embeddings):
209 | self.c_embeddings = c_embeddings
210 | self.h_embeddings = h_embeddings
211 |
212 | def __len__(self):
213 | return len(self.c_embeddings)
214 |
215 | def __getitem__(self, idx):
216 | return {
217 | "c_embedding": self.c_embeddings[idx],
218 | "h_embedding": self.h_embeddings[idx]
219 | }
220 |
221 | class EmbeddingDatasetVICE(Dataset):
222 | def __init__(self, path_data):
223 | image_features_dict = torch.load(os.path.join(path_data, 'openclip_emb/image_features.pt'))
224 | self.embedding_vise = torch.load(os.path.join(path_data, 'variables/embedding_vise.pt'))
225 | self.image_features = image_features_dict['image_features']
226 | self.labels = image_features_dict['labels']
227 | self.label2index = image_features_dict['l2i']
228 |
229 | def __len__(self):
230 | return len(self.image_features)
231 |
232 | def __getitem__(self, idx):
233 | idx_c = self.label2index[self.labels[idx]]
234 | return {
235 | "c_embedding": self.embedding_vise[idx_c],
236 | "h_embedding": self.image_features[idx]
237 | }
238 |
239 |
240 | # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise
241 | def add_noise_with_sigma(
242 | self,
243 | original_samples: torch.FloatTensor,
244 | noise: torch.FloatTensor,
245 | timesteps: torch.FloatTensor,
246 | ) -> torch.FloatTensor:
247 | # Make sure sigmas and timesteps have the same device and dtype as original_samples
248 | sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
249 | if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
250 | # mps does not support float64
251 | schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
252 | timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
253 | else:
254 | schedule_timesteps = self.timesteps.to(original_samples.device)
255 | timesteps = timesteps.to(original_samples.device)
256 |
257 | step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
258 |
259 | sigma = sigmas[step_indices].flatten()
260 | while len(sigma.shape) < len(original_samples.shape):
261 | sigma = sigma.unsqueeze(-1)
262 |
263 | noisy_samples = original_samples + noise * sigma
264 | return noisy_samples, sigma
265 |
266 |
267 | # diffusion pipe
268 | class Pipe:
269 |
270 | def __init__(self, diffusion_prior=None, scheduler=None, device='cuda'):
271 | self.diffusion_prior = diffusion_prior.to(device)
272 |
273 | if scheduler is None:
274 | from diffusers.schedulers import DDPMScheduler
275 | self.scheduler = DDPMScheduler()
276 | # self.scheduler.add_noise_with_sigma = add_noise_with_sigma.__get__(self.scheduler)
277 | else:
278 | self.scheduler = scheduler
279 |
280 | self.device = device
281 |
282 | def train(self, dataloader, num_epochs=10, learning_rate=1e-4):
283 | self.diffusion_prior.train()
284 | device = self.device
285 | criterion = nn.MSELoss(reduction='none')
286 | optimizer = optim.Adam(self.diffusion_prior.parameters(), lr=learning_rate)
287 | from diffusers.optimization import get_cosine_schedule_with_warmup
288 | lr_scheduler = get_cosine_schedule_with_warmup(
289 | optimizer=optimizer,
290 | num_warmup_steps=500,
291 | num_training_steps=(len(dataloader) * num_epochs),
292 | )
293 |
294 | num_train_timesteps = self.scheduler.config.num_train_timesteps
295 |
296 | for epoch in range(num_epochs):
297 | loss_sum = 0
298 | for batch in dataloader:
299 | c_embeds = batch['c_embedding'].to(device) if 'c_embedding' in batch.keys() else None
300 | h_embeds = batch['h_embedding'].to(device)
301 | N = h_embeds.shape[0]
302 |
303 | # 1. randomly replecing c_embeds to None
304 | if torch.rand(1) < 0.1:
305 | c_embeds = None
306 |
307 | # 2. Generate noisy embeddings as input
308 | noise = torch.randn_like(h_embeds)
309 |
310 | # 3. sample timestep
311 | timesteps = torch.randint(0, num_train_timesteps, (N,), device=device)
312 |
313 | # 4. add noise to h_embedding
314 | perturbed_h_embeds = self.scheduler.add_noise(
315 | h_embeds,
316 | noise,
317 | timesteps
318 | ) # (batch_size, embed_dim), (batch_size, )
319 |
320 | # 5. predict noise
321 | noise_pre = self.diffusion_prior(perturbed_h_embeds, timesteps, c_embeds)
322 |
323 | # 6. loss function weighted by sigma
324 | loss = criterion(noise_pre, noise) # (batch_size,)
325 | loss = (loss).mean()
326 |
327 | # 7. update parameters
328 | optimizer.zero_grad()
329 | loss.backward()
330 | torch.nn.utils.clip_grad_norm_(self.diffusion_prior.parameters(), 1.0)
331 | lr_scheduler.step()
332 | optimizer.step()
333 |
334 | loss_sum += loss.item()
335 |
336 | loss_epoch = loss_sum / len(dataloader)
337 | print(f'epoch: {epoch}, loss: {loss_epoch}')
338 | # lr_scheduler.step(loss)
339 |
340 | def generate(
341 | self,
342 | c_embeds=None,
343 | num_inference_steps=50,
344 | timesteps=None,
345 | guidance_scale=5.0,
346 | generator=None
347 | ):
348 | # c_embeds (batch_size, cond_dim)
349 | self.diffusion_prior.eval()
350 | N = c_embeds.shape[0] if c_embeds is not None else 1
351 |
352 | # 1. Prepare timesteps
353 | from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import retrieve_timesteps
354 | timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, self.device, timesteps)
355 |
356 | # 2. Prepare c_embeds
357 | if c_embeds is not None:
358 | c_embeds = c_embeds.to(self.device)
359 |
360 | # 3. Prepare noise
361 | h_t = torch.randn(N, self.diffusion_prior.embed_dim, generator=generator, device=self.device)
362 |
363 | # 4. denoising loop
364 | for _, t in tqdm(enumerate(timesteps)):
365 | t = torch.ones(h_t.shape[0], dtype=torch.float, device=self.device) * t
366 | # 4.1 noise prediction
367 | if guidance_scale == 0 or c_embeds is None:
368 | noise_pred = self.diffusion_prior(h_t, t)
369 | else:
370 | noise_pred_cond = self.diffusion_prior(h_t, t, c_embeds)
371 | noise_pred_uncond = self.diffusion_prior(h_t, t)
372 | # perform classifier-free guidance
373 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
374 |
375 | # 4.2 compute the previous noisy sample h_t -> h_{t-1}
376 | h_t = self.scheduler.step(noise_pred, t.long().item(), h_t, generator=generator).prev_sample
377 |
378 | return h_t
379 |
380 | if __name__ == '__main__':
381 | import os
382 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
383 | # 1. test prior
384 | prior = DiffusionPriorUNet(cond_dim=1024)
385 | x = torch.randn(2, 1024)
386 | t = torch.randint(0, 1000, (2,))
387 | c = torch.randn(2, 1024)
388 | y = prior(x, t, c)
389 | print(y.shape)
390 |
391 |
392 |
393 |
--------------------------------------------------------------------------------
/Retrieval/eegdatasets_joint_subjects.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 | import numpy as np
4 | import os
5 | import clip
6 | from torch.nn import functional as F
7 | import torch.nn as nn
8 | from torchvision import transforms
9 | from PIL import Image
10 | import requests
11 |
12 | # import os
13 | # proxy = 'http://10.16.35.10:13390'
14 | # os.environ['http_proxy'] = proxy
15 | # os.environ['https_proxy'] = proxy
16 |
17 | device = "cuda:0" if torch.cuda.is_available() else "cpu"
18 | # vlmodel, preprocess = clip.load("ViT-B/32", device=device)
19 | model_type = 'ViT-H-14'
20 | import open_clip
21 | vlmodel, preprocess_train, feature_extractor = open_clip.create_model_and_transforms(
22 | model_type, pretrained='laion2b_s32b_b79k', precision='fp32', device=device)
23 |
24 | import json
25 |
26 | # Load the configuration from the JSON file
27 | config_path = "data_config.json"
28 | with open(config_path, "r") as config_file:
29 | config = json.load(config_file)
30 |
31 | # Access the paths from the config
32 | data_path = config["data_path"]
33 | img_directory_training = config["img_directory_training"]
34 | img_directory_test = config["img_directory_test"]
35 |
36 |
37 | class EEGDataset():
38 | """
39 | subjects = ['sub-01', 'sub-02', 'sub-05', 'sub-04', 'sub-03', 'sub-06', 'sub-07', 'sub-08', 'sub-09', 'sub-10']
40 | """
41 | def __init__(self, data_path, adap_subject=None, subjects=None, train=True, time_window=[0, 1.0], classes=None, pictures=None):
42 | self.data_path = data_path
43 | self.train = train
44 | self.subject_list = os.listdir(data_path)
45 | self.subjects = self.subject_list if subjects is None else subjects
46 | self.n_sub = len(self.subjects)
47 | self.time_window = time_window
48 | self.n_cls = 1654 if train else 200
49 | self.classes = classes
50 | self.pictures = pictures
51 | self.adap_subject = adap_subject # Save this parameter
52 |
53 | # Assert any subjects in subject_list
54 | assert any(sub in self.subject_list for sub in self.subjects)
55 |
56 | self.data, self.labels, self.text, self.img = self.load_data()
57 |
58 | self.data = self.extract_eeg(self.data, time_window)
59 |
60 | if self.classes is None and self.pictures is None:
61 | # Try to load the saved features if they exist
62 | features_filename = os.path.join(f'{model_type}_features_train.pt') if self.train else os.path.join(f'{model_type}_features_test.pt')
63 |
64 | if os.path.exists(features_filename):
65 | saved_features = torch.load(features_filename)
66 | self.text_features = saved_features['text_features']
67 | self.img_features = saved_features['img_features']
68 | else:
69 | self.text_features = self.Textencoder(self.text)
70 | self.img_features = self.ImageEncoder(self.img)
71 | torch.save({
72 | 'text_features': self.text_features.cpu(),
73 | 'img_features': self.img_features.cpu(),
74 | }, features_filename)
75 | else:
76 | self.text_features = self.Textencoder(self.text)
77 | self.img_features = self.ImageEncoder(self.img)
78 |
79 | def load_data(self):
80 | data_list = []
81 | label_list = []
82 | texts = []
83 | images = []
84 |
85 | if self.train:
86 | directory = img_directory_training
87 | else:
88 | directory = img_directory_test
89 | # Get all directories in the path
90 | dirnames = [d for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d))]
91 | dirnames.sort()
92 |
93 | if self.classes is not None:
94 | dirnames = [dirnames[i] for i in self.classes]
95 |
96 | for dir in dirnames:
97 | # Try to find the first occurrence of '_'
98 | try:
99 | idx = dir.index('_')
100 | description = dir[idx+1:] # Get all content after the first '_'
101 | except ValueError:
102 | print(f"Skipped: {dir} due to no '_' found.")
103 | continue
104 |
105 | new_description = f"This picture is {description}"
106 | texts.append(new_description)
107 |
108 | if self.train:
109 | img_directory = img_directory_training # Replace with your new address
110 | else:
111 | img_directory = img_directory_test
112 |
113 | all_folders = [d for d in os.listdir(img_directory) if os.path.isdir(os.path.join(img_directory, d))]
114 | all_folders.sort() # Ensure the order of folders
115 |
116 | if self.classes is not None and self.pictures is not None:
117 | images = [] # Initialize images list
118 | for i in range(len(self.classes)):
119 | class_idx = self.classes[i]
120 | pic_idx = self.pictures[i]
121 | if class_idx < len(all_folders):
122 | folder = all_folders[class_idx]
123 | folder_path = os.path.join(img_directory, folder)
124 | all_images = [img for img in os.listdir(folder_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))]
125 | all_images.sort()
126 | if pic_idx < len(all_images):
127 | images.append(os.path.join(folder_path, all_images[pic_idx]))
128 | elif self.classes is not None and self.pictures is None:
129 | images = [] # Initialize images list
130 | for i in range(len(self.classes)):
131 | class_idx = self.classes[i]
132 | if class_idx < len(all_folders):
133 | folder = all_folders[class_idx]
134 | folder_path = os.path.join(img_directory, folder)
135 | all_images = [img for img in os.listdir(folder_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))]
136 | all_images.sort()
137 | images.extend(os.path.join(folder_path, img) for img in all_images)
138 | elif self.classes is None:
139 | images = [] # Initialize images list
140 | for folder in all_folders:
141 | folder_path = os.path.join(img_directory, folder)
142 | all_images = [img for img in os.listdir(folder_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))]
143 | all_images.sort()
144 | images.extend(os.path.join(folder_path, img) for img in all_images)
145 | else:
146 | # Handle other cases, such as mismatched lengths of self.classes and self.pictures
147 | print("Error")
148 |
149 | print("self.subjects", self.subjects)
150 | print("adap_subject", self.adap_subject)
151 | for subject in self.subjects:
152 | if self.train:
153 | # if subject == self.adap_subject: # Skip the excluded subject
154 | # continue
155 | # print("subject:", subject)
156 | file_name = 'preprocessed_eeg_training.npy'
157 |
158 | file_path = os.path.join(self.data_path, subject, file_name)
159 | data = np.load(file_path, allow_pickle=True)
160 |
161 | preprocessed_eeg_data = torch.from_numpy(data['preprocessed_eeg_data']).float().detach()
162 | times = torch.from_numpy(data['times']).detach()[50:]
163 | ch_names = data['ch_names'] # Keep as a Python list or encode appropriately
164 |
165 | n_classes = 1654 # Each class contains 10 images
166 | samples_per_class = 10 # Each class has ten samples
167 |
168 | if self.classes is not None and self.pictures is not None:
169 | for c, p in zip(self.classes, self.pictures):
170 | start_index = c * 1 + p
171 | if start_index < len(preprocessed_eeg_data): # Ensure index is within range
172 | preprocessed_eeg_data_class = preprocessed_eeg_data[start_index: start_index+1] # Select only one sample
173 | labels = torch.full((1,), c, dtype=torch.long).detach() # Add class label
174 | data_list.append(preprocessed_eeg_data_class)
175 | label_list.append(labels) # Add labels to the label list
176 |
177 | elif self.classes is not None and self.pictures is None:
178 | for c in self.classes:
179 | start_index = c * samples_per_class
180 | preprocessed_eeg_data_class = preprocessed_eeg_data[start_index: start_index+samples_per_class]
181 | labels = torch.full((samples_per_class,), c, dtype=torch.long).detach() # Add class label
182 | data_list.append(preprocessed_eeg_data_class)
183 | label_list.append(labels)
184 |
185 | else:
186 | for i in range(n_classes):
187 | start_index = i * samples_per_class
188 | preprocessed_eeg_data_class = preprocessed_eeg_data[start_index: start_index+samples_per_class]
189 | labels = torch.full((samples_per_class,), i, dtype=torch.long).detach() # Add class label
190 | data_list.append(preprocessed_eeg_data_class)
191 | label_list.append(labels)
192 |
193 |
194 | else:
195 | if subject == self.adap_subject or self.adap_subject == None:
196 | file_name = 'preprocessed_eeg_test.npy'
197 | file_path = os.path.join(self.data_path, subject, file_name)
198 | data = np.load(file_path, allow_pickle=True)
199 | preprocessed_eeg_data = torch.from_numpy(data['preprocessed_eeg_data']).float().detach()
200 | times = torch.from_numpy(data['times']).detach()[50:]
201 | ch_names = data['ch_names'] # Keep as a Python list or encode appropriately
202 | n_classes = 200 # Each class contains 1 image
203 |
204 | samples_per_class = 1 # Each class has one sample
205 |
206 | for i in range(n_classes):
207 | if self.classes is not None and i not in self.classes: # Skip if class not in the specified list
208 | continue
209 | start_index = i * samples_per_class # Update start_index for each class
210 | preprocessed_eeg_data_class = preprocessed_eeg_data[start_index:start_index+samples_per_class]
211 | labels = torch.full((samples_per_class,), i, dtype=torch.long).detach() # Add class labels
212 | preprocessed_eeg_data_class = torch.mean(preprocessed_eeg_data_class.squeeze(0), 0)
213 | data_list.append(preprocessed_eeg_data_class)
214 | label_list.append(labels) # Add labels to the label list
215 | else:
216 | continue
217 | # Data list: (subjects * classes) * (10 * 4 * 17 * 100)
218 | # Data tensor: (subjects * classes * 10 * 4) * 17 * 100
219 | if self.train:
220 | data_tensor = torch.cat(data_list, dim=0).view(-1, *data_list[0].shape[2:])
221 | print("data_tensor", data_tensor.shape)
222 | else:
223 | data_tensor = torch.cat(data_list, dim=0).view(-1, *data_list[0].shape)
224 | label_tensor = torch.cat(label_list, dim=0)
225 | if self.train:
226 | # Label tensor: (subjects * classes * 10 * 4)
227 | label_tensor = label_tensor.repeat_interleave(4)
228 | if self.classes is not None:
229 | unique_values = list(label_tensor.numpy())
230 | lis = []
231 | for i in unique_values:
232 | if i not in lis:
233 | lis.append(i)
234 | unique_values = torch.tensor(lis)
235 | mapping = {val.item(): index for index, val in enumerate(unique_values)}
236 | label_tensor = torch.tensor([mapping[val.item()] for val in label_tensor], dtype=torch.long)
237 |
238 | else:
239 | pass
240 |
241 | self.times = times
242 | self.ch_names = ch_names
243 |
244 | print(f"Data tensor shape: {data_tensor.shape}, label tensor shape: {label_tensor.shape}, text length: {len(texts)}, image length: {len(images)}")
245 |
246 | return data_tensor, label_tensor, texts, images
247 |
248 | def extract_eeg(self, eeg_data, time_window):
249 |
250 | start, end = time_window
251 |
252 | # Get the indices of the times within the specified window
253 | indices = (self.times >= start) & (self.times <= end)
254 | # Use these indices to select the corresponding data
255 | extracted_data = eeg_data[..., indices]
256 | # print(f"extracted_data shape: {extracted_data.shape}")
257 |
258 | return extracted_data
259 |
260 | def Textencoder(self, text):
261 | # Use the preprocessor to convert text to the model's input format
262 | text_inputs = torch.cat([clip.tokenize(t) for t in text]).to(device)
263 |
264 | # Use the CLIP model to encode text
265 | with torch.no_grad():
266 | text_features = vlmodel.encode_text(text_inputs)
267 |
268 | text_features = F.normalize(text_features, dim=-1).detach()
269 |
270 | return text_features
271 |
272 | def ImageEncoder(self, images):
273 | batch_size = 20 # Set to an appropriate value
274 | image_features_list = []
275 |
276 | for i in range(0, len(images), batch_size):
277 | batch_images = images[i:i + batch_size]
278 | image_inputs = torch.stack([preprocess_train(Image.open(img).convert("RGB")) for img in batch_images]).to(device)
279 |
280 | with torch.no_grad():
281 | batch_image_features = vlmodel.encode_image(image_inputs)
282 | batch_image_features /= batch_image_features.norm(dim=-1, keepdim=True)
283 |
284 | image_features_list.append(batch_image_features)
285 |
286 | image_features = torch.cat(image_features_list, dim=0)
287 |
288 | return image_features
289 |
290 | def __getitem__(self, index):
291 | # Get the data and label corresponding to "index"
292 | # index: (subjects * classes * 10 * 4)
293 | x = self.data[index]
294 | label = self.labels[index]
295 |
296 | if self.pictures is None:
297 | if self.classes is None:
298 | index_n_sub_train = self.n_cls * 10 * 4
299 | index_n_sub_test = self.n_cls * 1 * 80
300 | else:
301 | index_n_sub_test = len(self.classes)* 1 * 80
302 | index_n_sub_train = len(self.classes)* 10 * 4
303 | # text_index: classes
304 | if self.train:
305 | text_index = (index % index_n_sub_train) // (10 * 4)
306 | else:
307 | text_index = (index % index_n_sub_test)
308 | # img_index: classes * 10
309 | if self.train:
310 | img_index = (index % index_n_sub_train) // (4)
311 | else:
312 | img_index = (index % index_n_sub_test)
313 | else:
314 | if self.classes is None:
315 | index_n_sub_train = self.n_cls * 1 * 4
316 | index_n_sub_test = self.n_cls * 1 * 80
317 | else:
318 | index_n_sub_test = len(self.classes)* 1 * 80
319 | index_n_sub_train = len(self.classes)* 1 * 4
320 | # text_index: classes
321 | if self.train:
322 | text_index = (index % index_n_sub_train) // (1 * 4)
323 | else:
324 | text_index = (index % index_n_sub_test)
325 | # img_index: classes * 10
326 | if self.train:
327 | img_index = (index % index_n_sub_train) // (4)
328 | else:
329 | img_index = (index % index_n_sub_test)
330 |
331 | text = self.text[text_index]
332 | img = self.img[img_index]
333 |
334 | text_features = self.text_features[text_index]
335 | img_features = self.img_features[img_index]
336 |
337 | return x, label, text, text_features, img, img_features
338 |
339 | def __len__(self):
340 | return self.data.shape[0] # or self.labels.shape[0] which should be the same
341 |
342 | if __name__ == "__main__":
343 | # Instantiate the dataset and dataloader
344 | # data_path = "/home/ldy/Workspace/THINGS/EEG/osfstorage-archive" # Replace with the path to your data
345 | data_path = data_path
346 | train_dataset = EEGDataset(data_path, subjects=['sub-01'], train=True)
347 | test_dataset = EEGDataset(data_path, subjects=['sub-01'], train=False)
348 | # train_dataset = EEGDataset(data_path, adap_subject='sub-01', train=True)
349 | # test_dataset = EEGDataset(data_path, adap_subject='sub-01', train=False)
350 | # train_dataset = EEGDataset(data_path, train=True)
351 | # test_dataset = EEGDataset(data_path, train=False)
352 | # Training EEG data shape: torch.Size([16540, 4, 17, 100]) [Number of training images, repetition count, channels, EEG time points]
353 | # Testing EEG data shape: torch.Size([200, 80, 17, 100])
354 | # 1 second 'times': array([-0.2 , -0.19, -0.18, ... , 0.76, 0.77, 0.78, 0.79])}
355 | # 17 channels 'ch_names': ['Pz', 'P3', 'P7', 'O1', 'Oz', 'O2', 'P4', 'P8', 'P1', 'P5', 'PO7', 'PO3', 'POz', 'PO4', 'PO8', 'P6', 'P2']
356 | # 100 Hz
357 | train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
358 | test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)
359 |
360 | i = 80*1-1
361 | x, label, text, text_features, img, img_features = test_dataset[i]
362 | print(f"Index {i}, Label: {label}, text: {text}")
363 | Image.open(img)
364 |
--------------------------------------------------------------------------------
/Retrieval/eegdatasets_leaveone.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 | import numpy as np
4 | import os
5 | import clip
6 | from torch.nn import functional as F
7 | import torch.nn as nn
8 | from torchvision import transforms
9 | from PIL import Image
10 | import requests
11 |
12 | import os
13 | # Note: Set http_proxy/https_proxy environment variables if needed
14 | cuda_device_count = torch.cuda.device_count()
15 | print(cuda_device_count)
16 | device = "cuda:0" if torch.cuda.is_available() else "cpu"
17 | # vlmodel, preprocess = clip.load("ViT-B/32", device=device)
18 | model_type = 'ViT-H-14'
19 | import open_clip
20 | vlmodel, preprocess_train, feature_extractor = open_clip.create_model_and_transforms(
21 | model_type, pretrained='laion2b_s32b_b79k', precision='fp32', device = device)
22 |
23 | import json
24 |
25 | # Load the configuration from the JSON file
26 | config_path = "data_config.json"
27 | with open(config_path, "r") as config_file:
28 | config = json.load(config_file)
29 |
30 | # Access the paths from the config
31 | data_path = config["data_path"]
32 | img_directory_training = config["img_directory_training"]
33 | img_directory_test = config["img_directory_test"]
34 |
35 |
36 | class EEGDataset():
37 | """
38 | subjects = ['sub-01', 'sub-02', 'sub-05', 'sub-04', 'sub-03', 'sub-06', 'sub-07', 'sub-08', 'sub-09', 'sub-10']
39 | """
40 | def __init__(self, data_path, exclude_subject=None, subjects=None, train=True, time_window=[0, 1.0], classes = None, pictures = None, val_size=None):
41 | self.data_path = data_path
42 | self.train = train
43 | self.subject_list = os.listdir(data_path)
44 | self.subjects = self.subject_list if subjects is None else subjects
45 | self.n_sub = len(self.subjects)
46 | self.time_window = time_window
47 | self.n_cls = 1654 if train else 200
48 | self.classes = classes
49 | self.pictures = pictures
50 | self.exclude_subject = exclude_subject
51 | self.val_size = val_size
52 | # assert any subjects in subject_list
53 | assert any(sub in self.subject_list for sub in self.subjects)
54 |
55 | self.data, self.labels, self.text, self.img = self.load_data()
56 |
57 | self.data = self.extract_eeg(self.data, time_window)
58 |
59 |
60 | if self.classes is None and self.pictures is None:
61 | # Try to load the saved features if they exist
62 | features_filename = os.path.join(f'{model_type}_features_train.pt') if self.train else os.path.join(f'{model_type}_features_test.pt')
63 |
64 | if os.path.exists(features_filename) :
65 | saved_features = torch.load(features_filename)
66 | self.text_features = saved_features['text_features']
67 | self.img_features = saved_features['img_features']
68 | else:
69 | self.text_features = self.Textencoder(self.text)
70 | self.img_features = self.ImageEncoder(self.img)
71 | torch.save({
72 | 'text_features': self.text_features.cpu(),
73 | 'img_features': self.img_features.cpu(),
74 | }, features_filename)
75 | else:
76 | self.text_features = self.Textencoder(self.text)
77 | self.img_features = self.ImageEncoder(self.img)
78 |
79 | def load_data(self):
80 | data_list = []
81 | label_list = []
82 | texts = []
83 | images = []
84 |
85 | if self.train:
86 | directory = img_directory_training
87 | else:
88 | directory = img_directory_test
89 |
90 | dirnames = [d for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d))]
91 | dirnames.sort()
92 |
93 | if self.classes is not None:
94 | dirnames = [dirnames[i] for i in self.classes]
95 |
96 | for dir in dirnames:
97 |
98 | try:
99 | idx = dir.index('_')
100 | description = dir[idx+1:]
101 | except ValueError:
102 | print(f"Skipped: {dir} due to no '_' found.")
103 | continue
104 |
105 | new_description = f"This picture is {description}"
106 | texts.append(new_description)
107 |
108 | if self.train:
109 | img_directory = img_directory_training
110 | else:
111 | img_directory = img_directory_test
112 |
113 | all_folders = [d for d in os.listdir(img_directory) if os.path.isdir(os.path.join(img_directory, d))]
114 | all_folders.sort()
115 |
116 | if self.classes is not None and self.pictures is not None:
117 | images = []
118 | for i in range(len(self.classes)):
119 | class_idx = self.classes[i]
120 | pic_idx = self.pictures[i]
121 | if class_idx < len(all_folders):
122 | folder = all_folders[class_idx]
123 | folder_path = os.path.join(img_directory, folder)
124 | all_images = [img for img in os.listdir(folder_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))]
125 | all_images.sort()
126 | if pic_idx < len(all_images):
127 | images.append(os.path.join(folder_path, all_images[pic_idx]))
128 | elif self.classes is not None and self.pictures is None:
129 | images = []
130 | for i in range(len(self.classes)):
131 | class_idx = self.classes[i]
132 | if class_idx < len(all_folders):
133 | folder = all_folders[class_idx]
134 | folder_path = os.path.join(img_directory, folder)
135 | all_images = [img for img in os.listdir(folder_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))]
136 | all_images.sort()
137 | images.extend(os.path.join(folder_path, img) for img in all_images)
138 | elif self.classes is None:
139 | images = []
140 | for folder in all_folders:
141 | folder_path = os.path.join(img_directory, folder)
142 | all_images = [img for img in os.listdir(folder_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))]
143 | all_images.sort()
144 | images.extend(os.path.join(folder_path, img) for img in all_images)
145 | else:
146 |
147 | print("Error")
148 |
149 | print("self.subjects", self.subjects)
150 | print("exclude_subject", self.exclude_subject)
151 | for subject in self.subjects:
152 | if self.train:
153 | if subject == self.exclude_subject:
154 | continue
155 | # print("subject:", subject)
156 | file_name = 'preprocessed_eeg_training.npy'
157 |
158 | file_path = os.path.join(self.data_path, subject, file_name)
159 | data = np.load(file_path, allow_pickle=True)
160 |
161 | preprocessed_eeg_data = torch.from_numpy(data['preprocessed_eeg_data']).float().detach()
162 | times = torch.from_numpy(data['times']).detach()[50:]
163 | ch_names = data['ch_names']
164 |
165 | n_classes = 1654
166 | samples_per_class = 10
167 |
168 | if self.classes is not None and self.pictures is not None:
169 | for c, p in zip(self.classes, self.pictures):
170 | start_index = c * 1 + p
171 | if start_index < len(preprocessed_eeg_data):
172 | preprocessed_eeg_data_class = preprocessed_eeg_data[start_index: start_index+1]
173 | labels = torch.full((1,), c, dtype=torch.long).detach()
174 | data_list.append(preprocessed_eeg_data_class)
175 | label_list.append(labels)
176 |
177 | elif self.classes is not None and self.pictures is None:
178 | for c in self.classes:
179 | start_index = c * samples_per_class
180 | preprocessed_eeg_data_class = preprocessed_eeg_data[start_index: start_index+samples_per_class]
181 | labels = torch.full((samples_per_class,), c, dtype=torch.long).detach()
182 | data_list.append(preprocessed_eeg_data_class)
183 | label_list.append(labels)
184 |
185 | else:
186 | for i in range(n_classes):
187 | start_index = i * samples_per_class
188 | # if self.exclude_subject==None:
189 | # preprocessed_eeg_data_class = preprocessed_eeg_data[start_index: start_index+samples_per_class]
190 | # else:
191 | preprocessed_eeg_data_class = preprocessed_eeg_data[start_index: start_index+samples_per_class]
192 | # print("preprocessed_eeg_data_class", preprocessed_eeg_data_class.shape)
193 | # preprocessed_eeg_data_class = torch.mean(preprocessed_eeg_data_class, 1)
194 | # preprocessed_eeg_data_class = torch.mean(preprocessed_eeg_data_class, 0)
195 | # print("preprocessed_eeg_data_class", preprocessed_eeg_data_class.shape)
196 | labels = torch.full((samples_per_class,), i, dtype=torch.long).detach()
197 | data_list.append(preprocessed_eeg_data_class)
198 | label_list.append(labels)
199 |
200 |
201 | else:
202 | if subject == self.exclude_subject or self.exclude_subject==None:
203 | file_name = 'preprocessed_eeg_test.npy'
204 | file_path = os.path.join(self.data_path, subject, file_name)
205 | data = np.load(file_path, allow_pickle=True)
206 | preprocessed_eeg_data = torch.from_numpy(data['preprocessed_eeg_data']).float().detach()
207 | times = torch.from_numpy(data['times']).detach()[50:]
208 | ch_names = data['ch_names']
209 | n_classes = 200 # Each class contains 1 images
210 |
211 | samples_per_class = 1
212 |
213 | for i in range(n_classes):
214 | if self.classes is not None and i not in self.classes: # If we've defined specific classes and the current class is not in the list, skip
215 | continue
216 | start_index = i * samples_per_class # Update start_index for each class
217 | preprocessed_eeg_data_class = preprocessed_eeg_data[start_index:start_index+samples_per_class]
218 | # print("preprocessed_eeg_data_class", preprocessed_eeg_data_class.shape)
219 | labels = torch.full((samples_per_class,), i, dtype=torch.long).detach() # Add class labels
220 | preprocessed_eeg_data_class = torch.mean(preprocessed_eeg_data_class.squeeze(0), 0)
221 | # print("preprocessed_eeg_data_class", preprocessed_eeg_data_class.shape)
222 | data_list.append(preprocessed_eeg_data_class)
223 | label_list.append(labels) # Add labels to the label list
224 | else:
225 | continue
226 | # datalist: (subjects * classes) * (10 * 4 * 17 * 100)
227 | # data_tensor: (subjects * classes * 10 * 4) * 17 * 100
228 | # data_list = np.mean(data_list, )
229 | # print("data_list", len(data_list))
230 | if self.train:
231 | # print("data_list", *data_list[0].shape[1:])
232 | data_tensor = torch.cat(data_list, dim=0).view(-1, *data_list[0].shape[2:])
233 | # data_tensor = torch.cat(data_list, dim=0).view(-1, *data_list[0].shape[1:])
234 | # data_tensor = torch.cat(data_list, dim=0).view(-1, *data_list[0].shape)
235 | # print("label_tensor", label_tensor.shape)
236 | print("data_tensor", data_tensor.shape)
237 | else:
238 | data_tensor = torch.cat(data_list, dim=0).view(-1, *data_list[0].shape)
239 | # label_tensor = torch.cat(label_list, dim=0)
240 | # print("label_tensor", label_tensor.shape)
241 | # data_tensor = torch.cat(data_list, dim=0).view(-1, *data_list[0].shape[2:])
242 | # print("data_tensor", data_tensor.shape)
243 | # label_list: (subjects * classes) * 10
244 | # label_tensor: (subjects * classes * 10)
245 | # print("label_tensor = torch.cat(label_list, dim=0)")
246 | # print(label_list)
247 | label_tensor = torch.cat(label_list, dim=0)
248 | # label_tensor = torch.cat(label_list, dim=0)
249 | # print(label_tensor[:300])
250 | if self.train:
251 | # label_tensor: (subjects * classes * 10 * 4)
252 | label_tensor = label_tensor.repeat_interleave(4)
253 | if self.classes is not None:
254 | unique_values = list(label_tensor.numpy())
255 | lis = []
256 | for i in unique_values:
257 | if i not in lis:
258 | lis.append(i)
259 | unique_values = torch.tensor(lis)
260 | mapping = {val.item(): index for index, val in enumerate(unique_values)}
261 | label_tensor = torch.tensor([mapping[val.item()] for val in label_tensor], dtype=torch.long)
262 |
263 | else:
264 | # label_tensor = label_tensor.repeat_interleave(80)
265 | # if self.classes is not None:
266 | # unique_values = torch.unique(label_tensor, sorted=False)
267 |
268 | # mapping = {val.item(): index for index, val in enumerate(torch.flip(unique_values, [0]))}
269 | # label_tensor = torch.tensor([mapping[val.item()] for val in label_tensor], dtype=torch.long)
270 | pass
271 |
272 |
273 | self.times = times
274 | self.ch_names = ch_names
275 |
276 | print(f"Data tensor shape: {data_tensor.shape}, label tensor shape: {label_tensor.shape}, text length: {len(texts)}, image length: {len(images)}")
277 |
278 | return data_tensor, label_tensor, texts, images
279 |
280 | def extract_eeg(self, eeg_data, time_window):
281 |
282 | start, end = time_window
283 |
284 | # Get the indices of the times within the specified window
285 | indices = (self.times >= start) & (self.times <= end)
286 | # print("self.times", self.times.shape)
287 | # print("indices", indices)
288 | # print("indices", indices.shape)
289 | # print("eeg_data", eeg_data.shape)
290 | # Use these indices to select the corresponding data
291 | extracted_data = eeg_data[..., indices]
292 | # print(f"extracted_data shape: {extracted_data.shape}")
293 |
294 | return extracted_data
295 |
296 | def Textencoder(self, text):
297 |
298 | text_inputs = torch.cat([clip.tokenize(t) for t in text]).to(device)
299 | # print("text_inputs", text_inputs)
300 |
301 | with torch.no_grad():
302 | text_features = vlmodel.encode_text(text_inputs)
303 |
304 | text_features = F.normalize(text_features, dim=-1).detach()
305 |
306 | return text_features
307 |
308 | def ImageEncoder(self,images):
309 | batch_size = 20
310 | image_features_list = []
311 |
312 | for i in range(0, len(images), batch_size):
313 | batch_images = images[i:i + batch_size]
314 | image_inputs = torch.stack([preprocess_train(Image.open(img).convert("RGB")) for img in batch_images]).to(device)
315 |
316 | with torch.no_grad():
317 | batch_image_features = vlmodel.encode_image(image_inputs)
318 | batch_image_features /= batch_image_features.norm(dim=-1, keepdim=True)
319 |
320 | image_features_list.append(batch_image_features)
321 |
322 | image_features = torch.cat(image_features_list, dim=0)
323 |
324 | return image_features
325 |
326 | def __getitem__(self, index):
327 | # Get the data and label corresponding to "index"
328 | # index: (subjects * classes * 10 * 4)
329 | x = self.data[index]
330 | label = self.labels[index]
331 |
332 | if self.pictures is None:
333 | if self.classes is None:
334 | index_n_sub_train = self.n_cls * 10 * 4
335 | index_n_sub_test = self.n_cls * 1 * 80
336 | else:
337 | index_n_sub_test = len(self.classes)* 1 * 80
338 | index_n_sub_train = len(self.classes)* 10 * 4
339 | # text_index: classes
340 | if self.train:
341 | text_index = (index % index_n_sub_train) // (10 * 4)
342 | else:
343 | text_index = (index % index_n_sub_test)
344 | # img_index: classes * 10
345 | if self.train:
346 | img_index = (index % index_n_sub_train) // (4)
347 | else:
348 | img_index = (index % index_n_sub_test)
349 | else:
350 | if self.classes is None:
351 | index_n_sub_train = self.n_cls * 1 * 4
352 | index_n_sub_test = self.n_cls * 1 * 80
353 | else:
354 | index_n_sub_test = len(self.classes)* 1 * 80
355 | index_n_sub_train = len(self.classes)* 1 * 4
356 | # text_index: classes
357 | if self.train:
358 | text_index = (index % index_n_sub_train) // (1 * 4)
359 | else:
360 | text_index = (index % index_n_sub_test)
361 | # img_index: classes * 10
362 | if self.train:
363 | img_index = (index % index_n_sub_train) // (4)
364 | else:
365 | img_index = (index % index_n_sub_test)
366 | # print("text_index", text_index)
367 | # print("self.text", self.text)
368 | # print("self.text", len(self.text))
369 | text = self.text[text_index]
370 | img = self.img[img_index]
371 |
372 | text_features = self.text_features[text_index]
373 | img_features = self.img_features[img_index]
374 |
375 | return x, label, text, text_features, img, img_features
376 |
377 | def __len__(self):
378 | return self.data.shape[0] # or self.labels.shape[0] which should be the same
379 |
380 | if __name__ == "__main__":
381 | # Instantiate the dataset and dataloader
382 | # data_path = "/home/ldy/Workspace/THINGS/EEG/osfstorage-archive" # Replace with the path to your data
383 | data_path = data_path
384 | train_dataset = EEGDataset(data_path, subjects = ['sub-01'], train=True)
385 | test_dataset = EEGDataset(data_path, subjects = ['sub-01'], train=False)
386 | # train_dataset = EEGDataset(data_path, exclude_subject = 'sub-01', train=True)
387 | # test_dataset = EEGDataset(data_path, exclude_subject = 'sub-01', train=False)
388 | # train_dataset = EEGDataset(data_path, train=True)
389 | # test_dataset = EEGDataset(data_path, train=False)
390 |
391 |
392 |
393 |
394 | # 100 Hz
395 | train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
396 | test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)
397 |
398 | i = 80*1-1
399 | x, label, text, text_features, img, img_features = test_dataset[i]
400 | print(f"Index {i}, Label: {label}, text: {text}")
401 | Image.open(img)
402 |
403 |
404 |
405 |
--------------------------------------------------------------------------------
/Generation/eegdatasets_leaveone.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 | import numpy as np
4 | import os
5 | import clip
6 | from torch.nn import functional as F
7 | import torch.nn as nn
8 | from torchvision import transforms
9 | from PIL import Image
10 | import requests
11 |
12 | import os
13 | # Note: Set http_proxy/https_proxy environment variables if needed
14 | cuda_device_count = torch.cuda.device_count()
15 | print(cuda_device_count)
16 | device = "cuda:0" if torch.cuda.is_available() else "cpu"
17 | # vlmodel, preprocess = clip.load("ViT-B/32", device=device)
18 | model_type = 'ViT-H-14'
19 | import open_clip
20 | vlmodel, preprocess_train, feature_extractor = open_clip.create_model_and_transforms(
21 | model_type, pretrained='laion2b_s32b_b79k', precision='fp32', device = device)
22 |
23 | import json
24 |
25 | # Load the configuration from the JSON file
26 | config_path = "data_config.json"
27 | with open(config_path, "r") as config_file:
28 | config = json.load(config_file)
29 |
30 | # Access the paths from the config
31 | data_path = config["data_path"]
32 | img_directory_training = config["img_directory_training"]
33 | img_directory_test = config["img_directory_test"]
34 |
35 |
36 | class EEGDataset():
37 | """
38 | subjects = ['sub-01', 'sub-02', 'sub-05', 'sub-04', 'sub-03', 'sub-06', 'sub-07', 'sub-08', 'sub-09', 'sub-10']
39 | """
40 | def __init__(self, data_path, exclude_subject=None, subjects=None, train=True, time_window=[0, 1.0], classes = None, pictures = None, val_size=None):
41 | self.data_path = data_path
42 | self.train = train
43 | self.subject_list = os.listdir(data_path)
44 | self.subjects = self.subject_list if subjects is None else subjects
45 | self.n_sub = len(self.subjects)
46 | self.time_window = time_window
47 | self.n_cls = 1654 if train else 200
48 | self.classes = classes
49 | self.pictures = pictures
50 | self.exclude_subject = exclude_subject
51 | self.val_size = val_size
52 | # assert any subjects in subject_list
53 | assert any(sub in self.subject_list for sub in self.subjects)
54 |
55 | self.data, self.labels, self.text, self.img = self.load_data()
56 |
57 | self.data = self.extract_eeg(self.data, time_window)
58 |
59 |
60 | if self.classes is None and self.pictures is None:
61 | # Try to load the saved features if they exist
62 | features_filename = os.path.join(f'{model_type}_features_train.pt') if self.train else os.path.join(f'{model_type}_features_test.pt')
63 |
64 | if os.path.exists(features_filename) :
65 | saved_features = torch.load(features_filename)
66 | self.text_features = saved_features['text_features']
67 | self.img_features = saved_features['img_features']
68 | else:
69 | self.text_features = self.Textencoder(self.text)
70 | self.img_features = self.ImageEncoder(self.img)
71 | torch.save({
72 | 'text_features': self.text_features.cpu(),
73 | 'img_features': self.img_features.cpu(),
74 | }, features_filename)
75 | else:
76 | self.text_features = self.Textencoder(self.text)
77 | self.img_features = self.ImageEncoder(self.img)
78 |
79 | def load_data(self):
80 | data_list = []
81 | label_list = []
82 | texts = []
83 | images = []
84 |
85 | if self.train:
86 | directory = img_directory_training
87 | else:
88 | directory = img_directory_test
89 |
90 | dirnames = [d for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d))]
91 | dirnames.sort()
92 |
93 | if self.classes is not None:
94 | dirnames = [dirnames[i] for i in self.classes]
95 |
96 | for dir in dirnames:
97 |
98 | try:
99 | idx = dir.index('_')
100 | description = dir[idx+1:]
101 | except ValueError:
102 | print(f"Skipped: {dir} due to no '_' found.")
103 | continue
104 |
105 | new_description = f"This picture is {description}"
106 | texts.append(new_description)
107 |
108 | if self.train:
109 | img_directory = img_directory_training
110 | else:
111 | img_directory = img_directory_test
112 |
113 | all_folders = [d for d in os.listdir(img_directory) if os.path.isdir(os.path.join(img_directory, d))]
114 | all_folders.sort()
115 |
116 | if self.classes is not None and self.pictures is not None:
117 | images = []
118 | for i in range(len(self.classes)):
119 | class_idx = self.classes[i]
120 | pic_idx = self.pictures[i]
121 | if class_idx < len(all_folders):
122 | folder = all_folders[class_idx]
123 | folder_path = os.path.join(img_directory, folder)
124 | all_images = [img for img in os.listdir(folder_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))]
125 | all_images.sort()
126 | if pic_idx < len(all_images):
127 | images.append(os.path.join(folder_path, all_images[pic_idx]))
128 | elif self.classes is not None and self.pictures is None:
129 | images = []
130 | for i in range(len(self.classes)):
131 | class_idx = self.classes[i]
132 | if class_idx < len(all_folders):
133 | folder = all_folders[class_idx]
134 | folder_path = os.path.join(img_directory, folder)
135 | all_images = [img for img in os.listdir(folder_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))]
136 | all_images.sort()
137 | images.extend(os.path.join(folder_path, img) for img in all_images)
138 | elif self.classes is None:
139 | images = []
140 | for folder in all_folders:
141 | folder_path = os.path.join(img_directory, folder)
142 | all_images = [img for img in os.listdir(folder_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))]
143 | all_images.sort()
144 | images.extend(os.path.join(folder_path, img) for img in all_images)
145 | else:
146 |
147 | print("Error")
148 |
149 | print("self.subjects", self.subjects)
150 | print("exclude_subject", self.exclude_subject)
151 | for subject in self.subjects:
152 | if self.train:
153 | if subject == self.exclude_subject:
154 | continue
155 | # print("subject:", subject)
156 | file_name = 'preprocessed_eeg_training.npy'
157 |
158 | file_path = os.path.join(self.data_path, subject, file_name)
159 | data = np.load(file_path, allow_pickle=True)
160 |
161 | preprocessed_eeg_data = torch.from_numpy(data['preprocessed_eeg_data']).float().detach()
162 | times = torch.from_numpy(data['times']).detach()[50:]
163 | ch_names = data['ch_names']
164 |
165 | n_classes = 1654
166 | samples_per_class = 10
167 |
168 | if self.classes is not None and self.pictures is not None:
169 | for c, p in zip(self.classes, self.pictures):
170 | start_index = c * 1 + p
171 | if start_index < len(preprocessed_eeg_data):
172 | preprocessed_eeg_data_class = preprocessed_eeg_data[start_index: start_index+1]
173 | labels = torch.full((1,), c, dtype=torch.long).detach()
174 | data_list.append(preprocessed_eeg_data_class)
175 | label_list.append(labels)
176 |
177 | elif self.classes is not None and self.pictures is None:
178 | for c in self.classes:
179 | start_index = c * samples_per_class
180 | preprocessed_eeg_data_class = preprocessed_eeg_data[start_index: start_index+samples_per_class]
181 | labels = torch.full((samples_per_class,), c, dtype=torch.long).detach()
182 | data_list.append(preprocessed_eeg_data_class)
183 | label_list.append(labels)
184 |
185 | else:
186 | for i in range(n_classes):
187 | start_index = i * samples_per_class
188 | # if self.exclude_subject==None:
189 | # preprocessed_eeg_data_class = preprocessed_eeg_data[start_index: start_index+samples_per_class]
190 | # else:
191 | preprocessed_eeg_data_class = preprocessed_eeg_data[start_index: start_index+samples_per_class]
192 | # print("preprocessed_eeg_data_class", preprocessed_eeg_data_class.shape)
193 | # preprocessed_eeg_data_class = torch.mean(preprocessed_eeg_data_class, 1)
194 | # preprocessed_eeg_data_class = torch.mean(preprocessed_eeg_data_class, 0)
195 | # print("preprocessed_eeg_data_class", preprocessed_eeg_data_class.shape)
196 | labels = torch.full((samples_per_class,), i, dtype=torch.long).detach()
197 | data_list.append(preprocessed_eeg_data_class)
198 | label_list.append(labels)
199 |
200 |
201 | else:
202 | if subject == self.exclude_subject or self.exclude_subject==None:
203 | file_name = 'preprocessed_eeg_test.npy'
204 | file_path = os.path.join(self.data_path, subject, file_name)
205 | data = np.load(file_path, allow_pickle=True)
206 | preprocessed_eeg_data = torch.from_numpy(data['preprocessed_eeg_data']).float().detach()
207 | times = torch.from_numpy(data['times']).detach()[50:]
208 | ch_names = data['ch_names']
209 | n_classes = 200 # Each class contains 1 images
210 |
211 | samples_per_class = 1
212 |
213 | for i in range(n_classes):
214 | if self.classes is not None and i not in self.classes: # If we've defined specific classes and the current class is not in the list, skip
215 | continue
216 | start_index = i * samples_per_class # Update start_index for each class
217 | preprocessed_eeg_data_class = preprocessed_eeg_data[start_index:start_index+samples_per_class]
218 | # print("preprocessed_eeg_data_class", preprocessed_eeg_data_class.shape)
219 | labels = torch.full((samples_per_class,), i, dtype=torch.long).detach() # Add class labels
220 | preprocessed_eeg_data_class = torch.mean(preprocessed_eeg_data_class.squeeze(0), 0)
221 | # print("preprocessed_eeg_data_class", preprocessed_eeg_data_class.shape)
222 | data_list.append(preprocessed_eeg_data_class)
223 | label_list.append(labels) # Add labels to the label list
224 | else:
225 | continue
226 | # datalist: (subjects * classes) * (10 * 4 * 17 * 100)
227 | # data_tensor: (subjects * classes * 10 * 4) * 17 * 100
228 | # data_list = np.mean(data_list, )
229 | # print("data_list", len(data_list))
230 | if self.train:
231 | # print("data_list", *data_list[0].shape[1:])
232 | data_tensor = torch.cat(data_list, dim=0).view(-1, *data_list[0].shape[2:])
233 | # data_tensor = torch.cat(data_list, dim=0).view(-1, *data_list[0].shape[1:])
234 | # data_tensor = torch.cat(data_list, dim=0).view(-1, *data_list[0].shape)
235 | # print("label_tensor", label_tensor.shape)
236 | print("data_tensor", data_tensor.shape)
237 | else:
238 | data_tensor = torch.cat(data_list, dim=0).view(-1, *data_list[0].shape)
239 | # label_tensor = torch.cat(label_list, dim=0)
240 | # print("label_tensor", label_tensor.shape)
241 | # data_tensor = torch.cat(data_list, dim=0).view(-1, *data_list[0].shape[2:])
242 | # print("data_tensor", data_tensor.shape)
243 | # label_list: (subjects * classes) * 10
244 | # label_tensor: (subjects * classes * 10)
245 | # print("label_tensor = torch.cat(label_list, dim=0)")
246 | # print(label_list)
247 | label_tensor = torch.cat(label_list, dim=0)
248 | # label_tensor = torch.cat(label_list, dim=0)
249 | # print(label_tensor[:300])
250 | if self.train:
251 | # label_tensor: (subjects * classes * 10 * 4)
252 | label_tensor = label_tensor.repeat_interleave(4)
253 | if self.classes is not None:
254 | unique_values = list(label_tensor.numpy())
255 | lis = []
256 | for i in unique_values:
257 | if i not in lis:
258 | lis.append(i)
259 | unique_values = torch.tensor(lis)
260 | mapping = {val.item(): index for index, val in enumerate(unique_values)}
261 | label_tensor = torch.tensor([mapping[val.item()] for val in label_tensor], dtype=torch.long)
262 |
263 | else:
264 | # label_tensor = label_tensor.repeat_interleave(80)
265 | # if self.classes is not None:
266 | # unique_values = torch.unique(label_tensor, sorted=False)
267 |
268 | # mapping = {val.item(): index for index, val in enumerate(torch.flip(unique_values, [0]))}
269 | # label_tensor = torch.tensor([mapping[val.item()] for val in label_tensor], dtype=torch.long)
270 | pass
271 |
272 |
273 | self.times = times
274 | self.ch_names = ch_names
275 |
276 | print(f"Data tensor shape: {data_tensor.shape}, label tensor shape: {label_tensor.shape}, text length: {len(texts)}, image length: {len(images)}")
277 |
278 | return data_tensor, label_tensor, texts, images
279 |
280 | def extract_eeg(self, eeg_data, time_window):
281 |
282 | start, end = time_window
283 |
284 | # Get the indices of the times within the specified window
285 | indices = (self.times >= start) & (self.times <= end)
286 | # print("self.times", self.times.shape)
287 | # print("indices", indices)
288 | # print("indices", indices.shape)
289 | # print("eeg_data", eeg_data.shape)
290 | # Use these indices to select the corresponding data
291 | extracted_data = eeg_data[..., indices]
292 | # print(f"extracted_data shape: {extracted_data.shape}")
293 |
294 | return extracted_data
295 |
296 | def Textencoder(self, text):
297 |
298 | text_inputs = torch.cat([clip.tokenize(t) for t in text]).to(device)
299 | # print("text_inputs", text_inputs)
300 |
301 | with torch.no_grad():
302 | text_features = vlmodel.encode_text(text_inputs)
303 |
304 | text_features = F.normalize(text_features, dim=-1).detach()
305 |
306 | return text_features
307 |
308 | def ImageEncoder(self,images):
309 | batch_size = 20
310 | image_features_list = []
311 |
312 | for i in range(0, len(images), batch_size):
313 | batch_images = images[i:i + batch_size]
314 | image_inputs = torch.stack([preprocess_train(Image.open(img).convert("RGB")) for img in batch_images]).to(device)
315 |
316 | with torch.no_grad():
317 | batch_image_features = vlmodel.encode_image(image_inputs)
318 | # batch_image_features /= batch_image_features.norm(dim=-1, keepdim=True) #use unnormalized clip embedding in reconstruction training
319 |
320 | image_features_list.append(batch_image_features)
321 |
322 | image_features = torch.cat(image_features_list, dim=0)
323 |
324 | return image_features
325 |
326 | def __getitem__(self, index):
327 | # Get the data and label corresponding to "index"
328 | # index: (subjects * classes * 10 * 4)
329 | x = self.data[index]
330 | label = self.labels[index]
331 |
332 | if self.pictures is None:
333 | if self.classes is None:
334 | index_n_sub_train = self.n_cls * 10 * 4
335 | index_n_sub_test = self.n_cls * 1 * 80
336 | else:
337 | index_n_sub_test = len(self.classes)* 1 * 80
338 | index_n_sub_train = len(self.classes)* 10 * 4
339 | # text_index: classes
340 | if self.train:
341 | text_index = (index % index_n_sub_train) // (10 * 4)
342 | else:
343 | text_index = (index % index_n_sub_test)
344 | # img_index: classes * 10
345 | if self.train:
346 | img_index = (index % index_n_sub_train) // (4)
347 | else:
348 | img_index = (index % index_n_sub_test)
349 | else:
350 | if self.classes is None:
351 | index_n_sub_train = self.n_cls * 1 * 4
352 | index_n_sub_test = self.n_cls * 1 * 80
353 | else:
354 | index_n_sub_test = len(self.classes)* 1 * 80
355 | index_n_sub_train = len(self.classes)* 1 * 4
356 | # text_index: classes
357 | if self.train:
358 | text_index = (index % index_n_sub_train) // (1 * 4)
359 | else:
360 | text_index = (index % index_n_sub_test)
361 | # img_index: classes * 10
362 | if self.train:
363 | img_index = (index % index_n_sub_train) // (4)
364 | else:
365 | img_index = (index % index_n_sub_test)
366 | # print("text_index", text_index)
367 | # print("self.text", self.text)
368 | # print("self.text", len(self.text))
369 | text = self.text[text_index]
370 | img = self.img[img_index]
371 |
372 | text_features = self.text_features[text_index]
373 | img_features = self.img_features[img_index]
374 |
375 | return x, label, text, text_features, img, img_features
376 |
377 | def __len__(self):
378 | return self.data.shape[0] # or self.labels.shape[0] which should be the same
379 |
380 | if __name__ == "__main__":
381 | # Instantiate the dataset and dataloader
382 | # data_path = "/home/ldy/Workspace/THINGS/EEG/osfstorage-archive" # Replace with the path to your data
383 | data_path = data_path
384 | train_dataset = EEGDataset(data_path, subjects = ['sub-01'], train=True)
385 | test_dataset = EEGDataset(data_path, subjects = ['sub-01'], train=False)
386 | # train_dataset = EEGDataset(data_path, exclude_subject = 'sub-01', train=True)
387 | # test_dataset = EEGDataset(data_path, exclude_subject = 'sub-01', train=False)
388 | # train_dataset = EEGDataset(data_path, train=True)
389 | # test_dataset = EEGDataset(data_path, train=False)
390 |
391 |
392 |
393 |
394 | # 100 Hz
395 | train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
396 | test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)
397 |
398 | i = 80*1-1
399 | x, label, text, text_features, img, img_features = test_dataset[i]
400 | print(f"Index {i}, Label: {label}, text: {text}")
401 | Image.open(img)
402 |
403 |
404 |
405 |
--------------------------------------------------------------------------------