├── WeightFreezing_中文AI翻译版(未校对).pdf ├── compared_models.py ├── README.md └── modelsWithWeightFreezing.py /WeightFreezing_中文AI翻译版(未校对).pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiaoZhengQing/WeightFreezing/HEAD/WeightFreezing_中文AI翻译版(未校对).pdf -------------------------------------------------------------------------------- /compared_models.py: -------------------------------------------------------------------------------- 1 | from torchsummary import summary 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def weights_init(m): 7 | if isinstance(m, nn.Conv2d): 8 | nn.init.xavier_uniform_(m.weight) 9 | # nn.init.constant(m.bias, 0) # bias may be none 10 | 11 | elif isinstance(m, nn.BatchNorm2d): 12 | nn.init.constant_(m.weight, 1) 13 | nn.init.constant_(m.bias, 0) 14 | 15 | elif isinstance(m, nn.Linear): 16 | nn.init.xavier_uniform_(m.weight) 17 | nn.init.constant_(m.bias, 0) 18 | 19 | 20 | 21 | def square_activation(x): 22 | return torch.square(x) 23 | 24 | 25 | def safe_log(x): 26 | return torch.clip(torch.log(x), min=1e-7, max=1e7) 27 | 28 | 29 | class ShallowConvNet(nn.Module): 30 | def __init__(self, num_classes, chans, samples=1125): 31 | super(ShallowConvNet, self).__init__() 32 | self.conv_nums = 40 33 | self.features = nn.Sequential( 34 | nn.Conv2d(1, self.conv_nums, (1, 25)), 35 | nn.Conv2d(self.conv_nums, self.conv_nums, (chans, 1), bias=False), 36 | nn.BatchNorm2d(self.conv_nums) 37 | ) 38 | self.avgpool = nn.AvgPool2d(kernel_size=(1, 75), stride=(1, 15)) 39 | self.dropout = nn.Dropout() 40 | 41 | out = torch.ones((1, 1, chans, samples)) 42 | out = self.features(out) 43 | out = self.avgpool(out) 44 | n_out_time = out.cpu().data.numpy().shape 45 | self.classifier = nn.Linear(n_out_time[-1] * n_out_time[-2] * n_out_time[-3], num_classes) 46 | 47 | def forward(self, x): 48 | x = self.features(x) 49 | x = square_activation(x) 50 | x = self.avgpool(x) 51 | x = safe_log(x) 52 | x = self.dropout(x) 53 | 54 | features = torch.flatten(x, 1) 55 | cls = self.classifier(features) 56 | return cls 57 | 58 | 59 | class EEGNet(nn.Module): 60 | def __init__(self, num_classes, chans, samples=1125, dropout_rate=0.5, kernel_length=64, F1=8, 61 | F2=16,): 62 | super(EEGNet, self).__init__() 63 | 64 | self.features = nn.Sequential( 65 | nn.Conv2d(1, F1, kernel_size=(1, kernel_length), bias=False), 66 | nn.BatchNorm2d(F1), 67 | nn.Conv2d(F1, F1, kernel_size=(chans, 1), groups=F1, bias=False), # groups=F1 for depthWiseConv 68 | nn.BatchNorm2d(F1), 69 | nn.ELU(inplace=True), 70 | # nn.ReLU(), 71 | nn.AvgPool2d((1, 4)), 72 | nn.Dropout(dropout_rate), 73 | # for SeparableCon2D 74 | # SeparableConv2D(F1, F2, kernel1_size=(1, 16), bias=False), 75 | nn.Conv2d(F1, F1, kernel_size=(1, 16), groups=F1, bias=False), # groups=F1 for depthWiseConv 76 | nn.BatchNorm2d(F1), 77 | nn.ELU(inplace=True), 78 | # nn.ReLU(), 79 | nn.Conv2d(F1, F2, kernel_size=(1, 1), groups=1, bias=False), # point-wise cnn 80 | nn.BatchNorm2d(F2), 81 | # nn.ReLU(), 82 | nn.ELU(inplace=True), 83 | nn.AvgPool2d((1, 8)), 84 | nn.Dropout(p=dropout_rate), 85 | # nn.Dropout(p=0.5), 86 | ) 87 | out = torch.ones((1, 1, chans, samples)) 88 | out = self.features(out) 89 | n_out_time = out.cpu().data.numpy().shape 90 | self.classifier = nn.Linear(n_out_time[-1] * n_out_time[-2] * n_out_time[-3], num_classes) 91 | 92 | def forward(self, x): 93 | conv_features = self.features(x) 94 | features = torch.flatten(conv_features, 1) 95 | cls = self.classifier(features) 96 | return cls 97 | 98 | 99 | class LMDA(nn.Module): 100 | """ 101 | LMDA-Net for the paper 102 | """ 103 | def __init__(self, chans=22, samples=1125, num_classes=4, depth=9, kernel=75, channel_depth1=24, channel_depth2=9, 104 | ave_depth=1, avepool=5): 105 | super(LMDA, self).__init__() 106 | self.ave_depth = ave_depth 107 | self.channel_weight = nn.Parameter(torch.randn(depth, 1, chans), requires_grad=True) 108 | nn.init.xavier_uniform_(self.channel_weight.data) 109 | 110 | 111 | self.time_conv = nn.Sequential( 112 | nn.Conv2d(depth, channel_depth1, kernel_size=(1, 1), groups=1, bias=False), 113 | nn.BatchNorm2d(channel_depth1), 114 | nn.Conv2d(channel_depth1, channel_depth1, kernel_size=(1, kernel), 115 | groups=channel_depth1, bias=False), 116 | nn.BatchNorm2d(channel_depth1), 117 | nn.GELU(), 118 | ) 119 | # self.avgPool1 = nn.AvgPool2d((1, 24)) 120 | self.chanel_conv = nn.Sequential( 121 | nn.Conv2d(channel_depth1, channel_depth2, kernel_size=(1, 1), groups=1, bias=False), 122 | nn.BatchNorm2d(channel_depth2), 123 | nn.Conv2d(channel_depth2, channel_depth2, kernel_size=(chans, 1), groups=channel_depth2, bias=False), 124 | nn.BatchNorm2d(channel_depth2), 125 | nn.GELU(), 126 | ) 127 | 128 | self.norm = nn.Sequential( 129 | nn.AvgPool3d(kernel_size=(1, 1, avepool)), 130 | # nn.AdaptiveAvgPool3d((9, 1, 35)), 131 | nn.Dropout(p=0.65), 132 | ) 133 | 134 | # 定义自动填充模块 135 | out = torch.ones((1, 1, chans, samples)) 136 | out = torch.einsum('bdcw, hdc->bhcw', out, self.channel_weight) 137 | out = self.time_conv(out) 138 | out = self.chanel_conv(out) 139 | out = self.norm(out) 140 | n_out_time = out.cpu().data.numpy().shape 141 | print('In ShallowNet, n_out_time shape: ', n_out_time) 142 | self.classifier = nn.Linear(n_out_time[-1]*n_out_time[-2]*n_out_time[-3], num_classes) 143 | 144 | def EEGDepthAttention(self, x): 145 | # x: input features with shape [N, C, H, W] 146 | 147 | N, C, H, W = x.size() 148 | # K = W if W % 2 else W + 1 149 | k = 7 150 | adaptive_pool = nn.AdaptiveAvgPool2d((1, W)) 151 | conv = nn.Conv2d(1, 1, kernel_size=(k, 1), padding=(k//2, 0), bias=True).to(x.device) # original kernel k 152 | nn.init.xavier_uniform_(conv.weight) 153 | nn.init.constant_(conv.bias, 0) 154 | softmax = nn.Softmax(dim=-2) 155 | x_pool = adaptive_pool(x) 156 | x_transpose = x_pool.transpose(-2, -3) 157 | y = conv(x_transpose) 158 | y = softmax(y) 159 | y = y.transpose(-2, -3) 160 | return y * C * x 161 | 162 | def forward(self, x): 163 | x = torch.einsum('bdcw, hdc->bhcw', x, self.channel_weight) 164 | 165 | x_time = self.time_conv(x) # batch, depth1, channel, samples_ 166 | x_time = self.EEGDepthAttention(x_time) # DA1 167 | 168 | x = self.chanel_conv(x_time) # batch, depth2, 1, samples_ 169 | x = self.norm(x) 170 | 171 | features = torch.flatten(x, 1) 172 | cls = self.classifier(features) 173 | return cls 174 | 175 | 176 | if __name__ == '__main__': 177 | model = ShallowConvNet(num_classes=4, chans=22, samples=1125).cuda() 178 | a = torch.randn(12, 1, 3, 875).cuda().float() 179 | l2 = model(a) 180 | model_optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2) 181 | summary(model, show_input=True) 182 | 183 | print(l2.shape) 184 | 185 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WeightFreezing 2 | Submitted to Neural Networks, [Arxiv version](https://arxiv.org/pdf/2306.05775.pdf). 3 | 4 | # Description 5 | Source code for the paper: Weight-Freezing: A Regularization Approach for Fully Connected Layers with an Application in EEG Classification 6 | 7 | Due to my current inability to afford high open-access fees, the final version of this work may not be open-access. I apologize for any inconvenience caused. 8 | 9 | ## Important Statement 10 | This is my first work on designing a transformer specifically for EEG. The preliminary experiments for a related project have been conducted, but due to the pressing deadline of my doctoral thesis, I have had to temporarily pause this aspect of my work. After October 2023, I hope to continue researching this topic if given the opportunity (That's when I might visit the University of Vienna). 11 | My doctoral research focuses on the deep integration of artificial neural networks and neuroscience, advocating for the use of lightweight artificial neural network technologies to enhance EEG decoding. It is possible that my doctoral thesis and published papers may not have Chinese versions. However, there is a high-quality Chinese translation of the LMDA-Net paper available on WeChat official account (脑机接口社区), which interested readers can search for. 12 | 13 | In the future, I hope to collaborate with internationally renowned research groups to further explore the applications of lightweight artificial neural networks in BCI. My research strengths in this field lie in the deep integration of digital signal processing, machine learning, deep learning, and neuroscience systems. I possess strong problem-solving abilities and have a solid foundation in mathematics and programming. I am comfortable working in an all-English office environment and capable of independently completing research tasks in this field. Additionally, I have previous experience in research on autonomous driving platforms, which has provided me with knowledge in areas such as computer vision and circuitry. I also possess strong teamwork skills. 14 | If your research group is seeking to recruit a postdoctoral researcher in this field, I would greatly appreciate the opportunity for an interview. (mzq@tju.edu.cn) 15 | 16 | # Requirements 17 | - Python == 3.6 or higher 18 | - Pytorch == 1.10 or higher 19 | - GPU is required. 20 | 21 | # Contributions 22 | - To the best of our knowledge, this paper is the first to study the impact of the classifier in ANNs on EEG decoding performance. For this purpose, Weight-Freezing is proposed, which suppresses the influence of some input neurons on certain decision results by freezing some parameters in the fully connected layer, thereby achieving higher classification accuracy. 23 | - Weight-Freezing is also a novel regularization method, which can achieve sparse connections in the fully connected network. 24 | - Weight-Freezing is thoroughly validated and analyzed in three classic decoding networks and three highly cited public EEG datasets. The experimental results confirm the superiority of Weight-Freezing in classification and have also achieved state-of-the-art classification performance (averaged across all participants) for all the three highly cited datasets. 25 | 26 | This study's primary contribution lies in its potent facilitation of the application and implementation of Artificial Neural Network (ANN) models within Brain-Computer Interface (BCI) systems. Simultaneously, it sets a new performance benchmark for future EEG signal decoding efforts using more sizable models, such as transformers. 27 | Emerging research is increasingly adopting transformer networks for EEG signal decoding. These approaches can be viewed as enrichments to existing ANN models, as they elevate EEG classification accuracy via more sophisticated feature extraction networks. However, these enhancements have inadvertently complicated the deployment of these ANN models in real-world BCI systems. 28 | In a stark contrast, our study introduces Weight-Freezing as an innovative, subtractive strategy that refines existing ANN models. Empowered by Weight-Freezing, some lightweight and shallow decoding networks surpass all current transformer-based methods in terms of classification performance on identical public datasets. 29 | The incorporation of Weight-Freezing not only simplifies the deployment of ANN models within BCI systems but also sets a new performance standard for the deployment of larger models, such as transformers, in the future. Moreover, it provokes an intriguing question in the realm of EEG decoding: Is the deployment of large models like transformers for EEG feature extraction truly indispensable? 30 | 31 | # Results 32 | ![33f3428681103234abb0acb07c6a6ca](https://github.com/MiaoZhengQing/WeightFreezing/assets/116713490/abb617bd-f3ae-418f-9dd5-5ffb24cbbb4f) 33 | ![6b598f8a5dfeff920c909b9f93f4a09](https://github.com/MiaoZhengQing/WeightFreezing/assets/116713490/5a86123d-852c-405d-b98b-539e039243a6) 34 | 35 | # Models Implemented 36 | - [LMDA-Net](https://doi.org/10.1016/j.neuroimage.2023.120209) 37 | - [EEGNet](https://github.com/vlawhern/arl-eegmodels) 38 | - [ShallowConvNet](https://github.com/TNTLFreiburg/braindecode) 39 | 40 | # Related works 41 | - This paper is a follow-up version of [SDDA](https://arxiv.org/pdf/2202.09559.pdf) and [LMDA-Net](https://doi.org/10.1016/j.neuroimage.2023.120209), the preprocessing method is inherited from SDDA. 42 | 43 | 44 | # Paper Citation 45 | If you use this idea and code in a scientific publication, please cite us as: 46 | % Weight-Freezing 47 | Miao Z, Zhao M. Weight Freezing: A Regularization Approach for Fully Connected Layers with an Application in EEG Classification[J]. arXiv preprint arXiv:2306.05775, 2023. 48 | 49 | % LMDA-Net 50 | Miao Z, Zhang X, Zhao M, et al. LMDA-Net: A lightweight multi-dimensional attention network for general EEG-based brain-computer interface paradigms and interpretability[J]. arXiv preprint arXiv:2303.16407, 2023. 51 | 52 | % SDDA 53 | Miao Z, Zhang X, Menon C, et al. Priming Cross-Session Motor Imagery Classification with A Universal Deep Domain Adaptation Framework[J]. arXiv preprint arXiv:2202.09559, 2022. 54 | 55 | ``` 56 | % Weight-Freezing 57 | @article{miao2023weight, 58 | title={Weight Freezing: A Regularization Approach for Fully Connected Layers with an Application in EEG Classification}, 59 | author={Miao, Zhengqing and Zhao, Meirong}, 60 | journal={arXiv preprint arXiv:2306.05775}, 61 | year={2023}, 62 | doi={https://doi.org/10.48550/arXiv.2306.05775}, 63 | } 64 | 65 | % LMDA 66 | @article{miao2023lmda, 67 | title = {LMDA-Net:A lightweight multi-dimensional attention network for general EEG-based brain-computer interfaces and interpretability}, 68 | journal = {NeuroImage}, 69 | volume = {276}, 70 | pages = {120209}, 71 | year = {2023}, 72 | issn = {1053-8119}, 73 | doi = {https://doi.org/10.1016/j.neuroimage.2023.120209}, 74 | url = {https://www.sciencedirect.com/science/article/pii/S1053811923003609}, 75 | author = {Zhengqing Miao and Meirong Zhao and Xin Zhang and Dong Ming}, 76 | keywords = {Attention, Brain-computer interface (BCI), Electroencephalography (EEG), Model interpretability, Neural networks}, 77 | abstract = {Electroencephalography (EEG)-based brain-computer interfaces (BCIs) pose a challenge for decoding due to their low spatial resolution and signal-to-noise ratio. Typically, EEG-based recognition of activities and states involves the use of prior neuroscience knowledge to generate quantitative EEG features, which may limit BCI performance. Although neural network-based methods can effectively extract features, they often encounter issues such as poor generalization across datasets, high predicting volatility, and low model interpretability. To address these limitations, we propose a novel lightweight multi-dimensional attention network, called LMDA-Net. By incorporating two novel attention modules designed specifically for EEG signals, the channel attention module and the depth attention module, LMDA-Net is able to effectively integrate features from multiple dimensions, resulting in improved classification performance across various BCI tasks. LMDA-Net was evaluated on four high-impact public datasets, including motor imagery (MI) and P300-Speller, and was compared with other representative models. The experimental results demonstrate that LMDA-Net outperforms other representative methods in terms of classification accuracy and predicting volatility, achieving the highest accuracy in all datasets within 300 training epochs. Ablation experiments further confirm the effectiveness of the channel attention module and the depth attention module. To facilitate an in-depth understanding of the features extracted by LMDA-Net, we propose class-specific neural network feature interpretability algorithms that are suitable for evoked responses and endogenous activities. By mapping the output of the specific layer of LMDA-Net to the time or spatial domain through class activation maps, the resulting feature visualizations can provide interpretable analysis and establish connections with EEG time-spatial analysis in neuroscience. In summary, LMDA-Net shows great potential as a general decoding model for various EEG tasks.} 78 | } 79 | 80 | % TSFF-Net 81 | @article{miao2023time, 82 | title={Time-space-frequency feature Fusion for 3-channel motor imagery classification}, 83 | author={Miao, Zhengqing and Zhao, Meirong}, 84 | journal={arXiv preprint arXiv:2304.01461}, 85 | year={2023}, 86 | doi={https://doi.org/10.48550/arXiv.2304.01461}, 87 | } 88 | 89 | % SDDA 90 | @article{miao2022priming, 91 | title={Priming Cross-Session Motor Imagery Classification with A Universal Deep Domain Adaptation Framework}, 92 | author={Miao, Zhengqing and Zhang, Xin and Menon, Carlo and Zheng, Yelong and Zhao, Meirong and Ming, Dong}, 93 | journal={arXiv preprint arXiv:2202.09559}, 94 | year={2022} 95 | } 96 | ``` 97 | 98 | # Contact 99 | Email: mzq@tju.edu.cn 100 | -------------------------------------------------------------------------------- /modelsWithWeightFreezing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from compared_models import square_activation, safe_log 4 | import math 5 | import torch.nn.functional as F 6 | 7 | 8 | class WeightFreezing(nn.Module): 9 | def __init__(self, input_dim, output_dim, shared_ratio=0.3, multiple=0): 10 | super(WeightFreezing, self).__init__() 11 | 12 | self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim)) 13 | self.bias = nn.Parameter(torch.Tensor(output_dim)) 14 | 15 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 16 | if self.bias is not None: 17 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 18 | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 19 | nn.init.uniform_(self.bias, -bound, bound) 20 | 21 | mask = torch.rand(input_dim, output_dim) < shared_ratio 22 | self.register_buffer('shared_mask', mask) 23 | self.register_buffer('independent_mask', ~mask) 24 | 25 | self.multiple = multiple 26 | 27 | def forward(self, x, shared_weight): 28 | combined_weight = torch.where(self.shared_mask, shared_weight*self.multiple, self.weight.t()) 29 | output = F.linear(x, combined_weight.t(), self.bias) 30 | return output 31 | 32 | 33 | class ConvNetWeightFreezing(nn.Module): 34 | def __init__(self, num_classes=4, chans=22, samples=1125, shared_ratio=0.1): 35 | super(ConvNetWeightFreezing, self).__init__() 36 | self.conv_nums = 40 37 | self.num_classes = num_classes 38 | self.features = nn.Sequential( 39 | nn.Conv2d(1, self.conv_nums, (1, 25)), 40 | nn.Conv2d(self.conv_nums, self.conv_nums, (chans, 1), bias=False), 41 | nn.BatchNorm2d(self.conv_nums) 42 | ) 43 | self.avgpool = nn.AvgPool2d(kernel_size=(1, 75), stride=(1, 15)) 44 | self.dropout = nn.Dropout() 45 | 46 | out = torch.ones((1, 1, chans, samples)) 47 | out = self.features(out) 48 | out = self.avgpool(out) 49 | n_out_time = out.cpu().data.numpy().shape # [batch, self.conv_nums, 1, times] 50 | # share part weights 51 | shared_ratio = shared_ratio 52 | 53 | self.classifier = WeightFreezing(n_out_time[-1] * n_out_time[-2] * n_out_time[-3], num_classes, 54 | shared_ratio=shared_ratio) 55 | 56 | self.shared_weights = nn.Parameter(torch.Tensor(num_classes, n_out_time[-1] * n_out_time[-2] * n_out_time[-3]), 57 | requires_grad=False) 58 | self.bias = nn.Parameter(torch.Tensor(num_classes)) 59 | 60 | nn.init.kaiming_uniform_(self.shared_weights, a=math.sqrt(5)) 61 | if self.bias is not None: 62 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.shared_weights) 63 | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 64 | nn.init.uniform_(self.bias, -bound, bound) 65 | 66 | self.fixed_weight = self.shared_weights.t() * self.classifier.shared_mask 67 | 68 | def forward(self, x): 69 | x = self.features(x) 70 | x = square_activation(x) 71 | x = self.avgpool(x) 72 | x = safe_log(x) 73 | x = self.dropout(x) 74 | # x: [batch, 40, 1, times] 75 | features = torch.flatten(x, 1) # 使用卷积网络代替全连接层进行分类, 因此需要返回x和卷积层个数 76 | 77 | cls = self.classifier(features, self.fixed_weight.to(features.device)) 78 | 79 | return cls 80 | 81 | 82 | class EEGNetWeightFreezing(nn.Module): 83 | def __init__(self, num_classes, chans, samples=1125, dropout_rate=0.5, kernel_length=64, F1=8, 84 | F2=16, shared_ratio=0.1): 85 | super(EEGNetWeightFreezing, self).__init__() 86 | self.num_classes = num_classes 87 | 88 | self.features = nn.Sequential( 89 | nn.Conv2d(1, F1, kernel_size=(1, kernel_length), bias=False), 90 | nn.BatchNorm2d(F1), 91 | nn.Conv2d(F1, F1, kernel_size=(chans, 1), groups=F1, bias=False), # groups=F1 for depthWiseConv 92 | nn.BatchNorm2d(F1), 93 | nn.ELU(inplace=True), 94 | # nn.ReLU(), 95 | nn.AvgPool2d((1, 4)), 96 | nn.Dropout(dropout_rate), 97 | # for SeparableCon2D 98 | # SeparableConv2D(F1, F2, kernel1_size=(1, 16), bias=False), 99 | nn.Conv2d(F1, F1, kernel_size=(1, 16), groups=F1, bias=False), # groups=F1 for depthWiseConv 100 | nn.BatchNorm2d(F1), 101 | nn.ELU(inplace=True), 102 | # nn.ReLU(), 103 | nn.Conv2d(F1, F2, kernel_size=(1, 1), groups=1, bias=False), # point-wise cnn 104 | nn.BatchNorm2d(F2), 105 | # nn.ReLU(), 106 | nn.ELU(inplace=True), 107 | nn.AvgPool2d((1, 8)), 108 | nn.Dropout(p=dropout_rate), 109 | ) 110 | out = torch.ones((1, 1, chans, samples)) 111 | out = self.features(out) 112 | n_out_time = out.cpu().data.numpy().shape 113 | # share part weights 114 | shared_ratio = shared_ratio 115 | 116 | self.classifier = WeightFreezing(n_out_time[-1] * n_out_time[-2] * n_out_time[-3], num_classes, 117 | shared_ratio=shared_ratio) 118 | 119 | self.shared_weights = nn.Parameter(torch.Tensor(num_classes, n_out_time[-1] * n_out_time[-2] * n_out_time[-3]), 120 | requires_grad=False) 121 | self.bias = nn.Parameter(torch.Tensor(num_classes)) 122 | 123 | nn.init.kaiming_uniform_(self.shared_weights, a=math.sqrt(5)) 124 | if self.bias is not None: 125 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.shared_weights) 126 | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 127 | nn.init.uniform_(self.bias, -bound, bound) 128 | 129 | self.fixed_weight = self.shared_weights.t() * self.classifier.shared_mask 130 | 131 | def forward(self, x): 132 | x = self.features(x) 133 | features = torch.flatten(x, 1) 134 | cls = self.classifier(features, self.fixed_weight.to(features.device)) 135 | return cls 136 | 137 | 138 | class LMDAWeightFreezing(nn.Module): 139 | 140 | def __init__(self, chans=22, samples=1125, num_classes=4, depth=9, kernel=75, channel_depth1=40, channel_depth2=40, 141 | ave_depth=1, avepool=5, shared_ratio=0.1): 142 | super(LMDAWeightFreezing, self).__init__() 143 | self.ave_depth = ave_depth 144 | self.channel_weight = nn.Parameter(torch.randn(depth, 1, chans), requires_grad=True) 145 | nn.init.xavier_uniform_(self.channel_weight.data) 146 | 147 | self.num_classes = num_classes 148 | 149 | self.time_conv = nn.Sequential( 150 | nn.Conv2d(depth, channel_depth1, kernel_size=(1, 1), groups=1, bias=False), 151 | nn.BatchNorm2d(channel_depth1), 152 | nn.Conv2d(channel_depth1, channel_depth1, kernel_size=(1, kernel), 153 | groups=channel_depth1, bias=False), 154 | nn.BatchNorm2d(channel_depth1), 155 | nn.GELU(), 156 | ) 157 | # self.avgPool1 = nn.AvgPool2d((1, 24)) 158 | self.chanel_conv = nn.Sequential( 159 | nn.Conv2d(channel_depth1, channel_depth2, kernel_size=(1, 1), groups=1, bias=False), 160 | nn.BatchNorm2d(channel_depth2), 161 | nn.Conv2d(channel_depth2, channel_depth2, kernel_size=(chans, 1), groups=channel_depth2, bias=False), 162 | nn.BatchNorm2d(channel_depth2), 163 | nn.GELU(), 164 | ) 165 | 166 | self.norm = nn.Sequential( 167 | nn.AvgPool3d(kernel_size=(1, 1, avepool)), 168 | # nn.AdaptiveAvgPool3d((9, 1, 35)), 169 | nn.Dropout(p=0.65), 170 | ) 171 | 172 | # 定义自动填充模块 173 | out = torch.ones((1, 1, chans, samples)) 174 | out = torch.einsum('bdcw, hdc->bhcw', out, self.channel_weight) 175 | out = self.time_conv(out) 176 | # out = self.avgPool1(out) 177 | out = self.chanel_conv(out) 178 | out = self.norm(out) 179 | n_out_time = out.cpu().data.numpy().shape 180 | print('In ShallowNet, n_out_time shape: ', n_out_time) 181 | # share part weights 182 | shared_ratio = shared_ratio 183 | 184 | self.classifier = WeightFreezing(n_out_time[-1] * n_out_time[-2] * n_out_time[-3], num_classes, 185 | shared_ratio=shared_ratio) 186 | 187 | self.shared_weights = nn.Parameter(torch.Tensor(num_classes, n_out_time[-1] * n_out_time[-2] * n_out_time[-3]), 188 | requires_grad=False) 189 | self.bias = nn.Parameter(torch.Tensor(num_classes)) 190 | 191 | nn.init.kaiming_uniform_(self.shared_weights, a=math.sqrt(5)) 192 | if self.bias is not None: 193 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.shared_weights) 194 | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 195 | nn.init.uniform_(self.bias, -bound, bound) 196 | 197 | self.fixed_weight = self.shared_weights.t() * self.classifier.shared_mask 198 | 199 | def EEGDepthAttention(self, x): 200 | # x: input features with shape [N, C, H, W] 201 | N, C, H, W = x.size() 202 | # K = W if W % 2 else W + 1 203 | k = 7 204 | adaptive_pool = nn.AdaptiveAvgPool2d((1, W)) 205 | conv = nn.Conv2d(1, 1, kernel_size=(k, 1), padding=(k//2, 0), bias=True).to(x.device) # original kernel k 206 | softmax = nn.Softmax(dim=-2) 207 | x_pool = adaptive_pool(x) 208 | x_transpose = x_pool.transpose(-2, -3) 209 | y = conv(x_transpose) 210 | y = softmax(y) 211 | y = y.transpose(-2, -3) 212 | return y * C * x 213 | 214 | def forward(self, x): 215 | x = torch.einsum('bdcw, hdc->bhcw', x, self.channel_weight) 216 | 217 | x_time = self.time_conv(x) # batch, depth1, channel, samples_ 218 | x_time = self.EEGDepthAttention(x_time) # DA1 219 | 220 | x = self.chanel_conv(x_time) # batch, depth2, 1, samples_ 221 | x = self.norm(x) 222 | 223 | features = torch.flatten(x, 1) 224 | 225 | cls = self.classifier(features, self.fixed_weight.to(features.device)) 226 | return cls 227 | 228 | 229 | 230 | if __name__ == '__main__': 231 | a = torch.randn(32, 1, 22, 1125) 232 | model = ConvNetWeightFreezing() 233 | print(model(a).shape) 234 | --------------------------------------------------------------------------------