├── example.ipynb ├── README.md ├── meet └── models │ ├── build.py │ ├── bab.py │ ├── vit_utils.py │ └── vit.py ├── utils.py ├── eegUtils.py ├── Utils_Bashivan.py ├── train.py └── LICENSE /example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "08fe0c59", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "from meet.models.vit import meet_small_patch8 as create_model" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 3, 17 | "id": "652fb03e", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "model = create_model(num_classes=3)\n", 22 | "\n", 23 | "dummy_eeg = torch.randn(8, 5, 6, 32, 32) # (batch x bands x frames x height x width)\n", 24 | "\n", 25 | "pred = model(dummy_eeg) # (8, 3)" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 6, 31 | "id": "83de13c5-791c-4db7-aba4-6d29ce88584e", 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "assert pred.shape == (8, 3)" 36 | ] 37 | } 38 | ], 39 | "metadata": { 40 | "kernelspec": { 41 | "display_name": "Python 3", 42 | "language": "python", 43 | "name": "python3" 44 | }, 45 | "language_info": { 46 | "codemirror_mode": { 47 | "name": "ipython", 48 | "version": 3 49 | }, 50 | "file_extension": ".py", 51 | "mimetype": "text/x-python", 52 | "name": "python", 53 | "nbconvert_exporter": "python", 54 | "pygments_lexer": "ipython3", 55 | "version": "3.9.4" 56 | } 57 | }, 58 | "nbformat": 4, 59 | "nbformat_minor": 5 60 | } 61 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MEET 2 | 3 | This is an official pytorch implementation of our paper "MEET: Multi-band EEG Transformer". In this repository, we provide PyTorch code for training and testing our proposed MEET model. MEET provides an efficient EEG classification framework that achieves state-of-the-art results on several EEG benchmarks such as SEED. 4 | 5 | If you find MEET useful in your research, please use the following BibTeX entry for citation. 6 | 7 | 8 | 9 | # Model Variants 10 | 11 | We provide the following three variants of MEET. 12 | 13 | | Model | Depth | Heads | Hidden size | MLP size | Params | 14 | | ---------- | ----- | ----- | ----------- | -------- | ------ | 15 | | MEET-Small | 3 | 3 | 768 | 3072 | 30M | 16 | | MEET-Base | 6 | 12 | 768 | 3072 | 61M | 17 | | MEET-Large | 12 | 16 | 1024 | 4096 | 215M | 18 | 19 | To simplify the model and provide a cleaner code repository, we have made minor adjustments to the model's implementation. Therefore, there might be a small difference in performance compared to the results reported in the paper. 20 | 21 | 22 | 23 | # Usage 24 | 25 | You can use MEET as follows: 26 | 27 | ```python 28 | import torch 29 | from meet.models.vit import meet_small_patch8 as create_model 30 | 31 | model = create_model(num_classes=3) 32 | dummy_eeg = torch.randn(8, 5, 6, 32, 32) # (batch x bands x frames x height x width) 33 | pred = model(dummy_eeg) # (8, 3) 34 | 35 | assert pred.shape == (8, 3) 36 | ``` 37 | 38 | -------------------------------------------------------------------------------- /meet/models/build.py: -------------------------------------------------------------------------------- 1 | """Model construction functions.""" 2 | 3 | import torch 4 | from fvcore.common.registry import Registry 5 | 6 | MODEL_REGISTRY = Registry("MODEL") 7 | MODEL_REGISTRY.__doc__ = """ 8 | Registry for video model. 9 | 10 | The registered object will be called with `obj(cfg)`. 11 | The call should return a `torch.nn.Module` object. 12 | """ 13 | 14 | 15 | def build_model(cfg, gpu_id=None): 16 | """ 17 | Builds the video model. 18 | Args: 19 | cfg (configs): configs that contains the hyper-parameters to build the 20 | backbone. Details can be seen in slowfast/config/defaults.py. 21 | gpu_id (Optional[int]): specify the gpu index to build model. 22 | """ 23 | if torch.cuda.is_available(): 24 | assert ( 25 | cfg.NUM_GPUS <= torch.cuda.device_count() 26 | ), "Cannot use more GPU devices than available" 27 | else: 28 | assert ( 29 | cfg.NUM_GPUS == 0 30 | ), "Cuda is not available. Please set `NUM_GPUS: 0 for running on CPUs." 31 | 32 | # Construct the model 33 | name = cfg.MODEL.MODEL_NAME 34 | model = MODEL_REGISTRY.get(name)(cfg) 35 | 36 | if cfg.NUM_GPUS: 37 | if gpu_id is None: 38 | # Determine the GPU used by the current process 39 | cur_device = torch.cuda.current_device() 40 | else: 41 | cur_device = gpu_id 42 | # Transfer the model to the current GPU device 43 | model = model.cuda(device=cur_device) 44 | 45 | 46 | # Use multi-process data parallel model in the multi-gpu setting 47 | if cfg.NUM_GPUS > 1: 48 | # Make model replica operate on the current device 49 | model = torch.nn.parallel.DistributedDataParallel( 50 | module=model, device_ids=[cur_device], output_device=cur_device 51 | ) 52 | return model 53 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import pickle 5 | import random 6 | 7 | import torch 8 | from tqdm import tqdm 9 | 10 | import matplotlib.pyplot as plt 11 | 12 | 13 | def plot_data_loader_image(data_loader): 14 | batch_size = data_loader.batch_size 15 | plot_num = min(batch_size, 4) 16 | 17 | json_path = './class_indices.json' 18 | assert os.path.exists(json_path), json_path + " does not exist." 19 | json_file = open(json_path, 'r') 20 | class_indices = json.load(json_file) 21 | 22 | for data in data_loader: 23 | images, labels = data 24 | for i in range(plot_num): 25 | # [C, H, W] -> [H, W, C] 26 | img = images[i].numpy().transpose(1, 2, 0) 27 | img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255 28 | label = labels[i].item() 29 | plt.subplot(1, plot_num, i+1) 30 | plt.xlabel(class_indices[str(label)]) 31 | plt.xticks([]) 32 | plt.yticks([]) 33 | plt.imshow(img.astype('uint8')) 34 | plt.show() 35 | 36 | 37 | def write_pickle(list_info: list, file_name: str): 38 | with open(file_name, 'wb') as f: 39 | pickle.dump(list_info, f) 40 | 41 | 42 | def read_pickle(file_name: str) -> list: 43 | with open(file_name, 'rb') as f: 44 | info_list = pickle.load(f) 45 | return info_list 46 | 47 | 48 | def train_one_epoch(model, optimizer, data_loader, device, epoch): 49 | model.train() 50 | loss_function = torch.nn.CrossEntropyLoss() 51 | accu_loss = torch.zeros(1).to(device) 52 | accu_num = torch.zeros(1).to(device) 53 | optimizer.zero_grad() 54 | 55 | sample_num = 0 56 | data_loader = tqdm(data_loader, file=sys.stdout) 57 | for step, data in enumerate(data_loader): 58 | images, labels = data 59 | sample_num += images.shape[0] 60 | 61 | pred = model(images.to(device)) 62 | pred_classes = torch.max(pred, dim=1)[1] 63 | accu_num += torch.eq(pred_classes, labels.to(device)).sum() 64 | 65 | loss = loss_function(pred, labels.to(device)) 66 | loss.backward() 67 | accu_loss += loss.detach() 68 | 69 | data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch, 70 | accu_loss.item() / (step + 1), 71 | accu_num.item() / sample_num) 72 | 73 | if not torch.isfinite(loss): 74 | print('WARNING: non-finite loss, ending training ', loss) 75 | sys.exit(1) 76 | 77 | optimizer.step() 78 | optimizer.zero_grad() 79 | 80 | return accu_loss.item() / (step + 1), accu_num.item() / sample_num 81 | 82 | 83 | @torch.no_grad() 84 | def evaluate(model, data_loader, device, epoch): 85 | loss_function = torch.nn.CrossEntropyLoss() 86 | 87 | model.eval() 88 | 89 | accu_num = torch.zeros(1).to(device) 90 | accu_loss = torch.zeros(1).to(device) 91 | 92 | sample_num = 0 93 | data_loader = tqdm(data_loader, file=sys.stdout) 94 | for step, data in enumerate(data_loader): 95 | images, labels = data 96 | sample_num += images.shape[0] 97 | 98 | pred = model(images.to(device)) 99 | pred_classes = torch.max(pred, dim=1)[1] 100 | accu_num += torch.eq(pred_classes, labels.to(device)).sum() 101 | 102 | loss = loss_function(pred, labels.to(device)) 103 | accu_loss += loss 104 | 105 | data_loader.desc = "[valid epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch, 106 | accu_loss.item() / (step + 1), 107 | accu_num.item() / sample_num) 108 | 109 | return accu_loss.item() / (step + 1), accu_num.item() / sample_num 110 | -------------------------------------------------------------------------------- /meet/models/bab.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class BasicConv(nn.Module): 7 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False): 8 | super(BasicConv, self).__init__() 9 | self.out_channels = out_planes 10 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) 11 | self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None 12 | self.relu = nn.ReLU() if relu else None 13 | 14 | def forward(self, x): 15 | x = self.conv(x) 16 | if self.bn is not None: 17 | x = self.bn(x) 18 | if self.relu is not None: 19 | x = self.relu(x) 20 | return x 21 | 22 | class Flatten(nn.Module): 23 | def forward(self, x): 24 | return x.view(x.size(0), -1) 25 | 26 | class ChannelGate(nn.Module): 27 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): 28 | super(ChannelGate, self).__init__() 29 | self.gate_channels = gate_channels 30 | self.mlp = nn.Sequential( 31 | Flatten(), 32 | nn.Linear(gate_channels, gate_channels // reduction_ratio), 33 | nn.ReLU(), 34 | nn.Linear(gate_channels // reduction_ratio, gate_channels) 35 | ) 36 | self.pool_types = pool_types 37 | def forward(self, x): 38 | channel_att_sum = None 39 | for pool_type in self.pool_types: 40 | if pool_type=='avg': 41 | avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 42 | channel_att_raw = self.mlp( avg_pool ) 43 | elif pool_type=='max': 44 | max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 45 | channel_att_raw = self.mlp( max_pool ) 46 | elif pool_type=='lp': 47 | lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 48 | channel_att_raw = self.mlp( lp_pool ) 49 | elif pool_type=='lse': 50 | # LSE pool only 51 | lse_pool = logsumexp_2d(x) 52 | channel_att_raw = self.mlp( lse_pool ) 53 | 54 | if channel_att_sum is None: 55 | channel_att_sum = channel_att_raw 56 | else: 57 | channel_att_sum = channel_att_sum + channel_att_raw 58 | 59 | scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x) 60 | return x * scale 61 | 62 | def logsumexp_2d(tensor): 63 | tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) 64 | s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) 65 | outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() 66 | return outputs 67 | 68 | class ChannelPool(nn.Module): 69 | def forward(self, x): 70 | return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 ) 71 | 72 | class SpatialGate(nn.Module): 73 | def __init__(self): 74 | super(SpatialGate, self).__init__() 75 | kernel_size = 7 76 | self.compress = ChannelPool() 77 | self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False) 78 | def forward(self, x): 79 | x_compress = self.compress(x) 80 | x_out = self.spatial(x_compress) 81 | scale = F.sigmoid(x_out) # broadcasting 82 | return x * scale 83 | 84 | class BAB(nn.Module): 85 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): 86 | super(BAB, self).__init__() 87 | self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) 88 | self.no_spatial=no_spatial 89 | if not no_spatial: 90 | self.SpatialGate = SpatialGate() 91 | def forward(self, x): 92 | # x: [8,256,16,16] 93 | x_out = self.ChannelGate(x) # [8,256,16,16] 94 | # if not self.no_spatial: 95 | # x_out = self.SpatialGate(x_out) 96 | return x_out 97 | -------------------------------------------------------------------------------- /eegUtils.py: -------------------------------------------------------------------------------- 1 | # Source: Bashivan, et al."Learning Representations from EEG with Deep Recurrent-Convolutional Neural Networks." International conference on learning representations (2016). 2 | # Modified by 1061413241 3 | 4 | from torch.utils.data.dataset import Dataset 5 | import torch 6 | 7 | import scipy.io as sio 8 | import torch.optim as optim 9 | import torch.nn as nn 10 | import numpy as np 11 | 12 | def kfold(length, n_fold): 13 | tot_id = np.arange(length) 14 | np.random.shuffle(tot_id) 15 | len_fold = int(length/n_fold) 16 | train_id = [] 17 | test_id = [] 18 | for i in range(n_fold): 19 | test_id.append(tot_id[i*len_fold:(i+1)*len_fold]) 20 | train_id.append(np.hstack([tot_id[0:i*len_fold],tot_id[(i+1)*len_fold:-1]])) 21 | return train_id, test_id 22 | 23 | 24 | class EEGImagesDataset(Dataset): 25 | """EEG Images Dataset from EEG.""" 26 | 27 | def __init__(self, label, image): 28 | self.label = label 29 | self.Images = image.astype(np.float32) 30 | 31 | def __len__(self): 32 | return len(self.label) 33 | 34 | def __getitem__(self, idx): 35 | if torch.is_tensor(idx): 36 | idx = idx.tolist() 37 | image = self.Images[idx] 38 | label = self.label[idx] 39 | sample = (image, label) 40 | 41 | return sample 42 | 43 | 44 | 45 | def Test_Model(net, Testloader, criterion, is_cuda=True): 46 | running_loss = 0.0 47 | evaluation = [] 48 | for i, data in enumerate(Testloader, 0): 49 | input_img, labels = data 50 | input_img = input_img.to(torch.float32) 51 | if is_cuda: 52 | input_img = input_img.cuda() 53 | outputs = net(input_img) 54 | _, predicted = torch.max(outputs.cpu().data, 1) 55 | _, eva_label = torch.max(labels.cpu().data, 1) 56 | evaluation.append((predicted==eva_label).tolist()) 57 | loss = criterion(outputs, labels.cuda()) 58 | running_loss += loss.item() 59 | running_loss = running_loss/(i+1) 60 | evaluation = [item for sublist in evaluation for item in sublist] 61 | running_acc = sum(evaluation)/len(evaluation) 62 | return running_loss, running_acc 63 | 64 | 65 | def TrainTest_Model(model, trainloader, testloader, n_epoch=30, opti='SGD', learning_rate=0.0001, is_cuda=True, print_epoch =5, verbose=False): 66 | if is_cuda: 67 | net = model.cuda() 68 | else : 69 | net = model 70 | 71 | criterion = nn.CrossEntropyLoss() 72 | 73 | if opti=='SGD': 74 | optimizer = optim.SGD(net.parameters(), lr=learning_rate) 75 | elif opti =='Adam': 76 | optimizer = optim.Adam(net.parameters(), lr=learning_rate) 77 | else: 78 | print("Optimizer: "+optim+" not implemented.") 79 | 80 | for epoch in range(n_epoch): 81 | running_loss = 0.0 82 | evaluation = [] 83 | for i, data in enumerate(trainloader, 0): 84 | # get the inputs; data is a list of [inputs, labels] 85 | inputs, labels = data 86 | # zero the parameter gradients 87 | optimizer.zero_grad() 88 | 89 | # forward + backward + optimize 90 | outputs = net(inputs.to(torch.float32).cuda()) 91 | _, predicted = torch.max(outputs.cpu().data, 1) 92 | _, eva_label = torch.max(labels.cpu().data, 1) 93 | evaluation.append((predicted==eva_label).tolist()) 94 | loss = criterion(outputs, labels.to(torch.long).cuda()) 95 | loss.backward() 96 | optimizer.step() 97 | 98 | running_loss += loss.item() 99 | 100 | running_loss = running_loss/(i+1) 101 | evaluation = [item for sublist in evaluation for item in sublist] 102 | running_acc = sum(evaluation)/len(evaluation) 103 | validation_loss, validation_acc = Test_Model(net, testloader, criterion,True) 104 | 105 | if epoch%print_epoch==(print_epoch-1): 106 | print('[%d, %3d]\tloss: %.3f\tAccuracy : %.3f\t\tval-loss: %.3f\tval-Accuracy : %.3f' % 107 | (epoch+1, n_epoch, running_loss, running_acc, validation_loss, validation_acc)) 108 | if verbose: 109 | print('Finished Training \n loss: %.3f\tAccuracy : %.3f\t\tval-loss: %.3f\tval-Accuracy : %.3f' % 110 | (running_loss, running_acc, validation_loss,validation_acc)) 111 | 112 | return (running_loss, running_acc, validation_loss,validation_acc) -------------------------------------------------------------------------------- /Utils_Bashivan.py: -------------------------------------------------------------------------------- 1 | # Source: Bashivan, et al. "Learning Representations from EEG with Deep Recurrent-Convolutional Neural Networks." International conference on learning representations (2016). 2 | # Modified by 1061413241 3 | 4 | import numpy as np 5 | np.random.seed(123) 6 | import scipy.io 7 | 8 | from scipy.interpolate import griddata 9 | from sklearn.preprocessing import scale 10 | import math as m 11 | from sklearn.decomposition import PCA 12 | 13 | def azim_proj(pos): 14 | """ 15 | Computes the Azimuthal Equidistant Projection of input point in 3D Cartesian Coordinates. 16 | Imagine a plane being placed against (tangent to) a globe. If 17 | a light source inside the globe projects the graticule onto 18 | the plane the result would be a planar, or azimuthal, map 19 | projection. 20 | 21 | :param pos: position in 3D Cartesian coordinates 22 | :return: projected coordinates using Azimuthal Equidistant Projection 23 | """ 24 | [r, elev, az] = cart2sph(pos[0], pos[1], pos[2]) 25 | return pol2cart(az, m.pi / 2 - elev) 26 | 27 | def cart2sph(x, y, z): 28 | """ 29 | Transform Cartesian coordinates to spherical 30 | :param x: X coordinate 31 | :param y: Y coordinate 32 | :param z: Z coordinate 33 | :return: radius, elevation, azimuth 34 | """ 35 | x2_y2 = x**2 + y**2 36 | r = m.sqrt(x2_y2 + z**2) # r 37 | elev = m.atan2(z, m.sqrt(x2_y2)) # Elevation 38 | az = m.atan2(y, x) # Azimuth 39 | return r, elev, az 40 | 41 | 42 | def pol2cart(theta, rho): 43 | """ 44 | Transform polar coordinates to Cartesian 45 | :param theta: angle value 46 | :param rho: radius value 47 | :return: X, Y 48 | """ 49 | return rho * m.cos(theta), rho * m.sin(theta) 50 | 51 | def gen_images(locs, features, n_gridpoints, normalize=True, 52 | augment=False, pca=False, std_mult=0.1, n_components=2, edgeless=False): 53 | """ 54 | Generates EEG images given electrode locations in 2D space and multiple feature values for each electrode 55 | 56 | :param locs: An array with shape [n_electrodes, 2] containing X, Y 57 | coordinates for each electrode. 58 | :param features: Feature matrix as [n_samples, n_features] i.e. [2670,192] 59 | Features are as columns. 60 | Features corresponding to each frequency band are concatenated. 61 | (alpha1, alpha2, ..., beta1, beta2,...) 62 | :param n_gridpoints: Number of pixels in the output images 63 | :param normalize: Flag for whether to normalize each band over all samples 64 | :param augment: Flag for generating augmented images 65 | :param pca: Flag for PCA based data augmentation 66 | :param std_mult Multiplier for std of added noise 67 | :param n_components: Number of components in PCA to retain for augmentation 68 | :param edgeless: If True generates edgeless images by adding artificial channels 69 | at four corners of the image with value = 0 (default=False). 70 | :return: Tensor of size [samples, colors, W, H] containing generated 71 | images. 72 | """ 73 | feat_array_temp = [] 74 | nElectrodes = locs.shape[0] # Number of electrodes 75 | 76 | # Test whether the feature vector length is divisible by number of electrodes 77 | assert features.shape[1] % nElectrodes == 0 # 192%64==3 78 | n_colors = int(features.shape[1] / nElectrodes) 79 | for c in range(n_colors): 80 | feat_array_temp.append(features[:, c * nElectrodes : nElectrodes * (c+1)]) # [2670,64] 81 | if augment: 82 | if pca: 83 | for c in range(n_colors): 84 | feat_array_temp[c] = augment_EEG(feat_array_temp[c], std_mult, pca=True, n_components=n_components) 85 | else: 86 | for c in range(n_colors): 87 | feat_array_temp[c] = augment_EEG(feat_array_temp[c], std_mult, pca=False, n_components=n_components) 88 | n_samples = features.shape[0] # 2670 89 | 90 | # Interpolate the values 91 | grid_x, grid_y = np.mgrid[ 92 | min(locs[:, 0]):max(locs[:, 0]):n_gridpoints*1j, 93 | min(locs[:, 1]):max(locs[:, 1]):n_gridpoints*1j 94 | ] 95 | temp_interp = [] 96 | for c in range(n_colors): 97 | temp_interp.append(np.zeros([n_samples, n_gridpoints, n_gridpoints])) 98 | 99 | # Generate edgeless images 100 | if edgeless: 101 | min_x, min_y = np.min(locs, axis=0) 102 | max_x, max_y = np.max(locs, axis=0) 103 | locs = np.append(locs, np.array([[min_x, min_y], [min_x, max_y], [max_x, min_y], [max_x, max_y]]), axis=0) 104 | for c in range(n_colors): 105 | feat_array_temp[c] = np.append(feat_array_temp[c], np.zeros((n_samples, 4)), axis=1) 106 | 107 | # Interpolating 108 | for i in range(n_samples): 109 | for c in range(n_colors): 110 | temp_interp[c][i, :, :] = griddata(locs, feat_array_temp[c][i, :], (grid_x, grid_y), 111 | method='cubic', fill_value=np.nan) 112 | print('Interpolating {0}/{1}\r'.format(i + 1, n_samples), end='\r') 113 | 114 | # Normalizing 115 | for c in range(n_colors): 116 | if normalize: 117 | temp_interp[c][~np.isnan(temp_interp[c])] = \ 118 | scale(temp_interp[c][~np.isnan(temp_interp[c])]) 119 | temp_interp[c] = np.nan_to_num(temp_interp[c]) 120 | return np.swapaxes(np.asarray(temp_interp), 0, 1) # swap axes to have [samples, colors, W, H] 121 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import argparse 4 | from matplotlib.pyplot import axis 5 | 6 | import torch 7 | import torch.optim as optim 8 | import torch.optim.lr_scheduler as lr_scheduler 9 | from torch.utils.tensorboard import SummaryWriter 10 | from torch.utils.data import DataLoader,random_split 11 | import scipy.io as sio 12 | 13 | from meet.models.vit import meet_small_patch8 as create_model 14 | from utils import train_one_epoch, evaluate 15 | from eegUtils import * 16 | 17 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 18 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 19 | 20 | seed = 123 21 | np.random.seed(seed) 22 | torch.manual_seed(seed) 23 | torch.cuda.manual_seed(seed) 24 | torch.cuda.manual_seed_all(seed) 25 | 26 | def main(args): 27 | device = torch.device(args.device if torch.cuda.is_available() else "cpu") 28 | 29 | if os.path.exists("./weights/SEED/" + args.task) is False: 30 | os.makedirs("./weights/SEED/" + args.task) 31 | if os.path.exists("./Results/SEED/") is False: 32 | os.makedirs("./Results/SEED/") 33 | 34 | tb_writer = SummaryWriter("./Summary/SEED/" + args.task + "/") 35 | 36 | # Load data 37 | for sbj in range(1,2): 38 | raw_data = sio.loadmat("./dataset/SEED/eeg_convert_to_image/S" + str(sbj) + "_32@32.mat") # Change to your dataset path 39 | sbj_data = raw_data["img"] 40 | Label = (raw_data['label']).astype(int) 41 | if sbj == 1: 42 | All_data = sbj_data 43 | else: 44 | All_data = np.concatenate((All_data, sbj_data), axis=0) 45 | 46 | Label = np.tile(Label, (1,1)) 47 | EEG_Images = np.transpose(All_data, (0,2,1,3,4)) 48 | Label = np.reshape(Label, (-1,1))[:,0] 49 | 50 | EEG = EEGImagesDataset(label=Label, image=EEG_Images) 51 | lengths = [int(len(EEG)*0.8), int(len(EEG)*0.2)] 52 | if sum(lengths) != len(EEG): 53 | lengths[0] = lengths[0] + 1 54 | Train, Test = random_split(EEG, lengths) 55 | batch_size = args.batch_size 56 | Trainloader = DataLoader(Train, batch_size=batch_size,shuffle=True) 57 | Testloader = DataLoader(Test, batch_size=batch_size,shuffle=True) 58 | 59 | nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers 60 | print('Using {} dataloader workers every process'.format(nw)) 61 | 62 | model = create_model(num_classes=3).to(device) 63 | 64 | if args.weights != "": 65 | assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights) 66 | weights_dict = torch.load(args.weights, map_location=device) 67 | # Delete unnecessary weights 68 | del_keys = ['head.weight', 'head.bias'] if False \ 69 | else ['pre_logits.fc.weight', 'pre_logits.fc.bias', 'head.weight', 'head.bias'] 70 | for k in del_keys: 71 | del weights_dict[k] 72 | print(model.load_state_dict(weights_dict, strict=False)) 73 | 74 | if args.freeze_layers: 75 | for name, para in model.named_parameters(): 76 | if "head" not in name and "pre_logits" not in name: 77 | para.requires_grad_(False) 78 | else: 79 | print("training {}".format(name)) 80 | 81 | pg = [p for p in model.parameters() if p.requires_grad] 82 | optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=5E-5) 83 | # Scheduler https://arxiv.org/pdf/1812.01187.pdf 84 | lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf # cosine 85 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) 86 | 87 | best_acc = 0. 88 | acc_log = [] 89 | for epoch in range(args.epochs): 90 | # train 91 | train_loss, train_acc = train_one_epoch(model=model, 92 | optimizer=optimizer, 93 | data_loader=Trainloader, 94 | device=device, 95 | epoch=epoch) 96 | 97 | scheduler.step() 98 | 99 | # validate 100 | val_loss, val_acc = evaluate(model=model, 101 | data_loader=Testloader, 102 | device=device, 103 | epoch=epoch) 104 | 105 | tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"] 106 | tb_writer.add_scalar(tags[0], train_loss, epoch) 107 | tb_writer.add_scalar(tags[1], train_acc, epoch) 108 | tb_writer.add_scalar(tags[2], val_loss, epoch) 109 | tb_writer.add_scalar(tags[3], val_acc, epoch) 110 | tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch) 111 | acc_log.append([train_loss, train_acc, val_loss, val_acc, optimizer.param_groups[0]["lr"]]) 112 | 113 | # Save the weights 114 | if best_acc < val_acc: 115 | torch.save(model.state_dict(), "./weights/SEED/" + args.task + "/best_model.pth") 116 | best_acc = val_acc 117 | 118 | # Save the results 119 | result_file = os.path.join("./Results/SEED/", 'Result_%s.mat'%args.task) 120 | Acc_Log = np.array(acc_log) 121 | sio.savemat(result_file, {'Acclog': Acc_Log.astype(np.double)}) 122 | 123 | print(best_acc) 124 | 125 | 126 | if __name__ == '__main__': 127 | parser = argparse.ArgumentParser() 128 | parser.add_argument('--epochs', type=int, default=1000) 129 | parser.add_argument('--batch-size', type=int, default=8) 130 | parser.add_argument('--lr', type=float, default=0.001) 131 | parser.add_argument('--lrf', type=float, default=0.01) 132 | 133 | # Task type 134 | parser.add_argument('--task', default="SEED_S1") 135 | 136 | # The path of pretrained weight, set to null if you don't want to load it 137 | parser.add_argument('--weights', type=str, default='./weights/SEED/SEED_S1/best_model.pth', 138 | help='initial weights path') 139 | 140 | # Freeze weight or not 141 | parser.add_argument('--freeze-layers', type=bool, default=False) 142 | parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)') 143 | opt = parser.parse_args() 144 | 145 | main(opt) -------------------------------------------------------------------------------- /meet/models/vit_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Ross Wightman 2 | # Various utility functions 3 | 4 | import torch 5 | import torch.nn as nn 6 | from functools import partial 7 | import math 8 | import warnings 9 | import torch.nn.functional as F 10 | 11 | from itertools import repeat 12 | from torch._six import container_abcs 13 | 14 | DEFAULT_CROP_PCT = 0.875 15 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 16 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 17 | IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) 18 | IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) 19 | IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255) 20 | IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3) 21 | 22 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 23 | def norm_cdf(x): 24 | # Computes standard normal cumulative distribution function 25 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 26 | 27 | if (mean < a - 2 * std) or (mean > b + 2 * std): 28 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 29 | "The distribution of values may be incorrect.", 30 | stacklevel=2) 31 | 32 | with torch.no_grad(): 33 | # Values are generated by using a truncated uniform distribution and 34 | # then using the inverse CDF for the normal distribution. 35 | # Get upper and lower cdf values 36 | l = norm_cdf((a - mean) / std) 37 | u = norm_cdf((b - mean) / std) 38 | 39 | # Uniformly fill tensor with values from [l, u], then translate to 40 | # [2l-1, 2u-1]. 41 | tensor.uniform_(2 * l - 1, 2 * u - 1) 42 | 43 | # Use inverse cdf transform for normal distribution to get truncated 44 | # standard normal 45 | tensor.erfinv_() 46 | 47 | # Transform to proper mean, std 48 | tensor.mul_(std * math.sqrt(2.)) 49 | tensor.add_(mean) 50 | 51 | # Clamp to ensure it's in the proper range 52 | tensor.clamp_(min=a, max=b) 53 | return tensor 54 | 55 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 56 | # type: (Tensor, float, float, float, float) -> Tensor 57 | r"""Fills the input Tensor with values drawn from a truncated 58 | normal distribution. The values are effectively drawn from the 59 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 60 | with values outside :math:`[a, b]` redrawn until they are within 61 | the bounds. The method used for generating the random values works 62 | best when :math:`a \leq \text{mean} \leq b`. 63 | Args: 64 | tensor: an n-dimensional `torch.Tensor` 65 | mean: the mean of the normal distribution 66 | std: the standard deviation of the normal distribution 67 | a: the minimum cutoff value 68 | b: the maximum cutoff value 69 | Examples: 70 | >>> w = torch.empty(3, 5) 71 | >>> nn.init.trunc_normal_(w) 72 | """ 73 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 74 | 75 | # From PyTorch internals 76 | def _ntuple(n): 77 | def parse(x): 78 | if isinstance(x, container_abcs.Iterable): 79 | return x 80 | return tuple(repeat(x, n)) 81 | return parse 82 | to_2tuple = _ntuple(2) 83 | 84 | # Calculate symmetric padding for a convolution 85 | def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: 86 | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 87 | return padding 88 | 89 | def get_padding_value(padding, kernel_size, **kwargs): 90 | dynamic = False 91 | if isinstance(padding, str): 92 | # for any string padding, the padding will be calculated for you, one of three ways 93 | padding = padding.lower() 94 | if padding == 'same': 95 | # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact 96 | if is_static_pad(kernel_size, **kwargs): 97 | # static case, no extra overhead 98 | padding = get_padding(kernel_size, **kwargs) 99 | else: 100 | # dynamic 'SAME' padding, has runtime/GPU memory overhead 101 | padding = 0 102 | dynamic = True 103 | elif padding == 'valid': 104 | # 'VALID' padding, same as padding=0 105 | padding = 0 106 | else: 107 | # Default to PyTorch style 'same'-ish symmetric padding 108 | padding = get_padding(kernel_size, **kwargs) 109 | return padding, dynamic 110 | 111 | # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution 112 | def get_same_padding(x: int, k: int, s: int, d: int): 113 | return max((int(math.ceil(x // s)) - 1) * s + (k - 1) * d + 1 - x, 0) 114 | 115 | 116 | # Can SAME padding for given args be done statically? 117 | def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): 118 | return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 119 | 120 | 121 | # Dynamically pad input x with 'SAME' padding for conv with specified args 122 | #def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0): 123 | def pad_same(x, k, s, d=(1, 1), value= 0): 124 | ih, iw = x.size()[-2:] 125 | pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1]) 126 | if pad_h > 0 or pad_w > 0: 127 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value) 128 | return x 129 | 130 | def adaptive_pool_feat_mult(pool_type='avg'): 131 | if pool_type == 'catavgmax': 132 | return 2 133 | else: 134 | return 1 135 | 136 | def drop_path(x, drop_prob: float = 0., training: bool = False): 137 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 138 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 139 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 140 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 141 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 142 | 'survival rate' as the argument. 143 | """ 144 | if drop_prob == 0. or not training: 145 | return x 146 | keep_prob = 1 - drop_prob 147 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 148 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 149 | random_tensor.floor_() # binarize 150 | output = x.div(keep_prob) * random_tensor 151 | return output 152 | 153 | class DropPath(nn.Module): 154 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 155 | """ 156 | def __init__(self, drop_prob=None): 157 | super(DropPath, self).__init__() 158 | self.drop_prob = drop_prob 159 | 160 | def forward(self, x): 161 | return drop_path(x, self.drop_prob, self.training) 162 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /meet/models/vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | import math 5 | import warnings 6 | import torch.nn.functional as F 7 | import numpy as np 8 | 9 | from meet.models.vit_utils import DropPath, to_2tuple, trunc_normal_ 10 | 11 | from .build import MODEL_REGISTRY 12 | from torch import einsum 13 | from einops import rearrange, reduce, repeat 14 | 15 | from .bab import BAB 16 | 17 | 18 | class Mlp(nn.Module): 19 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 20 | super().__init__() 21 | out_features = out_features or in_features 22 | hidden_features = hidden_features or in_features 23 | self.fc1 = nn.Linear(in_features, hidden_features) 24 | self.act = act_layer() 25 | self.fc2 = nn.Linear(hidden_features, out_features) 26 | self.drop = nn.Dropout(drop) 27 | 28 | def forward(self, x): 29 | x = self.fc1(x) 30 | x = self.act(x) 31 | x = self.drop(x) 32 | x = self.fc2(x) 33 | x = self.drop(x) 34 | return x 35 | 36 | class Attention(nn.Module): 37 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., with_qkv=True): 38 | super().__init__() 39 | self.num_heads = num_heads 40 | head_dim = dim // num_heads 41 | self.scale = qk_scale or head_dim ** -0.5 42 | self.with_qkv = with_qkv 43 | if self.with_qkv: 44 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 45 | self.proj = nn.Linear(dim, dim) 46 | self.proj_drop = nn.Dropout(proj_drop) 47 | self.attn_drop = nn.Dropout(attn_drop) 48 | 49 | def forward(self, x): 50 | B, N, C = x.shape 51 | if self.with_qkv: 52 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 53 | q, k, v = qkv[0], qkv[1], qkv[2] 54 | else: 55 | qkv = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 56 | q, k, v = qkv, qkv, qkv 57 | 58 | attn = (q @ k.transpose(-2, -1)) * self.scale 59 | attn = attn.softmax(dim=-1) 60 | attn = self.attn_drop(attn) 61 | 62 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 63 | if self.with_qkv: 64 | x = self.proj(x) 65 | x = self.proj_drop(x) 66 | return x 67 | 68 | class Block(nn.Module): 69 | 70 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 71 | drop_path=0.1, act_layer=nn.GELU, norm_layer=nn.LayerNorm, attention_type='divided_space_time'): 72 | super().__init__() 73 | self.attention_type = attention_type 74 | assert(attention_type in ['divided_space_time', 'space_only','joint_space_time']) 75 | 76 | self.norm1 = norm_layer(dim) 77 | self.attn = Attention( 78 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 79 | 80 | ## Temporal Attention Parameters 81 | if self.attention_type == 'divided_space_time': 82 | self.temporal_norm1 = norm_layer(dim) 83 | self.temporal_attn = Attention( 84 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 85 | self.temporal_fc = nn.Linear(dim, dim) 86 | 87 | ## drop path 88 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 89 | self.norm2 = norm_layer(dim) 90 | mlp_hidden_dim = int(dim * mlp_ratio) 91 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 92 | 93 | 94 | def forward(self, x, B, T, W): 95 | num_spatial_tokens = (x.size(1) - 1) // T 96 | H = num_spatial_tokens // W 97 | 98 | if self.attention_type in ['space_only', 'joint_space_time']: 99 | x = x + self.drop_path(self.attn(self.norm1(x))) 100 | x = x + self.drop_path(self.mlp(self.norm2(x))) 101 | return x 102 | elif self.attention_type == 'divided_space_time': 103 | ## Temporal 104 | xt = x[:,1:,:] 105 | xt = rearrange(xt, 'b (h w t) m -> (b h w) t m',b=B,h=H,w=W,t=T) 106 | res_temporal = self.drop_path(self.temporal_attn(self.temporal_norm1(xt))) 107 | res_temporal = rearrange(res_temporal, '(b h w) t m -> b (h w t) m',b=B,h=H,w=W,t=T) 108 | res_temporal = self.temporal_fc(res_temporal) 109 | xt = x[:,1:,:] + res_temporal 110 | 111 | ## Spatial 112 | init_cls_token = x[:,0,:].unsqueeze(1) 113 | cls_token = init_cls_token.repeat(1, T, 1) 114 | cls_token = rearrange(cls_token, 'b t m -> (b t) m',b=B,t=T).unsqueeze(1) 115 | xs = xt 116 | xs = rearrange(xs, 'b (h w t) m -> (b t) (h w) m',b=B,h=H,w=W,t=T) 117 | xs = torch.cat((cls_token, xs), 1) 118 | res_spatial = self.drop_path(self.attn(self.norm1(xs))) 119 | 120 | ### Taking care of CLS token 121 | cls_token = res_spatial[:,0,:] 122 | cls_token = rearrange(cls_token, '(b t) m -> b t m',b=B,t=T) 123 | cls_token = torch.mean(cls_token,1,True) ## averaging for every frame 124 | res_spatial = res_spatial[:,1:,:] 125 | res_spatial = rearrange(res_spatial, '(b t) (h w) m -> b (h w t) m',b=B,h=H,w=W,t=T) 126 | res = res_spatial 127 | x = xt 128 | 129 | ## Mlp 130 | x = torch.cat((init_cls_token, x), 1) + torch.cat((cls_token, res), 1) 131 | x = x + self.drop_path(self.mlp(self.norm2(x))) 132 | return x 133 | 134 | class PatchEmbed(nn.Module): 135 | """ Image to Patch Embedding 136 | """ 137 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 138 | super().__init__() 139 | img_size = to_2tuple(img_size) 140 | patch_size = to_2tuple(patch_size) 141 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 142 | self.img_size = img_size 143 | self.patch_size = patch_size 144 | self.num_patches = num_patches 145 | 146 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 147 | 148 | def forward(self, x): 149 | B, C, T, H, W = x.shape 150 | x = rearrange(x, 'b c t h w -> (b t) c h w') 151 | x = self.proj(x) 152 | W = x.size(-1) 153 | x = x.flatten(2).transpose(1, 2) 154 | return x, T, W 155 | 156 | 157 | class VisionTransformer(nn.Module): 158 | """ Vision Transformere 159 | """ 160 | def __init__(self, img_size=224, patch_size=16, in_chans=5, num_classes=1000, embed_dim=768, depth=12, 161 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 162 | drop_path_rate=0.1, hybrid_backbone=None, norm_layer=nn.LayerNorm, num_frames=8, attention_type='divided_space_time', dropout=0.): 163 | super().__init__() 164 | self.attention_type = attention_type 165 | self.depth = depth 166 | self.dropout = nn.Dropout(dropout) 167 | self.num_classes = num_classes 168 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 169 | self.patch_embed = PatchEmbed( 170 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 171 | num_patches = self.patch_embed.num_patches 172 | 173 | ## Positional Embeddings 174 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 175 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, embed_dim)) 176 | self.pos_drop = nn.Dropout(p=drop_rate) 177 | if self.attention_type != 'space_only': 178 | self.time_embed = nn.Parameter(torch.zeros(1, num_frames, embed_dim)) 179 | self.time_drop = nn.Dropout(p=drop_rate) 180 | 181 | ## Attention Blocks 182 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.depth)] # stochastic depth decay rule 183 | self.blocks = nn.ModuleList([ 184 | Block( 185 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 186 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, attention_type=self.attention_type) 187 | for i in range(self.depth)]) 188 | self.norm = norm_layer(embed_dim) 189 | 190 | # Classifier head 191 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 192 | 193 | trunc_normal_(self.pos_embed, std=.02) 194 | trunc_normal_(self.cls_token, std=.02) 195 | self.apply(self._init_weights) 196 | 197 | ## initialization of temporal attention weights 198 | if self.attention_type == 'divided_space_time': 199 | i = 0 200 | for m in self.blocks.modules(): 201 | m_str = str(m) 202 | if 'Block' in m_str: 203 | if i > 0: 204 | nn.init.constant_(m.temporal_fc.weight, 0) 205 | nn.init.constant_(m.temporal_fc.bias, 0) 206 | i += 1 207 | 208 | def _init_weights(self, m): 209 | if isinstance(m, nn.Linear): 210 | trunc_normal_(m.weight, std=.02) 211 | if isinstance(m, nn.Linear) and m.bias is not None: 212 | nn.init.constant_(m.bias, 0) 213 | elif isinstance(m, nn.LayerNorm): 214 | nn.init.constant_(m.bias, 0) 215 | nn.init.constant_(m.weight, 1.0) 216 | 217 | @torch.jit.ignore 218 | def no_weight_decay(self): 219 | return {'pos_embed', 'cls_token', 'time_embed'} 220 | 221 | def get_classifier(self): 222 | return self.head 223 | 224 | def reset_classifier(self, num_classes, global_pool=''): 225 | self.num_classes = num_classes 226 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 227 | 228 | def forward_features(self, x): 229 | B = x.shape[0] 230 | x, T, W = self.patch_embed(x) 231 | cls_tokens = self.cls_token.expand(x.size(0), -1, -1) 232 | x = torch.cat((cls_tokens, x), dim=1) 233 | 234 | ## resizing the positional embeddings in case they don't match the input at inference 235 | if x.size(1) != self.pos_embed.size(1): 236 | pos_embed = self.pos_embed 237 | cls_pos_embed = pos_embed[0,0,:].unsqueeze(0).unsqueeze(1) 238 | other_pos_embed = pos_embed[0,1:,:].unsqueeze(0).transpose(1, 2) 239 | P = int(other_pos_embed.size(2) ** 0.5) 240 | H = x.size(1) // W 241 | other_pos_embed = other_pos_embed.reshape(1, x.size(2), P, P) 242 | new_pos_embed = F.interpolate(other_pos_embed, size=(H, W), mode='nearest') 243 | new_pos_embed = new_pos_embed.flatten(2) 244 | new_pos_embed = new_pos_embed.transpose(1, 2) 245 | new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1) 246 | x = x + new_pos_embed 247 | else: 248 | x = x + self.pos_embed 249 | x = self.pos_drop(x) 250 | 251 | 252 | ## Time Embeddings 253 | if self.attention_type != 'space_only': 254 | cls_tokens = x[:B, 0, :].unsqueeze(1) 255 | x = x[:,1:] 256 | x = rearrange(x, '(b t) n m -> (b n) t m',b=B,t=T) 257 | ## Resizing time embeddings in case they don't match 258 | if T != self.time_embed.size(1): 259 | time_embed = self.time_embed.transpose(1, 2) 260 | new_time_embed = F.interpolate(time_embed, size=(T), mode='nearest') 261 | new_time_embed = new_time_embed.transpose(1, 2) 262 | x = x + new_time_embed 263 | else: 264 | x = x + self.time_embed 265 | x = self.time_drop(x) 266 | x = rearrange(x, '(b n) t m -> b (n t) m',b=B,t=T) 267 | x = torch.cat((cls_tokens, x), dim=1) 268 | 269 | ## Attention blocks 270 | for blk in self.blocks: 271 | x = blk(x, B, T, W) 272 | 273 | ### Predictions for space-only baseline 274 | if self.attention_type == 'space_only': 275 | x = rearrange(x, '(b t) n m -> b t n m',b=B,t=T) 276 | x = torch.mean(x, 1) # averaging predictions for every frame 277 | 278 | x = self.norm(x) 279 | return x[:, 0] 280 | 281 | def forward(self, x): 282 | x = self.forward_features(x) 283 | x = self.head(x) 284 | return x 285 | 286 | def _conv_filter(state_dict, patch_size=16): 287 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 288 | out_dict = {} 289 | for k, v in state_dict.items(): 290 | if 'patch_embed.proj.weight' in k: 291 | if v.shape[-1] != patch_size: 292 | patch_size = v.shape[-1] 293 | v = v.reshape((v.shape[0], 3, patch_size, patch_size)) 294 | out_dict[k] = v 295 | return out_dict 296 | 297 | @MODEL_REGISTRY.register() 298 | class MEET(nn.Module): 299 | def __init__(self, img_size=32, patch_size=16, embed_dim=768, 300 | depth=12, num_heads=12, num_classes=4, num_frames=6, 301 | attention_type='divided_space_time', pretrained_model='', **kwargs): 302 | super(MEET, self).__init__() 303 | self.model = VisionTransformer(img_size=img_size, num_classes=num_classes, patch_size=patch_size, 304 | embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=4, 305 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 306 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 307 | num_frames=num_frames, attention_type=attention_type, **kwargs) 308 | self.bab = BAB(5, 5) 309 | 310 | def forward(self, x): 311 | # x: [8,5,6,32,32] 312 | B, C, T, h, w = x.shape 313 | # BAB 314 | x = torch.reshape(x, (B*T,C,h,w)) 315 | x = self.bab(x) 316 | x = torch.reshape(x, (B,C,T,h,w)) 317 | x = self.model(x) 318 | return x 319 | 320 | def meet_small_patch8(num_classes: int = 3): 321 | """ 322 | MEET-Small model from original paper (). 323 | """ 324 | model = MEET(img_size=32, 325 | patch_size=8, 326 | embed_dim=768, 327 | depth=3, 328 | num_heads=3, 329 | num_classes=num_classes, 330 | num_frames=1, 331 | attention_type='divided_space_time', 332 | pretrained_model='') 333 | return model 334 | 335 | def meet_base_patch16(num_classes: int = 3): 336 | """ 337 | MEET-Base model from original paper (). 338 | """ 339 | model = MEET(img_size=32, 340 | patch_size=16, 341 | embed_dim=768, 342 | depth=6, 343 | num_heads=12, 344 | num_classes=num_classes, 345 | num_frames=1, 346 | attention_type='divided_space_time', 347 | pretrained_model='') 348 | return model 349 | 350 | def meet_large_patch16(num_classes: int = 3): 351 | """ 352 | MEET-Large model from original paper (). 353 | """ 354 | model = MEET(img_size=32, 355 | patch_size=16, 356 | embed_dim=1024, 357 | depth=12, 358 | num_heads=16, 359 | num_classes=num_classes, 360 | num_frames=1, 361 | attention_type='divided_space_time', 362 | pretrained_model='') 363 | return model --------------------------------------------------------------------------------