├── .DS_Store ├── .idea ├── .gitignore ├── PPDA_PLUS.iml ├── inspectionProfiles │ └── Project_Default.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── GradientReverseLayer.py ├── LICENSE ├── Readme.md ├── T-SNE ├── generatePlotByTSNE.py ├── model │ ├── 1_pretrain_model.pth │ └── 1_tune_model.pth └── plot │ ├── origin_label.jpg │ ├── origin_subject.jpg │ ├── pretrain_label.jpg │ ├── pretrain_subject.jpg │ ├── tune_label.jpg │ └── tune_subject.jpg ├── __pycache__ ├── GradientReverseLayer.cpython-36.pyc ├── model.cpython-36.pyc ├── preprocess.cpython-36.pyc ├── test.cpython-36.pyc └── train.cpython-36.pyc ├── ablation ├── withoutBothMixAndNoise.py ├── withoutNoise.py └── witoutMix.py ├── main.py ├── model.py ├── noiseInjectionMethods ├── Dropout.py ├── channelsShuffling.py ├── maskChannels.py └── maskTimeSteps.py ├── preprocess.py ├── requirements.txt ├── test.py └── train.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeBreathing/DMMR/5fd5e6ebf1a14d7afa1c07213a71aa97980b5b8b/.DS_Store -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/PPDA_PLUS.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 17 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /GradientReverseLayer.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Function 2 | 3 | #gradient reversal 4 | class ReverseLayerF(Function): 5 | @staticmethod 6 | def forward(ctx, x, m): 7 | ctx.m = m 8 | return x.view_as(x) 9 | 10 | @staticmethod 11 | def backward(ctx, grad_output): 12 | output = grad_output.neg() * ctx.m 13 | return output, None -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # DMMR 2 | This is the official PyTorch implementation for our AAAI'24 paper DMMR: Cross-Subject Domain Generalization for EEG-Based Emotion Recognition via Denoising Mixed Mutual Reconstruction 3 | [Paper link:](https://ojs.aaai.org/index.php/AAAI/article/view/27819) 4 | 5 | ## Datasets 6 | The public available datasets (SEED and SEED-IV) can be downloaded from the https://bcmi.sjtu.edu.cn/home/seed/index.html 7 | 8 | To facilitate data retrieval, the data from the first session of all subjects is utilized in both datasets, the file structure of the datasets should be like: 9 | ``` 10 | ExtractedFeatures/ 11 | 1/ 12 | eeg_feature_smooth/ 13 | 1/ 14 | ``` 15 | Kindly change the file path in the main.py 16 | 17 | ## Usage of DMMR 18 | Run `python main.py`, and The results will be recorded in TensorBoard. 19 | The argument for the `dataset_name` is set to be `seed3` for the SEED dataset, and `seed4` for the SEED-IV dataset, respectively. 20 | 21 | ## Ablation Studies 22 | Run `python ablation/witoutMix.py` 23 | Run `python ablation/withoutNoise.py` 24 | Run `python ablation/withoutBothMixAndNoise.py` 25 | 26 | ## other noise injection methods 27 | Run `python noiseInjectionMethods/maskChannels.py` 28 | Run `python noiseInjectionMethods/maskTimeSteps.py` 29 | Run `python noiseInjectionMethods/channelsShuffling.py` 30 | Run `python noiseInjectionMethods/Dropout.py` 31 | 32 | ## Plot with TSNE 33 | Run `python T-SNE/generatePlotByTSNE.py` 34 | 35 | ## Citation 36 | If you found our work useful for your research, please cite our work: 37 | ``` 38 | @inproceedings{wang2024dmmr, 39 | title={DMMR: Cross-Subject Domain Generalization for EEG-Based Emotion Recognition via Denoising Mixed Mutual Reconstruction}, 40 | author={Wang, Yiming and Zhang, Bin and Tang, Yujiao}, 41 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 42 | volume={38}, 43 | number={1}, 44 | pages={628--636}, 45 | year={2024} 46 | } 47 | ``` 48 | We thank the following repositories for providing helpful functions used in our work: 49 | [MS-MDA](https://github.com/VoiceBeer/MS-MDA) 50 | [DANN](https://github.com/fungtion/DANN) 51 | 52 | 53 | -------------------------------------------------------------------------------- /T-SNE/generatePlotByTSNE.py: -------------------------------------------------------------------------------- 1 | from preprocess import getDataLoaders 2 | import math 3 | import argparse 4 | from train import * 5 | import random 6 | import os 7 | import numpy as np 8 | 9 | os.environ["CUDA_VISIBLE_DEVICES"] = '3' 10 | 11 | def set_seed(seed=3): 12 | random.seed(seed) 13 | np.random.seed(seed) 14 | os.environ['PYTHONHASHSEED'] = str(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | torch.backends.cudnn.deterministic = True 19 | 20 | def main(data_loader_dict,args, cuda, seed=3): 21 | set_seed(seed) 22 | acc = TSNEForDMMR(data_loader_dict, cuda, args) 23 | return acc 24 | 25 | if __name__ == '__main__': 26 | cuda = torch.cuda.is_available() 27 | parser = argparse.ArgumentParser(description='DMMR') 28 | 29 | #config of experiment 30 | parser.add_argument("--way", type=str, default='TSNE', help="name of current way") 31 | parser.add_argument("--index", type=str, default='plot features', help="tensorboard index") 32 | parser.add_argument("--num_workers_train", type=int, default=0, help="classes of dataset") 33 | parser.add_argument("--num_workers_test", type=int, default=0, help="classes of dataset") 34 | parser.add_argument("--path", type=str, default="../eeg_data/ExtractedFeatures/", help="classes of dataset") 35 | 36 | #config of dataset 37 | parser.add_argument("--dataset_name", type=str, nargs='?', default='seed3', help="all subject numbers") 38 | parser.add_argument("--session", type=str, nargs='?', default='1', help="select session") 39 | parser.add_argument("--subjects", type=int, choices=[15], default=15, help="all subject numbers") 40 | parser.add_argument("--cls_classes", type=int, choices=[3], default=3, help="classes of dataset") 41 | parser.add_argument("--dim", type=int, default=310, help="dim of input") 42 | 43 | #config of net 44 | parser.add_argument("--input_dim", type=int, default=310, help="input dim is the same with sample's last dim") 45 | parser.add_argument("--hid_dim", type=int, default=64, help="hid dim is for hidden layer of lstm") 46 | parser.add_argument("--n_layers", type=int, default=2, help="num of layers of lstm") 47 | parser.add_argument("--batch_size", type=int, default=512, help="batch size") 48 | parser.add_argument("--time_steps", type=int, choices=[30], default=30, help="window size") 49 | parser.add_argument("--epoch_preTraining", type=int, default=300, help="epoch of baseModel") 50 | parser.add_argument("--epoch_fineTuning", type=int, default=500, help="epoch of baseModel4") 51 | parser.add_argument("--lr", type=int, default=1e-3, help="epoch of pretrain") 52 | parser.add_argument("--weight_decay", type=float, default=0.0005, help="weight decay") 53 | parser.add_argument("--beta", type=float, default=0.05, help="balancing hyperparameter in the loss of pretraining phase") 54 | 55 | args = parser.parse_args() 56 | args.source_subjects = args.subjects-1 57 | net_config = {"fts": args.dim, "cls": args.cls_classes} 58 | optim_config = {"lr": args.lr, "weight_decay": args.weight_decay} 59 | 60 | acc_list=[] 61 | #use the subject 1 as the target subject for testing 62 | one_subject = 1 63 | # 1.data preparation 64 | source_loaders, test_loader = getDataLoaders(one_subject, args) 65 | data_loader_dict = {"source_loader": source_loaders, "test_loader":test_loader} 66 | # 2. call main 67 | acc = main(data_loader_dict, args, cuda) -------------------------------------------------------------------------------- /T-SNE/model/1_pretrain_model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeBreathing/DMMR/5fd5e6ebf1a14d7afa1c07213a71aa97980b5b8b/T-SNE/model/1_pretrain_model.pth -------------------------------------------------------------------------------- /T-SNE/model/1_tune_model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeBreathing/DMMR/5fd5e6ebf1a14d7afa1c07213a71aa97980b5b8b/T-SNE/model/1_tune_model.pth -------------------------------------------------------------------------------- /T-SNE/plot/origin_label.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeBreathing/DMMR/5fd5e6ebf1a14d7afa1c07213a71aa97980b5b8b/T-SNE/plot/origin_label.jpg -------------------------------------------------------------------------------- /T-SNE/plot/origin_subject.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeBreathing/DMMR/5fd5e6ebf1a14d7afa1c07213a71aa97980b5b8b/T-SNE/plot/origin_subject.jpg -------------------------------------------------------------------------------- /T-SNE/plot/pretrain_label.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeBreathing/DMMR/5fd5e6ebf1a14d7afa1c07213a71aa97980b5b8b/T-SNE/plot/pretrain_label.jpg -------------------------------------------------------------------------------- /T-SNE/plot/pretrain_subject.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeBreathing/DMMR/5fd5e6ebf1a14d7afa1c07213a71aa97980b5b8b/T-SNE/plot/pretrain_subject.jpg -------------------------------------------------------------------------------- /T-SNE/plot/tune_label.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeBreathing/DMMR/5fd5e6ebf1a14d7afa1c07213a71aa97980b5b8b/T-SNE/plot/tune_label.jpg -------------------------------------------------------------------------------- /T-SNE/plot/tune_subject.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeBreathing/DMMR/5fd5e6ebf1a14d7afa1c07213a71aa97980b5b8b/T-SNE/plot/tune_subject.jpg -------------------------------------------------------------------------------- /__pycache__/GradientReverseLayer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeBreathing/DMMR/5fd5e6ebf1a14d7afa1c07213a71aa97980b5b8b/__pycache__/GradientReverseLayer.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeBreathing/DMMR/5fd5e6ebf1a14d7afa1c07213a71aa97980b5b8b/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/preprocess.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeBreathing/DMMR/5fd5e6ebf1a14d7afa1c07213a71aa97980b5b8b/__pycache__/preprocess.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/test.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeBreathing/DMMR/5fd5e6ebf1a14d7afa1c07213a71aa97980b5b8b/__pycache__/test.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/train.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeBreathing/DMMR/5fd5e6ebf1a14d7afa1c07213a71aa97980b5b8b/__pycache__/train.cpython-36.pyc -------------------------------------------------------------------------------- /ablation/withoutBothMixAndNoise.py: -------------------------------------------------------------------------------- 1 | from preprocess import getDataLoaders 2 | import math 3 | from torch.utils.tensorboard import SummaryWriter 4 | import argparse 5 | from train import * 6 | import random 7 | import os 8 | import numpy as np 9 | 10 | os.environ["CUDA_VISIBLE_DEVICES"] = '3' 11 | 12 | def set_seed(seed=3): 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | os.environ['PYTHONHASHSEED'] = str(seed) 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed(seed) 18 | torch.cuda.manual_seed_all(seed) 19 | torch.backends.cudnn.deterministic = True 20 | 21 | def main(data_loader_dict, args, optim_config, cuda, writer, one_subject, seed=3): 22 | set_seed(seed) 23 | if args.dataset_name == 'seed3': 24 | iteration = 7 25 | elif args.dataset_name == 'seed4': 26 | iteration = 3 27 | acc = trainDMMR_WithoutBothMixAndNoise(data_loader_dict, optim_config, cuda, args, iteration, writer, one_subject) 28 | return acc 29 | 30 | if __name__ == '__main__': 31 | cuda = torch.cuda.is_available() 32 | parser = argparse.ArgumentParser(description='DMMR') 33 | 34 | #config of experiment 35 | parser.add_argument("--way", type=str, default='DMMR/seed3', help="name of current way") 36 | parser.add_argument("--index", type=str, default='withoutBothMixAndNoise', help="tensorboard index") 37 | 38 | #config of dataset 39 | parser.add_argument("--dataset_name", type=str, nargs='?', default='seed3', help="the dataset name, supporting seed3 and seed4") 40 | parser.add_argument("--session", type=str, nargs='?', default='1', help="selected session") 41 | parser.add_argument("--subjects", type=int, choices=[15], default=15, help="the number of all subject") 42 | parser.add_argument("--dim", type=int, default=310, help="dim of input") 43 | 44 | #config of DMMR 45 | parser.add_argument("--input_dim", type=int, default=310, help="input dim is the same with sample's last dim") 46 | parser.add_argument("--hid_dim", type=int, default=64, help="hid dim is for hidden layer of lstm") 47 | parser.add_argument("--n_layers", type=int, default=1, help="num of layers of lstm") 48 | parser.add_argument("--epoch_fineTuning", type=int, default=500, help="epoch of the fine-tuning phase") 49 | parser.add_argument("--lr", type=int, default=1e-3, help="epoch of calModel") 50 | parser.add_argument("--weight_decay", type=float, default=0.0005, help="weight decay") 51 | parser.add_argument("--beta", type=float, default=0.05, help="balancing hyperparameter in the loss of pretraining phase") 52 | 53 | 54 | args = parser.parse_args() 55 | args.source_subjects = args.subjects-1 56 | args.seed3_path = "../eeg_data/ExtractedFeatures/" 57 | args.seed4_path = "../eeg_data/eeg_feature_smooth/" 58 | if cuda: 59 | args.num_workers_train = 4 60 | args.num_workers_test = 2 61 | else: 62 | args.num_workers_train = 0 63 | args.num_workers_test = 0 64 | if args.dataset_name == "seed3": 65 | args.path = args.seed3_path 66 | args.cls_classes = 3 67 | args.time_steps = 30 68 | args.batch_size = 512 #batch_size 69 | args.epoch_preTraining = 300 #epoch of the pre-training phase 70 | elif args.dataset_name == "seed4": 71 | args.path = args.seed4_path 72 | args.cls_classes = 4 73 | args.time_steps = 10 74 | args.batch_size = 256 #batch_size 75 | args.epoch_preTraining = 400 #epoch of the pre-training phase 76 | else: 77 | print("need to define the input dataset") 78 | optim_config = {"lr": args.lr, "weight_decay": args.weight_decay} 79 | # leave-one-subject-out cross-validation 80 | acc_list=[] 81 | writer = SummaryWriter("data/session"+args.session+"/"+args.way+"/" + args.index) 82 | for one_subject in range(0, args.subjects): 83 | # 1.data preparation 84 | source_loaders, test_loader = getDataLoaders(one_subject, args) 85 | data_loader_dict = {"source_loader": source_loaders, "test_loader":test_loader} 86 | # 2. main 87 | acc = main(data_loader_dict, args, optim_config, cuda, writer, one_subject) 88 | writer.add_scalars('single experiment acc: ', 89 | {'test acc': acc}, one_subject + 1) 90 | writer.flush() 91 | acc_list.append(acc) 92 | writer.add_text('final acc avg', str(np.average(acc_list))) 93 | writer.add_text('final acc std', str(np.std(acc_list))) 94 | acc_list_str = [str(x) for x in acc_list] 95 | writer.add_text('final each acc', ",".join(acc_list_str)) 96 | writer.add_scalars('final experiment acc scala: /avg', 97 | {'test acc': np.average(acc_list)}) 98 | writer.add_scalars('final experiment acc scala: /std', 99 | {'test acc': np.std(acc_list)}) 100 | writer.close() 101 | -------------------------------------------------------------------------------- /ablation/withoutNoise.py: -------------------------------------------------------------------------------- 1 | from preprocess import getDataLoaders 2 | import math 3 | from torch.utils.tensorboard import SummaryWriter 4 | import argparse 5 | from train import * 6 | import random 7 | import os 8 | import numpy as np 9 | 10 | os.environ["CUDA_VISIBLE_DEVICES"] = '3' 11 | 12 | def set_seed(seed=3): 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | os.environ['PYTHONHASHSEED'] = str(seed) 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed(seed) 18 | torch.cuda.manual_seed_all(seed) 19 | torch.backends.cudnn.deterministic = True 20 | 21 | def main(data_loader_dict, args, optim_config, cuda, writer, one_subject, seed=3): 22 | set_seed(seed) 23 | if args.dataset_name == 'seed3': 24 | iteration = 7 25 | elif args.dataset_name == 'seed4': 26 | iteration = 3 27 | acc = trainDMMR_WithoutNoise(data_loader_dict, optim_config, cuda, args, iteration, writer, one_subject) 28 | return acc 29 | 30 | if __name__ == '__main__': 31 | cuda = torch.cuda.is_available() 32 | parser = argparse.ArgumentParser(description='DMMR') 33 | 34 | #config of experiment 35 | parser.add_argument("--way", type=str, default='DMMR/seed3', help="name of current way") 36 | parser.add_argument("--index", type=str, default='0', help="tensorboard index") 37 | 38 | #config of dataset 39 | parser.add_argument("--dataset_name", type=str, nargs='?', default='seed3', help="the dataset name, supporting seed3 and seed4") 40 | parser.add_argument("--session", type=str, nargs='?', default='1', help="selected session") 41 | parser.add_argument("--subjects", type=int, choices=[15], default=15, help="the number of all subject") 42 | parser.add_argument("--dim", type=int, default=310, help="dim of input") 43 | 44 | #config of DMMR 45 | parser.add_argument("--input_dim", type=int, default=310, help="input dim is the same with sample's last dim") 46 | parser.add_argument("--hid_dim", type=int, default=64, help="hid dim is for hidden layer of lstm") 47 | parser.add_argument("--n_layers", type=int, default=1, help="num of layers of lstm") 48 | parser.add_argument("--epoch_fineTuning", type=int, default=500, help="epoch of the fine-tuning phase") 49 | parser.add_argument("--lr", type=int, default=1e-3, help="epoch of calModel") 50 | parser.add_argument("--weight_decay", type=float, default=0.0005, help="weight decay") 51 | parser.add_argument("--beta", type=float, default=0.05, help="balancing hyperparameter in the loss of pretraining phase") 52 | 53 | 54 | args = parser.parse_args() 55 | args.source_subjects = args.subjects-1 56 | args.seed3_path = "../eeg_data/ExtractedFeatures/" 57 | args.seed4_path = "../eeg_data/eeg_feature_smooth/" 58 | if cuda: 59 | args.num_workers_train = 4 60 | args.num_workers_test = 2 61 | else: 62 | args.num_workers_train = 0 63 | args.num_workers_test = 0 64 | if args.dataset_name == "seed3": 65 | args.path = args.seed3_path 66 | args.cls_classes = 3 67 | args.time_steps = 30 68 | args.batch_size = 512 #batch_size 69 | args.epoch_preTraining = 300 #epoch of the pre-training phase 70 | elif args.dataset_name == "seed4": 71 | args.path = args.seed4_path 72 | args.cls_classes = 4 73 | args.time_steps = 10 74 | args.batch_size = 256 #batch_size 75 | args.epoch_preTraining = 400 #epoch of the pre-training phase 76 | else: 77 | print("need to define the input dataset") 78 | optim_config = {"lr": args.lr, "weight_decay": args.weight_decay} 79 | # leave-one-subject-out cross-validation 80 | acc_list=[] 81 | writer = SummaryWriter("data/session"+args.session+"/"+args.way+"/" + args.index) 82 | for one_subject in range(0, args.subjects): 83 | # 1.data preparation 84 | source_loaders, test_loader = getDataLoaders(one_subject, args) 85 | data_loader_dict = {"source_loader": source_loaders, "test_loader":test_loader} 86 | # 2. main 87 | acc = main(data_loader_dict, args, optim_config, cuda, writer, one_subject) 88 | writer.add_scalars('single experiment acc: ', 89 | {'test acc': acc}, one_subject + 1) 90 | writer.flush() 91 | acc_list.append(acc) 92 | writer.add_text('final acc avg', str(np.average(acc_list))) 93 | writer.add_text('final acc std', str(np.std(acc_list))) 94 | acc_list_str = [str(x) for x in acc_list] 95 | writer.add_text('final each acc', ",".join(acc_list_str)) 96 | writer.add_scalars('final experiment acc scala: /avg', 97 | {'test acc': np.average(acc_list)}) 98 | writer.add_scalars('final experiment acc scala: /std', 99 | {'test acc': np.std(acc_list)}) 100 | writer.close() 101 | -------------------------------------------------------------------------------- /ablation/witoutMix.py: -------------------------------------------------------------------------------- 1 | from preprocess import getDataLoaders 2 | import math 3 | from torch.utils.tensorboard import SummaryWriter 4 | import argparse 5 | from train import * 6 | import random 7 | import os 8 | import numpy as np 9 | 10 | os.environ["CUDA_VISIBLE_DEVICES"] = '3' 11 | 12 | def set_seed(seed=3): 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | os.environ['PYTHONHASHSEED'] = str(seed) 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed(seed) 18 | torch.cuda.manual_seed_all(seed) 19 | torch.backends.cudnn.deterministic = True 20 | 21 | def main(data_loader_dict, args, optim_config, cuda, writer, one_subject, seed=3): 22 | set_seed(seed) 23 | if args.dataset_name == 'seed3': 24 | iteration = 7 25 | elif args.dataset_name == 'seed4': 26 | iteration = 3 27 | acc = trainDMMR_WithoutMix(data_loader_dict, optim_config, cuda, args, iteration, writer, one_subject) 28 | return acc 29 | 30 | if __name__ == '__main__': 31 | cuda = torch.cuda.is_available() 32 | parser = argparse.ArgumentParser(description='DMMR') 33 | 34 | #config of experiment 35 | parser.add_argument("--way", type=str, default='DMMR/seed3', help="name of current way") 36 | parser.add_argument("--index", type=str, default='withoutMix', help="tensorboard index") 37 | 38 | #config of dataset 39 | parser.add_argument("--dataset_name", type=str, nargs='?', default='seed3', help="the dataset name, supporting seed3 and seed4") 40 | parser.add_argument("--session", type=str, nargs='?', default='1', help="selected session") 41 | parser.add_argument("--subjects", type=int, choices=[15], default=15, help="the number of all subject") 42 | parser.add_argument("--dim", type=int, default=310, help="dim of input") 43 | 44 | #config of DMMR 45 | parser.add_argument("--input_dim", type=int, default=310, help="input dim is the same with sample's last dim") 46 | parser.add_argument("--hid_dim", type=int, default=64, help="hid dim is for hidden layer of lstm") 47 | parser.add_argument("--n_layers", type=int, default=1, help="num of layers of lstm") 48 | parser.add_argument("--epoch_fineTuning", type=int, default=500, help="epoch of the fine-tuning phase") 49 | parser.add_argument("--lr", type=int, default=1e-3, help="epoch of calModel") 50 | parser.add_argument("--weight_decay", type=float, default=0.0005, help="weight decay") 51 | parser.add_argument("--beta", type=float, default=0.05, help="balancing hyperparameter in the loss of pretraining phase") 52 | 53 | 54 | args = parser.parse_args() 55 | args.source_subjects = args.subjects-1 56 | args.seed3_path = "../eeg_data/ExtractedFeatures/" 57 | args.seed4_path = "../eeg_data/eeg_feature_smooth/" 58 | if cuda: 59 | args.num_workers_train = 4 60 | args.num_workers_test = 2 61 | else: 62 | args.num_workers_train = 0 63 | args.num_workers_test = 0 64 | if args.dataset_name == "seed3": 65 | args.path = args.seed3_path 66 | args.cls_classes = 3 67 | args.time_steps = 30 68 | args.batch_size = 512 #batch_size 69 | args.epoch_preTraining = 300 #epoch of the pre-training phase 70 | elif args.dataset_name == "seed4": 71 | args.path = args.seed4_path 72 | args.cls_classes = 4 73 | args.time_steps = 10 74 | args.batch_size = 256 #batch_size 75 | args.epoch_preTraining = 400 #epoch of the pre-training phase 76 | else: 77 | print("need to define the input dataset") 78 | optim_config = {"lr": args.lr, "weight_decay": args.weight_decay} 79 | # leave-one-subject-out cross-validation 80 | acc_list=[] 81 | writer = SummaryWriter("data/session"+args.session+"/"+args.way+"/" + args.index) 82 | for one_subject in range(0, args.subjects): 83 | # 1.data preparation 84 | source_loaders, test_loader = getDataLoaders(one_subject, args) 85 | data_loader_dict = {"source_loader": source_loaders, "test_loader":test_loader} 86 | # 2. main 87 | acc = main(data_loader_dict, args, optim_config, cuda, writer, one_subject) 88 | writer.add_scalars('single experiment acc: ', 89 | {'test acc': acc}, one_subject + 1) 90 | writer.flush() 91 | acc_list.append(acc) 92 | writer.add_text('final acc avg', str(np.average(acc_list))) 93 | writer.add_text('final acc std', str(np.std(acc_list))) 94 | acc_list_str = [str(x) for x in acc_list] 95 | writer.add_text('final each acc', ",".join(acc_list_str)) 96 | writer.add_scalars('final experiment acc scala: /avg', 97 | {'test acc': np.average(acc_list)}) 98 | writer.add_scalars('final experiment acc scala: /std', 99 | {'test acc': np.std(acc_list)}) 100 | writer.close() 101 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from preprocess import getDataLoaders 2 | import math 3 | from torch.utils.tensorboard import SummaryWriter 4 | import argparse 5 | from train import * 6 | import random 7 | import os 8 | import numpy as np 9 | 10 | os.environ["CUDA_VISIBLE_DEVICES"] = '3' 11 | 12 | def set_seed(seed=3): 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | os.environ['PYTHONHASHSEED'] = str(seed) 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed(seed) 18 | torch.cuda.manual_seed_all(seed) 19 | torch.backends.cudnn.deterministic = True 20 | 21 | def main(data_loader_dict, args, optim_config, cuda, writer, one_subject, seed=3): 22 | set_seed(seed) 23 | if args.dataset_name == 'seed3': 24 | iteration = 7 25 | elif args.dataset_name == 'seed4': 26 | iteration = 3 27 | acc = trainDMMR(data_loader_dict, optim_config, cuda, args, iteration, writer, one_subject) 28 | return acc 29 | 30 | if __name__ == '__main__': 31 | cuda = torch.cuda.is_available() 32 | parser = argparse.ArgumentParser(description='DMMR') 33 | 34 | #config of experiment 35 | parser.add_argument("--way", type=str, default='DMMR/seed3', help="name of current way") 36 | parser.add_argument("--index", type=str, default='0', help="tensorboard index") 37 | 38 | #config of dataset 39 | parser.add_argument("--dataset_name", type=str, nargs='?', default='seed3', help="the dataset name, supporting seed3 and seed4") 40 | parser.add_argument("--session", type=str, nargs='?', default='1', help="selected session") 41 | parser.add_argument("--subjects", type=int, choices=[15], default=15, help="the number of all subject") 42 | parser.add_argument("--dim", type=int, default=310, help="dim of input") 43 | 44 | #config of DMMR 45 | parser.add_argument("--input_dim", type=int, default=310, help="input dim is the same with sample's last dim") 46 | parser.add_argument("--hid_dim", type=int, default=64, help="hid dim is for hidden layer of lstm") 47 | parser.add_argument("--n_layers", type=int, default=1, help="num of layers of lstm") 48 | parser.add_argument("--epoch_fineTuning", type=int, default=500, help="epoch of the fine-tuning phase") 49 | parser.add_argument("--lr", type=int, default=1e-3, help="epoch of calModel") 50 | parser.add_argument("--weight_decay", type=float, default=0.0005, help="weight decay") 51 | parser.add_argument("--beta", type=float, default=0.05, help="balancing hyperparameter in the loss of pretraining phase") 52 | 53 | 54 | args = parser.parse_args() 55 | args.source_subjects = args.subjects-1 56 | args.seed3_path = "../eeg_data/ExtractedFeatures/" 57 | args.seed4_path = "../eeg_data/eeg_feature_smooth/" 58 | if cuda: 59 | args.num_workers_train = 4 60 | args.num_workers_test = 2 61 | else: 62 | args.num_workers_train = 0 63 | args.num_workers_test = 0 64 | if args.dataset_name == "seed3": 65 | args.path = args.seed3_path 66 | args.cls_classes = 3 67 | args.time_steps = 30 68 | args.batch_size = 512 #batch_size 69 | args.epoch_preTraining = 300 #epoch of the pre-training phase 70 | elif args.dataset_name == "seed4": 71 | args.path = args.seed4_path 72 | args.cls_classes = 4 73 | args.time_steps = 10 74 | args.batch_size = 256 #batch_size 75 | args.epoch_preTraining = 400 #epoch of the pre-training phase 76 | else: 77 | print("need to define the input dataset") 78 | optim_config = {"lr": args.lr, "weight_decay": args.weight_decay} 79 | # leave-one-subject-out cross-validation 80 | acc_list=[] 81 | writer = SummaryWriter("data/session"+args.session+"/"+args.way+"/" + args.index) 82 | for one_subject in range(0, args.subjects): 83 | # 1.data preparation 84 | source_loaders, test_loader = getDataLoaders(one_subject, args) 85 | data_loader_dict = {"source_loader": source_loaders, "test_loader":test_loader} 86 | # 2. main 87 | acc = main(data_loader_dict, args, optim_config, cuda, writer, one_subject) 88 | writer.add_scalars('single experiment acc: ', 89 | {'test acc': acc}, one_subject + 1) 90 | writer.flush() 91 | acc_list.append(acc) 92 | writer.add_text('final acc avg', str(np.average(acc_list))) 93 | writer.add_text('final acc std', str(np.std(acc_list))) 94 | acc_list_str = [str(x) for x in acc_list] 95 | writer.add_text('final each acc', ",".join(acc_list_str)) 96 | writer.add_scalars('final experiment acc scala: /avg', 97 | {'test acc': np.average(acc_list)}) 98 | writer.add_scalars('final experiment acc scala: /std', 99 | {'test acc': np.std(acc_list)}) 100 | writer.close() 101 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from GradientReverseLayer import ReverseLayerF 5 | import random 6 | import copy 7 | 8 | # The ABP module 9 | class Attention(nn.Module): 10 | def __init__(self, cuda, input_dim): 11 | super(Attention, self).__init__() 12 | self.input_dim = input_dim 13 | if cuda: 14 | self.w_linear = nn.Parameter(torch.randn(input_dim, input_dim).cuda()) 15 | self.u_linear = nn.Parameter(torch.randn(input_dim).cuda()) 16 | else: 17 | self.w_linear = nn.Parameter(torch.randn(input_dim, input_dim)) 18 | self.u_linear = nn.Parameter(torch.randn(input_dim)) 19 | 20 | def forward(self, x, batch_size, time_steps): 21 | x_reshape = torch.Tensor.reshape(x, [-1, self.input_dim]) 22 | attn_softmax = F.softmax(torch.mm(x_reshape, self.w_linear)+ self.u_linear,1) 23 | res = torch.mul(attn_softmax, x_reshape) 24 | res = torch.Tensor.reshape(res, [batch_size, time_steps, self.input_dim]) 25 | return res 26 | 27 | class LSTM(nn.Module): 28 | def __init__(self, input_dim=310, output_dim=64, layers=2, location=-1): 29 | super(LSTM, self).__init__() 30 | self.lstm = nn.LSTM(input_dim, output_dim, num_layers=layers, batch_first=True) 31 | self.location = location 32 | def forward(self, x): 33 | # self.lstm.flatten_parameters() 34 | feature, (hn, cn) = self.lstm(x) 35 | return feature[:, self.location, :], hn, cn 36 | 37 | class Encoder(nn.Module): 38 | def __init__(self, input_dim=310, hid_dim=64, n_layers=2): 39 | super(Encoder, self).__init__() 40 | self.theta = LSTM(input_dim, hid_dim, n_layers) 41 | def forward(self, x): 42 | x_h = self.theta(x) 43 | return x_h 44 | 45 | class Decoder(nn.Module): 46 | def __init__(self, input_dim=310, hid_dim=64, n_layers=2,output_dim=310): 47 | super(Decoder, self).__init__() 48 | self.rnn = nn.LSTM(input_dim, hid_dim, n_layers) 49 | self.fc_out = nn.Linear(hid_dim, output_dim) 50 | def forward(self, input, hidden, cell, time_steps): 51 | out =[] 52 | out1 = self.fc_out(input) 53 | out.append(out1) 54 | out1= out1.unsqueeze(0) # input = [batch size] to [1, batch size] 55 | for i in range(time_steps-1): 56 | output, (hidden, cell) = self.rnn(out1, 57 | (hidden, cell)) # output =[seq len, batch size, hid dim* ndirection] 58 | out_cur = self.fc_out(output.squeeze(0)) # prediction = [batch size, output dim] 59 | out.append(out_cur) 60 | out1 = out_cur.unsqueeze(0) 61 | out.reverse() 62 | out = torch.stack(out) 63 | out = out.transpose(1,0) #batch first 64 | return out, hidden, cell 65 | 66 | 67 | #namely The Subject Classifier SD 68 | class DomainClassifier(nn.Module): 69 | def __init__(self, input_dim =64, output_dim=14): 70 | super(DomainClassifier, self).__init__() 71 | self.classifier = nn.Linear(input_dim, output_dim) 72 | 73 | def forward(self, x): 74 | x = self.classifier(x) 75 | return x 76 | 77 | # The MSE loss 78 | class MSE(nn.Module): 79 | def __init__(self): 80 | super(MSE, self).__init__() 81 | 82 | def forward(self, pred, real): 83 | diffs = torch.add(real, -pred) 84 | n = torch.numel(diffs.data) 85 | mse = torch.sum(diffs.pow(2)) / n 86 | return mse 87 | 88 | 89 | def timeStepsShuffle(source_data): 90 | source_data_1 = source_data.clone() 91 | #retain the last time step 92 | curTimeStep_1 = source_data_1[:, -1, :] 93 | # get data of other time steps 94 | dim_size = source_data[:, :-1, :].size(1) 95 | # generate a random sequence 96 | idxs_1 = list(range(dim_size)) 97 | # generate a shuffled sequence 98 | random.shuffle(idxs_1) 99 | # get data corresponding to the shuffled sequence 100 | else_1 = source_data_1[:, idxs_1, :] 101 | # add the origin last time step 102 | result_1 = torch.cat([else_1, curTimeStep_1.unsqueeze(1)], dim=1) 103 | return result_1 104 | 105 | def maskTimeSteps(source_data, rate): 106 | source_data_1 = source_data.clone() 107 | num_zeros = int(source_data.size(1) * rate) 108 | #mask certain rate of time steps ignoring the last 109 | zero_indices_1 = torch.randperm(source_data_1.size(1)-1)[:num_zeros] 110 | source_data_1[:, zero_indices_1,:] = 0 111 | return source_data_1 112 | 113 | def maskChannels(source_data, args, rate): 114 | # reshape for operating the channel dimension 115 | source_data_reshaped = source_data.reshape(args.batch_size, args.time_steps, 5, 62) 116 | source_data_reshaped_1 = source_data_reshaped.clone() 117 | num_zeros = int(source_data_reshaped.size(-1) * rate) 118 | # mask certain rate of channels 119 | zero_indices_1 = torch.randperm(source_data_reshaped_1.size(-1))[:num_zeros] 120 | source_data_reshaped_1[..., zero_indices_1] = 0 121 | source_data_reshaped_1 = source_data_reshaped_1.reshape(args.batch_size, args.time_steps, 310) 122 | return source_data_reshaped_1 123 | 124 | def shuffleChannels(source_data, args): 125 | # reshape for operating the channel dimension 126 | source_data_reshaped = source_data.reshape(args.batch_size, args.time_steps, 5, 62) 127 | source_data_reshaped_1 = source_data_reshaped.clone() 128 | dim_size = source_data_reshaped[..., :].size(-1) 129 | # # generate a random sequence 130 | idxs_1 = list(range(dim_size)) 131 | random.shuffle(idxs_1) 132 | # shuffle channels 133 | source_data_reshaped_1 = source_data_reshaped_1[..., idxs_1] 134 | result_1 = source_data_reshaped_1.reshape(args.batch_size, args.time_steps, 310) 135 | return result_1 136 | 137 | # proposed DMMR model 138 | class DMMRPreTrainingModel(nn.Module): 139 | def __init__(self, cuda, number_of_source=14, number_of_category=3, batch_size=10, time_steps=15): 140 | super(DMMRPreTrainingModel, self).__init__() 141 | self.batch_size = batch_size 142 | self.time_steps = time_steps 143 | self.number_of_source = number_of_source 144 | self.attentionLayer = Attention(cuda, input_dim=310) 145 | self.sharedEncoder = Encoder(input_dim=310, hid_dim=64, n_layers=1) 146 | self.mse = MSE() 147 | self.domainClassifier = DomainClassifier(input_dim=64, output_dim=14) 148 | for i in range(number_of_source): 149 | exec('self.decoder' + str(i) + '=Decoder(input_dim=310, hid_dim=64, n_layers=1, output_dim=310)') 150 | def forward(self, x, corres, subject_id, m=0, mark=0): 151 | # Noise Injection, with the proposed method Time Steps Shuffling 152 | x = timeStepsShuffle(x) 153 | # The ABP module 154 | x = self.attentionLayer(x, x.shape[0], self.time_steps) 155 | # Encoder the weighted features with one-layer LSTM 156 | shared_last_out, shared_hn, shared_cn = self.sharedEncoder(x) 157 | # The DG_DANN module 158 | # The GRL layer in the first stage 159 | reverse_feature = ReverseLayerF.apply(shared_last_out, m) 160 | # The Subject Discriminator 161 | subject_predict = self.domainClassifier(reverse_feature) 162 | subject_predict = F.log_softmax(subject_predict,dim=1) 163 | # The domain adversarial loss 164 | sim_loss = F.nll_loss(subject_predict, subject_id) 165 | 166 | # Build Supervision for Decoders, the inputs are also weighted 167 | corres = self.attentionLayer(corres, corres.shape[0], self.time_steps) 168 | splitted_tensors = torch.chunk(corres, self.number_of_source, dim=0) 169 | rec_loss = 0 170 | mixSubjectFeature = 0 171 | for i in range(self.number_of_source): 172 | # Reconstruct features in the first stage 173 | x_out, *_ = eval('self.decoder' + str(i))(shared_last_out, shared_hn, shared_cn, self.time_steps) 174 | # The proposed mix method for data augmentation 175 | mixSubjectFeature += x_out 176 | shared_last_out_2, shared_hn_2, shared_cn_2 = self.sharedEncoder(mixSubjectFeature) 177 | for i in range(self.number_of_source): 178 | # Reconstruct features in the second stage 179 | x_out, *_ = eval('self.decoder' + str(i))(shared_last_out_2, shared_hn_2, shared_cn_2, self.time_steps) 180 | # Compute the reconstructive loss in the second stage only 181 | rec_loss += self.mse(x_out, splitted_tensors[i]) 182 | return rec_loss, sim_loss 183 | class DMMRFineTuningModel(nn.Module): 184 | def __init__(self, cuda, baseModel, number_of_source=14, number_of_category=3, batch_size=10, time_steps=15): 185 | super(DMMRFineTuningModel, self).__init__() 186 | self.baseModel = copy.deepcopy(baseModel) 187 | self.batch_size = batch_size 188 | self.time_steps = time_steps 189 | self.number_of_source = number_of_source 190 | # The ABP module and sharedEncoder are from the pretrained model 191 | self.attentionLayer = self.baseModel.attentionLayer 192 | self.sharedEncoder = self.baseModel.sharedEncoder 193 | # Add a new emotion classifier for emotion recognition 194 | self.cls_fc = nn.Sequential(nn.Linear(64, 64, bias=False), nn.BatchNorm1d(64), 195 | nn.ReLU(inplace=True), nn.Linear(64, number_of_category, bias=True)) 196 | self.mse = MSE() 197 | for i in range(number_of_source): 198 | exec('self.decoder' + str(i) + '=Decoder(input_dim=310, hid_dim=64, n_layers=1, output_dim=310)') 199 | def forward(self, x, label_src=0): 200 | x = self.attentionLayer(x, x.shape[0], self.time_steps) 201 | shared_last_out, shared_hn, shared_cn = self.sharedEncoder(x) 202 | x_logits = self.cls_fc(shared_last_out) 203 | x_pred = F.log_softmax(x_logits, dim=1) 204 | cls_loss = F.nll_loss(x_pred, label_src.squeeze()) 205 | return x_pred, x_logits, cls_loss 206 | 207 | class DMMRTestModel(nn.Module): 208 | def __init__(self, baseModel): 209 | super(DMMRTestModel, self).__init__() 210 | self.baseModel = copy.deepcopy(baseModel) 211 | def forward(self, x): 212 | x = self.baseModel.attentionLayer(x, self.baseModel.batch_size, self.baseModel.time_steps) 213 | shared_last_out, shared_hn, shared_cn = self.baseModel.sharedEncoder(x) 214 | x_shared_logits = self.baseModel.cls_fc(shared_last_out) 215 | return x_shared_logits 216 | 217 | ############## other noisy injection methods ############## 218 | class PreTrainingWithMaskTimeSteps(nn.Module): 219 | def __init__(self, cuda, number_of_source=14, number_of_category=3, batch_size=10, time_steps=15): 220 | super(PreTrainingWithMaskTimeSteps, self).__init__() 221 | self.batch_size = batch_size 222 | self.time_steps = time_steps 223 | self.number_of_source = number_of_source 224 | self.attentionLayer = Attention(cuda, input_dim=310) 225 | self.sharedEncoder = Encoder(input_dim=310, hid_dim=64, n_layers=1) 226 | self.mse = MSE() 227 | self.domainClassifier = DomainClassifier(input_dim=64, output_dim=14) 228 | for i in range(number_of_source): 229 | exec('self.decoder' + str(i) + '=Decoder(input_dim=310, hid_dim=64, n_layers=1, output_dim=310)') 230 | def forward(self, x, corres, subject_id, args, m=0, mark=0): 231 | x = maskTimeSteps(x, 0.2) 232 | x = self.attentionLayer(x, x.shape[0], self.time_steps) 233 | shared_last_out, shared_hn, shared_cn = self.sharedEncoder(x) 234 | 235 | reverse_feature = ReverseLayerF.apply(shared_last_out, m) 236 | subject_predict = self.domainClassifier(reverse_feature) 237 | subject_predict = F.log_softmax(subject_predict,dim=1) 238 | sim_loss = F.nll_loss(subject_predict, subject_id) 239 | 240 | corres = self.attentionLayer(corres, corres.shape[0], self.time_steps) 241 | splitted_tensors = torch.chunk(corres, self.number_of_source, dim=0) 242 | rec_loss = 0 243 | mixSubjectFeature = 0 244 | for i in range(self.number_of_source): 245 | x_out, *_ = eval('self.decoder' + str(i))(shared_last_out, shared_hn, shared_cn, self.time_steps) 246 | mixSubjectFeature += x_out 247 | shared_last_out_2, shared_hn_2, shared_cn_2 = self.sharedEncoder(mixSubjectFeature) 248 | for i in range(self.number_of_source): 249 | x_out, *_ = eval('self.decoder' + str(i))(shared_last_out_2, shared_hn_2, shared_cn_2, self.time_steps) 250 | rec_loss += self.mse(x_out, splitted_tensors[i]) 251 | return rec_loss, sim_loss 252 | 253 | class PreTrainingWithMaskChannels(nn.Module): 254 | def __init__(self, cuda, number_of_source=14, number_of_category=3, batch_size=10, time_steps=15): 255 | super(PreTrainingWithMaskChannels, self).__init__() 256 | self.batch_size = batch_size 257 | self.time_steps = time_steps 258 | self.number_of_source = number_of_source 259 | self.attentionLayer = Attention(cuda, input_dim=310) 260 | self.sharedEncoder = Encoder(input_dim=310, hid_dim=64, n_layers=1) 261 | self.mse = MSE() 262 | self.domainClassifier = DomainClassifier(input_dim=64, output_dim=14) 263 | for i in range(number_of_source): 264 | exec('self.decoder' + str(i) + '=Decoder(input_dim=310, hid_dim=64, n_layers=1, output_dim=310)') 265 | def forward(self, x, corres, subject_id, args, m=0, mark=0): 266 | x = maskChannels(x, args, 0.2) 267 | x = self.attentionLayer(x, x.shape[0], self.time_steps) 268 | shared_last_out, shared_hn, shared_cn = self.sharedEncoder(x) 269 | 270 | reverse_feature = ReverseLayerF.apply(shared_last_out, m) 271 | subject_predict = self.domainClassifier(reverse_feature) 272 | subject_predict = F.log_softmax(subject_predict,dim=1) 273 | sim_loss = F.nll_loss(subject_predict, subject_id) 274 | 275 | corres = self.attentionLayer(corres, corres.shape[0], self.time_steps) 276 | splitted_tensors = torch.chunk(corres, self.number_of_source, dim=0) 277 | rec_loss = 0 278 | mixSubjectFeature = 0 279 | for i in range(self.number_of_source): 280 | x_out, *_ = eval('self.decoder' + str(i))(shared_last_out, shared_hn, shared_cn, self.time_steps) 281 | mixSubjectFeature += x_out 282 | shared_last_out_2, shared_hn_2, shared_cn_2 = self.sharedEncoder(mixSubjectFeature) 283 | for i in range(self.number_of_source): 284 | x_out, *_ = eval('self.decoder' + str(i))(shared_last_out_2, shared_hn_2, shared_cn_2, self.time_steps) 285 | rec_loss += self.mse(x_out, splitted_tensors[i]) 286 | return rec_loss, sim_loss 287 | 288 | class PreTrainingWithChannelsShuffling(nn.Module): 289 | def __init__(self, cuda, number_of_source=14, number_of_category=3, batch_size=10, time_steps=15): 290 | super(PreTrainingWithChannelsShuffling, self).__init__() 291 | self.batch_size = batch_size 292 | self.time_steps = time_steps 293 | self.number_of_source = number_of_source 294 | self.attentionLayer = Attention(cuda, input_dim=310) 295 | self.sharedEncoder = Encoder(input_dim=310, hid_dim=64, n_layers=1) 296 | self.mse = MSE() 297 | self.domainClassifier = DomainClassifier(input_dim=64, output_dim=14) 298 | for i in range(number_of_source): 299 | exec('self.decoder' + str(i) + '=Decoder(input_dim=310, hid_dim=64, n_layers=1, output_dim=310)') 300 | def forward(self, x, corres, subject_id, args, m=0, mark=0): 301 | x = shuffleChannels(x, args) 302 | x = self.attentionLayer(x, x.shape[0], self.time_steps) 303 | shared_last_out, shared_hn, shared_cn = self.sharedEncoder(x) 304 | 305 | reverse_feature = ReverseLayerF.apply(shared_last_out, m) 306 | subject_predict = self.domainClassifier(reverse_feature) 307 | subject_predict = F.log_softmax(subject_predict,dim=1) 308 | sim_loss = F.nll_loss(subject_predict, subject_id) 309 | 310 | corres = self.attentionLayer(corres, corres.shape[0], self.time_steps) 311 | splitted_tensors = torch.chunk(corres, self.number_of_source, dim=0) 312 | rec_loss = 0 313 | mixSubjectFeature = 0 314 | for i in range(self.number_of_source): 315 | x_out, *_ = eval('self.decoder' + str(i))(shared_last_out, shared_hn, shared_cn, self.time_steps) 316 | mixSubjectFeature += x_out 317 | shared_last_out_2, shared_hn_2, shared_cn_2 = self.sharedEncoder(mixSubjectFeature) 318 | for i in range(self.number_of_source): 319 | x_out, *_ = eval('self.decoder' + str(i))(shared_last_out_2, shared_hn_2, shared_cn_2, self.time_steps) 320 | rec_loss += self.mse(x_out, splitted_tensors[i]) 321 | return rec_loss, sim_loss 322 | 323 | class PreTrainingWithDropout(nn.Module): 324 | def __init__(self, cuda, number_of_source=14, number_of_category=3, batch_size=10, time_steps=15, dropout_rate=0.2): 325 | super(PreTrainingWithDropout, self).__init__() 326 | self.batch_size = batch_size 327 | self.time_steps = time_steps 328 | self.number_of_source = number_of_source 329 | self.attentionLayer = Attention(cuda, input_dim=310) 330 | self.sharedEncoder = Encoder(input_dim=310, hid_dim=64, n_layers=1) 331 | self.mse = MSE() 332 | self.domainClassifier = DomainClassifier(input_dim=64, output_dim=14) 333 | self.dropout = nn.Dropout(dropout_rate) # noise 334 | for i in range(number_of_source): 335 | exec('self.decoder' + str(i) + '=Decoder(input_dim=310, hid_dim=64, n_layers=1, output_dim=310)') 336 | def forward(self, x, corres, subject_id, args, m=0, mark=0): 337 | x = self.dropout(x) 338 | x = self.attentionLayer(x, x.shape[0], self.time_steps) 339 | shared_last_out, shared_hn, shared_cn = self.sharedEncoder(x) 340 | 341 | reverse_feature = ReverseLayerF.apply(shared_last_out, m) 342 | subject_predict = self.domainClassifier(reverse_feature) 343 | subject_predict = F.log_softmax(subject_predict,dim=1) 344 | sim_loss = F.nll_loss(subject_predict, subject_id) 345 | 346 | corres = self.attentionLayer(corres, corres.shape[0], self.time_steps) 347 | splitted_tensors = torch.chunk(corres, self.number_of_source, dim=0) 348 | rec_loss = 0 349 | mixSubjectFeature = 0 350 | for i in range(self.number_of_source): 351 | x_out, *_ = eval('self.decoder' + str(i))(shared_last_out, shared_hn, shared_cn, self.time_steps) 352 | mixSubjectFeature += x_out 353 | shared_last_out_2, shared_hn_2, shared_cn_2 = self.sharedEncoder(mixSubjectFeature) 354 | for i in range(self.number_of_source): 355 | x_out, *_ = eval('self.decoder' + str(i))(shared_last_out_2, shared_hn_2, shared_cn_2, self.time_steps) 356 | rec_loss += self.mse(x_out, splitted_tensors[i]) 357 | return rec_loss, sim_loss 358 | 359 | 360 | ############## noiseInjectionMethods stydy ############## 361 | #w/o mix 362 | class PreTrainingWithoutMix(nn.Module): 363 | def __init__(self, cuda, number_of_source=14, number_of_category=3, batch_size=10, time_steps=15): 364 | super(PreTrainingWithoutMix, self).__init__() 365 | self.batch_size = batch_size 366 | self.time_steps = time_steps 367 | self.number_of_source = number_of_source 368 | self.attentionLayer = Attention(cuda, input_dim=310) 369 | self.sharedEncoder = Encoder(input_dim=310, hid_dim=64, n_layers=1) 370 | self.mse = MSE() 371 | self.domainClassifier = DomainClassifier(input_dim=64, output_dim=14) 372 | for i in range(number_of_source): 373 | exec('self.decoder' + str(i) + '=Decoder(input_dim=310, hid_dim=64, n_layers=1, output_dim=310)') 374 | def forward(self, x, corres, subject_id, m=0, mark=0): 375 | x = timeStepsShuffle(x) 376 | x = self.attentionLayer(x, x.shape[0], self.time_steps) 377 | shared_last_out, shared_hn, shared_cn = self.sharedEncoder(x) 378 | 379 | reverse_feature = ReverseLayerF.apply(shared_last_out, m) 380 | subject_predict = self.domainClassifier(reverse_feature) 381 | subject_predict = F.log_softmax(subject_predict,dim=1) 382 | sim_loss = F.nll_loss(subject_predict, subject_id) 383 | 384 | corres = self.attentionLayer(corres, corres.shape[0], self.time_steps) 385 | splitted_tensors = torch.chunk(corres, self.number_of_source, dim=0) 386 | rec_loss = 0 387 | for i in range(self.number_of_source): 388 | x_out, *_ = eval('self.decoder' + str(i))(shared_last_out, shared_hn, shared_cn, self.time_steps) 389 | rec_loss += self.mse(x_out, splitted_tensors[i]) 390 | return rec_loss, sim_loss 391 | #w/o noise 392 | class PreTrainingWithoutNoise(nn.Module): 393 | def __init__(self, cuda, number_of_source=14, number_of_category=3, batch_size=10, time_steps=15): 394 | super(PreTrainingWithoutNoise, self).__init__() 395 | self.batch_size = batch_size 396 | self.time_steps = time_steps 397 | self.number_of_source = number_of_source 398 | self.attentionLayer = Attention(cuda, input_dim=310) 399 | self.sharedEncoder = Encoder(input_dim=310, hid_dim=64, n_layers=1) 400 | self.mse = MSE() 401 | self.domainClassifier = DomainClassifier(input_dim=64, output_dim=14) 402 | for i in range(number_of_source): 403 | exec('self.decoder' + str(i) + '=Decoder(input_dim=310, hid_dim=64, n_layers=1, output_dim=310)') 404 | def forward(self, x, corres, subject_id, m=0, mark=0): 405 | x = self.attentionLayer(x, x.shape[0], self.time_steps) 406 | shared_last_out, shared_hn, shared_cn = self.sharedEncoder(x) 407 | 408 | reverse_feature = ReverseLayerF.apply(shared_last_out, m) 409 | subject_predict = self.domainClassifier(reverse_feature) 410 | subject_predict = F.log_softmax(subject_predict,dim=1) 411 | sim_loss = F.nll_loss(subject_predict, subject_id) 412 | 413 | corres = self.attentionLayer(corres, corres.shape[0], self.time_steps) 414 | splitted_tensors = torch.chunk(corres, self.number_of_source, dim=0) 415 | rec_loss = 0 416 | mixSubjectFeature = 0 417 | for i in range(self.number_of_source): 418 | x_out, *_ = eval('self.decoder' + str(i))(shared_last_out, shared_hn, shared_cn, self.time_steps) 419 | mixSubjectFeature += x_out 420 | shared_last_out_2, shared_hn_2, shared_cn_2 = self.sharedEncoder(mixSubjectFeature) 421 | for i in range(self.number_of_source): 422 | x_out, *_ = eval('self.decoder' + str(i))(shared_last_out_2, shared_hn_2, shared_cn_2, self.time_steps) 423 | rec_loss += self.mse(x_out, splitted_tensors[i]) 424 | return rec_loss, sim_loss 425 | #w/o both 426 | class PreTrainingWithoutBothMixAndNoise(nn.Module): 427 | def __init__(self, cuda, number_of_source=14, number_of_category=3, batch_size=10, time_steps=15): 428 | super(PreTrainingWithoutBothMixAndNoise, self).__init__() 429 | self.batch_size = batch_size 430 | self.time_steps = time_steps 431 | self.number_of_source = number_of_source 432 | self.attentionLayer = Attention(cuda, input_dim=310) 433 | self.sharedEncoder = Encoder(input_dim=310, hid_dim=64, n_layers=1) 434 | self.mse = MSE() 435 | self.domainClassifier = DomainClassifier(input_dim=64, output_dim=14) 436 | for i in range(number_of_source): 437 | exec('self.decoder' + str(i) + '=Decoder(input_dim=310, hid_dim=64, n_layers=1, output_dim=310)') 438 | def forward(self, x, corres, subject_id, m=0, mark=0): 439 | x = self.attentionLayer(x, x.shape[0], self.time_steps) 440 | shared_last_out, shared_hn, shared_cn = self.sharedEncoder(x) 441 | 442 | reverse_feature = ReverseLayerF.apply(shared_last_out, m) 443 | subject_predict = self.domainClassifier(reverse_feature) 444 | subject_predict = F.log_softmax(subject_predict,dim=1) 445 | sim_loss = F.nll_loss(subject_predict, subject_id) 446 | 447 | corres = self.attentionLayer(corres, corres.shape[0], self.time_steps) 448 | splitted_tensors = torch.chunk(corres, self.number_of_source, dim=0) 449 | rec_loss = 0 450 | for i in range(self.number_of_source): 451 | x_out, *_ = eval('self.decoder' + str(i))(shared_last_out, shared_hn, shared_cn, self.time_steps) 452 | rec_loss += self.mse(x_out, splitted_tensors[i]) 453 | return rec_loss, sim_loss 454 | 455 | #return feature of shared feature for T_SNE plots 456 | class ModelReturnFeatures(nn.Module): 457 | def __init__(self, baseModel, time_steps=15): 458 | super(ModelReturnFeatures, self).__init__() 459 | self.baseModel = baseModel 460 | self.time_steps = time_steps 461 | def forward(self, x): 462 | x = self.baseModel.attentionLayer(x, x.shape[0], self.time_steps) 463 | shared_last_out, shared_hn, shared_cn = self.baseModel.sharedEncoder(x) 464 | return x, shared_last_out 465 | 466 | 467 | -------------------------------------------------------------------------------- /noiseInjectionMethods/Dropout.py: -------------------------------------------------------------------------------- 1 | from preprocess import getDataLoaders 2 | import math 3 | from torch.utils.tensorboard import SummaryWriter 4 | import argparse 5 | from train import * 6 | import random 7 | import os 8 | import numpy as np 9 | 10 | os.environ["CUDA_VISIBLE_DEVICES"] = '3' 11 | 12 | def set_seed(seed=3): 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | os.environ['PYTHONHASHSEED'] = str(seed) 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed(seed) 18 | torch.cuda.manual_seed_all(seed) 19 | torch.backends.cudnn.deterministic = True 20 | 21 | def main(data_loader_dict, args, optim_config, cuda, writer, one_subject, seed=3): 22 | set_seed(seed) 23 | if args.dataset_name == 'seed3': 24 | iteration = 7 25 | elif args.dataset_name == 'seed4': 26 | iteration = 3 27 | acc = trainDMMR_Noise_Dropout(data_loader_dict, optim_config, cuda, args, iteration, writer, one_subject) 28 | return acc 29 | 30 | if __name__ == '__main__': 31 | cuda = torch.cuda.is_available() 32 | parser = argparse.ArgumentParser(description='DMMR') 33 | 34 | #config of experiment 35 | parser.add_argument("--way", type=str, default='DMMR/seed3', help="name of current way") 36 | parser.add_argument("--index", type=str, default='Dropout', help="tensorboard index") 37 | 38 | #config of dataset 39 | parser.add_argument("--dataset_name", type=str, nargs='?', default='seed3', help="the dataset name, supporting seed3 and seed4") 40 | parser.add_argument("--session", type=str, nargs='?', default='1', help="selected session") 41 | parser.add_argument("--subjects", type=int, choices=[15], default=15, help="the number of all subject") 42 | parser.add_argument("--dim", type=int, default=310, help="dim of input") 43 | 44 | #config of DMMR 45 | parser.add_argument("--input_dim", type=int, default=310, help="input dim is the same with sample's last dim") 46 | parser.add_argument("--hid_dim", type=int, default=64, help="hid dim is for hidden layer of lstm") 47 | parser.add_argument("--n_layers", type=int, default=1, help="num of layers of lstm") 48 | parser.add_argument("--epoch_fineTuning", type=int, default=500, help="epoch of the fine-tuning phase") 49 | parser.add_argument("--lr", type=int, default=1e-3, help="epoch of calModel") 50 | parser.add_argument("--weight_decay", type=float, default=0.0005, help="weight decay") 51 | parser.add_argument("--beta", type=float, default=0.05, help="balancing hyperparameter in the loss of pretraining phase") 52 | 53 | 54 | args = parser.parse_args() 55 | args.source_subjects = args.subjects-1 56 | args.seed3_path = "../eeg_data/ExtractedFeatures/" 57 | args.seed4_path = "../eeg_data/eeg_feature_smooth/" 58 | if cuda: 59 | args.num_workers_train = 4 60 | args.num_workers_test = 2 61 | else: 62 | args.num_workers_train = 0 63 | args.num_workers_test = 0 64 | if args.dataset_name == "seed3": 65 | args.path = args.seed3_path 66 | args.cls_classes = 3 67 | args.time_steps = 30 68 | args.batch_size = 512 #batch_size 69 | args.epoch_preTraining = 300 #epoch of the pre-training phase 70 | elif args.dataset_name == "seed4": 71 | args.path = args.seed4_path 72 | args.cls_classes = 4 73 | args.time_steps = 10 74 | args.batch_size = 256 #batch_size 75 | args.epoch_preTraining = 400 #epoch of the pre-training phase 76 | else: 77 | print("need to define the input dataset") 78 | optim_config = {"lr": args.lr, "weight_decay": args.weight_decay} 79 | # leave-one-subject-out cross-validation 80 | acc_list=[] 81 | writer = SummaryWriter("data/session"+args.session+"/"+args.way+"/" + args.index) 82 | for one_subject in range(0, args.subjects): 83 | # 1.data preparation 84 | source_loaders, test_loader = getDataLoaders(one_subject, args) 85 | data_loader_dict = {"source_loader": source_loaders, "test_loader":test_loader} 86 | # 2. main 87 | acc = main(data_loader_dict, args, optim_config, cuda, writer, one_subject) 88 | writer.add_scalars('single experiment acc: ', 89 | {'test acc': acc}, one_subject + 1) 90 | writer.flush() 91 | acc_list.append(acc) 92 | writer.add_text('final acc avg', str(np.average(acc_list))) 93 | writer.add_text('final acc std', str(np.std(acc_list))) 94 | acc_list_str = [str(x) for x in acc_list] 95 | writer.add_text('final each acc', ",".join(acc_list_str)) 96 | writer.add_scalars('final experiment acc scala: /avg', 97 | {'test acc': np.average(acc_list)}) 98 | writer.add_scalars('final experiment acc scala: /std', 99 | {'test acc': np.std(acc_list)}) 100 | writer.close() 101 | -------------------------------------------------------------------------------- /noiseInjectionMethods/channelsShuffling.py: -------------------------------------------------------------------------------- 1 | from preprocess import getDataLoaders 2 | import math 3 | from torch.utils.tensorboard import SummaryWriter 4 | import argparse 5 | from train import * 6 | import random 7 | import os 8 | import numpy as np 9 | 10 | os.environ["CUDA_VISIBLE_DEVICES"] = '3' 11 | 12 | def set_seed(seed=3): 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | os.environ['PYTHONHASHSEED'] = str(seed) 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed(seed) 18 | torch.cuda.manual_seed_all(seed) 19 | torch.backends.cudnn.deterministic = True 20 | 21 | def main(data_loader_dict, args, optim_config, cuda, writer, one_subject, seed=3): 22 | set_seed(seed) 23 | if args.dataset_name == 'seed3': 24 | iteration = 7 25 | elif args.dataset_name == 'seed4': 26 | iteration = 3 27 | acc = trainDMMR_Noise_ChannelsShuffling(data_loader_dict, optim_config, cuda, args, iteration, writer, one_subject) 28 | return acc 29 | 30 | if __name__ == '__main__': 31 | cuda = torch.cuda.is_available() 32 | parser = argparse.ArgumentParser(description='DMMR') 33 | 34 | #config of experiment 35 | parser.add_argument("--way", type=str, default='DMMR/seed3', help="name of current way") 36 | parser.add_argument("--index", type=str, default='channelsShuffling', help="tensorboard index") 37 | 38 | #config of dataset 39 | parser.add_argument("--dataset_name", type=str, nargs='?', default='seed3', help="the dataset name, supporting seed3 and seed4") 40 | parser.add_argument("--session", type=str, nargs='?', default='1', help="selected session") 41 | parser.add_argument("--subjects", type=int, choices=[15], default=15, help="the number of all subject") 42 | parser.add_argument("--dim", type=int, default=310, help="dim of input") 43 | 44 | #config of DMMR 45 | parser.add_argument("--input_dim", type=int, default=310, help="input dim is the same with sample's last dim") 46 | parser.add_argument("--hid_dim", type=int, default=64, help="hid dim is for hidden layer of lstm") 47 | parser.add_argument("--n_layers", type=int, default=1, help="num of layers of lstm") 48 | parser.add_argument("--epoch_fineTuning", type=int, default=500, help="epoch of the fine-tuning phase") 49 | parser.add_argument("--lr", type=int, default=1e-3, help="epoch of calModel") 50 | parser.add_argument("--weight_decay", type=float, default=0.0005, help="weight decay") 51 | parser.add_argument("--beta", type=float, default=0.05, help="balancing hyperparameter in the loss of pretraining phase") 52 | 53 | 54 | args = parser.parse_args() 55 | args.source_subjects = args.subjects-1 56 | args.seed3_path = "../eeg_data/ExtractedFeatures/" 57 | args.seed4_path = "../eeg_data/eeg_feature_smooth/" 58 | if cuda: 59 | args.num_workers_train = 4 60 | args.num_workers_test = 2 61 | else: 62 | args.num_workers_train = 0 63 | args.num_workers_test = 0 64 | if args.dataset_name == "seed3": 65 | args.path = args.seed3_path 66 | args.cls_classes = 3 67 | args.time_steps = 30 68 | args.batch_size = 512 #batch_size 69 | args.epoch_preTraining = 300 #epoch of the pre-training phase 70 | elif args.dataset_name == "seed4": 71 | args.path = args.seed4_path 72 | args.cls_classes = 4 73 | args.time_steps = 10 74 | args.batch_size = 256 #batch_size 75 | args.epoch_preTraining = 400 #epoch of the pre-training phase 76 | else: 77 | print("need to define the input dataset") 78 | optim_config = {"lr": args.lr, "weight_decay": args.weight_decay} 79 | # leave-one-subject-out cross-validation 80 | acc_list=[] 81 | writer = SummaryWriter("data/session"+args.session+"/"+args.way+"/" + args.index) 82 | for one_subject in range(0, args.subjects): 83 | # 1.data preparation 84 | source_loaders, test_loader = getDataLoaders(one_subject, args) 85 | data_loader_dict = {"source_loader": source_loaders, "test_loader":test_loader} 86 | # 2. main 87 | acc = main(data_loader_dict, args, optim_config, cuda, writer, one_subject) 88 | writer.add_scalars('single experiment acc: ', 89 | {'test acc': acc}, one_subject + 1) 90 | writer.flush() 91 | acc_list.append(acc) 92 | writer.add_text('final acc avg', str(np.average(acc_list))) 93 | writer.add_text('final acc std', str(np.std(acc_list))) 94 | acc_list_str = [str(x) for x in acc_list] 95 | writer.add_text('final each acc', ",".join(acc_list_str)) 96 | writer.add_scalars('final experiment acc scala: /avg', 97 | {'test acc': np.average(acc_list)}) 98 | writer.add_scalars('final experiment acc scala: /std', 99 | {'test acc': np.std(acc_list)}) 100 | writer.close() 101 | -------------------------------------------------------------------------------- /noiseInjectionMethods/maskChannels.py: -------------------------------------------------------------------------------- 1 | from preprocess import getDataLoaders 2 | import math 3 | from torch.utils.tensorboard import SummaryWriter 4 | import argparse 5 | from train import * 6 | import random 7 | import os 8 | import numpy as np 9 | 10 | os.environ["CUDA_VISIBLE_DEVICES"] = '3' 11 | 12 | def set_seed(seed=3): 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | os.environ['PYTHONHASHSEED'] = str(seed) 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed(seed) 18 | torch.cuda.manual_seed_all(seed) 19 | torch.backends.cudnn.deterministic = True 20 | 21 | def main(data_loader_dict, args, optim_config, cuda, writer, one_subject, seed=3): 22 | set_seed(seed) 23 | if args.dataset_name == 'seed3': 24 | iteration = 7 25 | elif args.dataset_name == 'seed4': 26 | iteration = 3 27 | acc = trainDMMR_Noise_MaskChannels(data_loader_dict, optim_config, cuda, args, iteration, writer, one_subject) 28 | return acc 29 | 30 | if __name__ == '__main__': 31 | cuda = torch.cuda.is_available() 32 | parser = argparse.ArgumentParser(description='DMMR') 33 | 34 | #config of experiment 35 | parser.add_argument("--way", type=str, default='DMMR/seed3', help="name of current way") 36 | parser.add_argument("--index", type=str, default='maskChannels', help="tensorboard index") 37 | 38 | #config of dataset 39 | parser.add_argument("--dataset_name", type=str, nargs='?', default='seed3', help="the dataset name, supporting seed3 and seed4") 40 | parser.add_argument("--session", type=str, nargs='?', default='1', help="selected session") 41 | parser.add_argument("--subjects", type=int, choices=[15], default=15, help="the number of all subject") 42 | parser.add_argument("--dim", type=int, default=310, help="dim of input") 43 | 44 | #config of DMMR 45 | parser.add_argument("--input_dim", type=int, default=310, help="input dim is the same with sample's last dim") 46 | parser.add_argument("--hid_dim", type=int, default=64, help="hid dim is for hidden layer of lstm") 47 | parser.add_argument("--n_layers", type=int, default=1, help="num of layers of lstm") 48 | parser.add_argument("--epoch_fineTuning", type=int, default=500, help="epoch of the fine-tuning phase") 49 | parser.add_argument("--lr", type=int, default=1e-3, help="epoch of calModel") 50 | parser.add_argument("--weight_decay", type=float, default=0.0005, help="weight decay") 51 | parser.add_argument("--beta", type=float, default=0.05, help="balancing hyperparameter in the loss of pretraining phase") 52 | 53 | 54 | args = parser.parse_args() 55 | args.source_subjects = args.subjects-1 56 | args.seed3_path = "../eeg_data/ExtractedFeatures/" 57 | args.seed4_path = "../eeg_data/eeg_feature_smooth/" 58 | if cuda: 59 | args.num_workers_train = 4 60 | args.num_workers_test = 2 61 | else: 62 | args.num_workers_train = 0 63 | args.num_workers_test = 0 64 | if args.dataset_name == "seed3": 65 | args.path = args.seed3_path 66 | args.cls_classes = 3 67 | args.time_steps = 30 68 | args.batch_size = 512 #batch_size 69 | args.epoch_preTraining = 300 #epoch of the pre-training phase 70 | elif args.dataset_name == "seed4": 71 | args.path = args.seed4_path 72 | args.cls_classes = 4 73 | args.time_steps = 10 74 | args.batch_size = 256 #batch_size 75 | args.epoch_preTraining = 400 #epoch of the pre-training phase 76 | else: 77 | print("need to define the input dataset") 78 | optim_config = {"lr": args.lr, "weight_decay": args.weight_decay} 79 | # leave-one-subject-out cross-validation 80 | acc_list=[] 81 | writer = SummaryWriter("data/session"+args.session+"/"+args.way+"/" + args.index) 82 | for one_subject in range(0, args.subjects): 83 | # 1.data preparation 84 | source_loaders, test_loader = getDataLoaders(one_subject, args) 85 | data_loader_dict = {"source_loader": source_loaders, "test_loader":test_loader} 86 | # 2. main 87 | acc = main(data_loader_dict, args, optim_config, cuda, writer, one_subject) 88 | writer.add_scalars('single experiment acc: ', 89 | {'test acc': acc}, one_subject + 1) 90 | writer.flush() 91 | acc_list.append(acc) 92 | writer.add_text('final acc avg', str(np.average(acc_list))) 93 | writer.add_text('final acc std', str(np.std(acc_list))) 94 | acc_list_str = [str(x) for x in acc_list] 95 | writer.add_text('final each acc', ",".join(acc_list_str)) 96 | writer.add_scalars('final experiment acc scala: /avg', 97 | {'test acc': np.average(acc_list)}) 98 | writer.add_scalars('final experiment acc scala: /std', 99 | {'test acc': np.std(acc_list)}) 100 | writer.close() 101 | -------------------------------------------------------------------------------- /noiseInjectionMethods/maskTimeSteps.py: -------------------------------------------------------------------------------- 1 | from preprocess import getDataLoaders 2 | import math 3 | from torch.utils.tensorboard import SummaryWriter 4 | import argparse 5 | from train import * 6 | import random 7 | import os 8 | import numpy as np 9 | 10 | os.environ["CUDA_VISIBLE_DEVICES"] = '3' 11 | 12 | def set_seed(seed=3): 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | os.environ['PYTHONHASHSEED'] = str(seed) 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed(seed) 18 | torch.cuda.manual_seed_all(seed) 19 | torch.backends.cudnn.deterministic = True 20 | 21 | def main(data_loader_dict, args, optim_config, cuda, writer, one_subject, seed=3): 22 | set_seed(seed) 23 | if args.dataset_name == 'seed3': 24 | iteration = 7 25 | elif args.dataset_name == 'seed4': 26 | iteration = 3 27 | acc = trainDMMR_Noise_MaskTimeSteps(data_loader_dict, optim_config, cuda, args, iteration, writer, one_subject) 28 | return acc 29 | 30 | if __name__ == '__main__': 31 | cuda = torch.cuda.is_available() 32 | parser = argparse.ArgumentParser(description='DMMR') 33 | 34 | #config of experiment 35 | parser.add_argument("--way", type=str, default='DMMR/seed3', help="name of current way") 36 | parser.add_argument("--index", type=str, default='maskTimeSteps', help="tensorboard index") 37 | 38 | #config of dataset 39 | parser.add_argument("--dataset_name", type=str, nargs='?', default='seed3', help="the dataset name, supporting seed3 and seed4") 40 | parser.add_argument("--session", type=str, nargs='?', default='1', help="selected session") 41 | parser.add_argument("--subjects", type=int, choices=[15], default=15, help="the number of all subject") 42 | parser.add_argument("--dim", type=int, default=310, help="dim of input") 43 | 44 | #config of DMMR 45 | parser.add_argument("--input_dim", type=int, default=310, help="input dim is the same with sample's last dim") 46 | parser.add_argument("--hid_dim", type=int, default=64, help="hid dim is for hidden layer of lstm") 47 | parser.add_argument("--n_layers", type=int, default=1, help="num of layers of lstm") 48 | parser.add_argument("--epoch_fineTuning", type=int, default=500, help="epoch of the fine-tuning phase") 49 | parser.add_argument("--lr", type=int, default=1e-3, help="epoch of calModel") 50 | parser.add_argument("--weight_decay", type=float, default=0.0005, help="weight decay") 51 | parser.add_argument("--beta", type=float, default=0.05, help="balancing hyperparameter in the loss of pretraining phase") 52 | 53 | 54 | args = parser.parse_args() 55 | args.source_subjects = args.subjects-1 56 | args.seed3_path = "../eeg_data/ExtractedFeatures/" 57 | args.seed4_path = "../eeg_data/eeg_feature_smooth/" 58 | if cuda: 59 | args.num_workers_train = 4 60 | args.num_workers_test = 2 61 | else: 62 | args.num_workers_train = 0 63 | args.num_workers_test = 0 64 | if args.dataset_name == "seed3": 65 | args.path = args.seed3_path 66 | args.cls_classes = 3 67 | args.time_steps = 30 68 | args.batch_size = 512 #batch_size 69 | args.epoch_preTraining = 300 #epoch of the pre-training phase 70 | elif args.dataset_name == "seed4": 71 | args.path = args.seed4_path 72 | args.cls_classes = 4 73 | args.time_steps = 10 74 | args.batch_size = 256 #batch_size 75 | args.epoch_preTraining = 400 #epoch of the pre-training phase 76 | else: 77 | print("need to define the input dataset") 78 | optim_config = {"lr": args.lr, "weight_decay": args.weight_decay} 79 | # leave-one-subject-out cross-validation 80 | acc_list=[] 81 | writer = SummaryWriter("data/session"+args.session+"/"+args.way+"/" + args.index) 82 | for one_subject in range(0, args.subjects): 83 | # 1.data preparation 84 | source_loaders, test_loader = getDataLoaders(one_subject, args) 85 | data_loader_dict = {"source_loader": source_loaders, "test_loader":test_loader} 86 | # 2. main 87 | acc = main(data_loader_dict, args, optim_config, cuda, writer, one_subject) 88 | writer.add_scalars('single experiment acc: ', 89 | {'test acc': acc}, one_subject + 1) 90 | writer.flush() 91 | acc_list.append(acc) 92 | writer.add_text('final acc avg', str(np.average(acc_list))) 93 | writer.add_text('final acc std', str(np.std(acc_list))) 94 | acc_list_str = [str(x) for x in acc_list] 95 | writer.add_text('final each acc', ",".join(acc_list_str)) 96 | writer.add_scalars('final experiment acc scala: /avg', 97 | {'test acc': np.average(acc_list)}) 98 | writer.add_scalars('final experiment acc scala: /std', 99 | {'test acc': np.std(acc_list)}) 100 | writer.close() 101 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import scipy.io as scio 4 | import os 5 | 6 | def get_data_path(file_path): 7 | data_path = [] 8 | for f in os.listdir(file_path): 9 | if f.startswith("."): 10 | continue 11 | else: 12 | data_path.append(os.path.join(file_path, f)) 13 | return data_path 14 | 15 | # Time-domain sliding window of the DE feature in an experiment 16 | def window_slice(data, time_steps): 17 | data = np.transpose(data, (1, 0, 2)).reshape(-1, 310) 18 | xs = [] 19 | for i in range(data.shape[0] - time_steps + 1): 20 | xs.append(data[i: i + time_steps]) 21 | xs = np.concatenate(xs).reshape((len(xs), -1, 310)) 22 | return xs 23 | 24 | 25 | def get_number_of_label_n_trial(dataset_name): 26 | ''' 27 | description: get the number of categories, trial number and the corresponding labels 28 | param {type} 29 | return {type}: 30 | trial: int 31 | label: int 32 | label_xxx: list 3*15 33 | ''' 34 | # global variables 35 | label_seed4 = [[1,2,3,0,2,0,0,1,0,1,2,1,1,1,2,3,2,2,3,3,0,3,0,3], 36 | [2,1,3,0,0,2,0,2,3,3,2,3,2,0,1,1,2,1,0,3,0,1,3,1], 37 | [1,2,2,1,3,3,3,1,1,2,1,0,2,3,3,0,2,3,0,0,2,0,1,0]] 38 | label_seed3 = [[2,1,0,0,1,2,0,1,2,2,1,0,1,2,0], 39 | [2,1,0,0,1,2,0,1,2,2,1,0,1,2,0], 40 | [2,1,0,0,1,2,0,1,2,2,1,0,1,2,0]] 41 | if dataset_name == 'seed3': 42 | label = 3 43 | trial = 15 44 | return trial, label, label_seed3 45 | elif dataset_name == 'seed4': 46 | label = 4 47 | trial = 24 48 | return trial, label, label_seed4 49 | else: 50 | print('Unexcepted dataset name') 51 | 52 | #The return is a list of x and y, where the list contains data for each subject 53 | def load_trained_data(samples_path_list, args): 54 | # load the label data 55 | _, _, labels = get_number_of_label_n_trial(args.dataset_name) 56 | label = labels[int(args.session)-1] 57 | if args.dataset_name=="seed3": 58 | label = np.resize(label, (15,)) 59 | label = np.reshape(label, (1, 15)) 60 | elif args.dataset_name=="seed4": 61 | label = np.resize(label, (24,)) 62 | label = np.reshape(label, (1, 24)) 63 | X_train_all = [] 64 | Y_tain_all = [] 65 | #Iterate through each subject (there are 14 source subjects in both datasets) 66 | for path in samples_path_list: 67 | # load the sample data 68 | sample = scio.loadmat(path, verify_compressed_data_integrity=False) 69 | flag = 0 70 | X_train = [] 71 | y_train = [] 72 | for key, val in sample.items(): 73 | if key.startswith("de_LDS"): 74 | X_train.append(window_slice(val, args.time_steps)) 75 | train_label = np.full((X_train[-1].shape[0], 1), label[0, flag]) 76 | y_train.append(train_label) 77 | flag += 1 78 | X_train_one_subject=np.concatenate(X_train) 79 | y_train_one_subject=np.concatenate(y_train) 80 | X_train_all.append(X_train_one_subject) 81 | Y_tain_all.append(y_train_one_subject) 82 | return X_train_all, Y_tain_all 83 | 84 | def normalize(features, select_dim=0): 85 | features_min, _ = torch.min(features, dim=select_dim) 86 | features_max, _ = torch.max(features, dim=select_dim) 87 | features_min = features_min.unsqueeze(select_dim) 88 | features_max = features_max.unsqueeze(select_dim) 89 | return (features - features_min)/(features_max - features_min) 90 | 91 | # Load the data, return a list of samples and labels, including tensors of data for each subject 92 | def load4train(samples_path_list, args): 93 | """ 94 | load the SEED data set 95 | """ 96 | train_sample, train_label = load_trained_data(samples_path_list, args) 97 | sample_res = [] 98 | label_res = [] 99 | for subject_index in range(len(train_sample)): 100 | # transfer from ndarray to tensor 101 | one_subject_samples = torch.from_numpy(train_sample[subject_index]).type(torch.FloatTensor) 102 | one_subject_labels = torch.from_numpy(train_label[subject_index]).type(torch.LongTensor) 103 | # normalize tensor 104 | one_subject_samples = normalize(one_subject_samples) 105 | sample_res.append(one_subject_samples) 106 | label_res.append(one_subject_labels) 107 | return sample_res, label_res 108 | 109 | #The input is the full set of data from the first session. 110 | def getDataLoaders(one_subject, args): 111 | pre_path=args.path 112 | config_path = {"file_path": pre_path + args.session + "/", 113 | "label_path": pre_path+"label.mat"} 114 | path_list = get_data_path(config_path["file_path"]) 115 | try: 116 | target_path_list = [i for i in path_list if(i.startswith(config_path["file_path"] + str(int(one_subject+1))+"_"))] 117 | target_path=target_path_list[0] 118 | except: 119 | print("target data not exist") 120 | path_list.remove(target_path) 121 | source_path_list = path_list 122 | 123 | # read from DE feature 124 | sources_sample, sources_label = load4train(source_path_list, args) 125 | targets_sample, targets_label = load4train(target_path_list, args) 126 | 127 | if(len(targets_label)==1): 128 | target_sample = targets_sample[0] 129 | target_label = targets_label[0] 130 | 131 | # Generate Data loaders 132 | source_dsets = [] 133 | for i in range(len(sources_sample)): 134 | source_dsets.append(torch.utils.data.TensorDataset(sources_sample[i], sources_label[i])) 135 | target_dset = torch.utils.data.TensorDataset(target_sample, target_label) 136 | 137 | source_loaders = [] 138 | for j in range(len(source_dsets)): 139 | source_loaders.append(torch.utils.data.DataLoader(source_dsets[j], args.batch_size, shuffle=True, num_workers=args.num_workers_train, drop_last=True)) 140 | test_loader = torch.utils.data.DataLoader(target_dset, args.batch_size, shuffle=False, num_workers=args.num_workers_test, drop_last=True) 141 | return source_loaders, test_loader -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | python==3.6.13 2 | numpy==1.19.2 3 | scipy==1.5.2 4 | torch==1.10.1 5 | torchvision==0.11.2 6 | tensorboard==1.15.0 7 | pillow==8.3.1 8 | matplotlib==3.3.4 -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | 4 | def testDMMR(dataLoader, DMMRTestModel, cuda, batch_size): 5 | print("testing DMMR") 6 | index = 0 7 | count = 0 8 | data_set_all = 0 9 | if cuda: 10 | DMMRTestModel = DMMRTestModel.cuda() 11 | DMMRTestModel.eval() 12 | with torch.no_grad(): 13 | for _, (test_input, label) in enumerate(dataLoader): 14 | if cuda: 15 | test_input, label = test_input.cuda(), label.cuda() 16 | test_input, label = Variable(test_input), Variable(label) 17 | data_set_all += len(label) 18 | x_shared_pred = DMMRTestModel(test_input) 19 | _, pred = torch.max(x_shared_pred, dim=1) 20 | count += pred.eq(label.squeeze().data.view_as(pred)).sum() 21 | index += batch_size 22 | acc = float(count) / data_set_all 23 | return acc 24 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from model import * 4 | import numpy as np 5 | from test import * 6 | from collections import defaultdict 7 | import random 8 | from sklearn.manifold import TSNE 9 | import matplotlib.pyplot as plt 10 | 11 | def trainDMMR(data_loader_dict, optimizer_config, cuda, args, iteration, writer, one_subject): 12 | # data of source subjects, which is used as the training set 13 | source_loader = data_loader_dict['source_loader'] 14 | # The pre-training phase 15 | preTrainModel = DMMRPreTrainingModel(cuda, number_of_source=len(source_loader), number_of_category=args.cls_classes, batch_size=args.batch_size, time_steps=args.time_steps) 16 | if cuda: 17 | preTrainModel = preTrainModel.cuda() 18 | source_iters = [] 19 | for i in range(len(source_loader)): 20 | source_iters.append(iter(source_loader[i])) 21 | optimizer_PreTraining = torch.optim.Adam(preTrainModel.parameters(), **optimizer_config) 22 | 23 | acc_final = 0 24 | for epoch in range(args.epoch_preTraining): 25 | print("epoch: "+str(epoch)) 26 | start_time_pretrain = time.time() 27 | preTrainModel.train() 28 | data_set_all = 0 29 | for i in range(1, iteration + 1): 30 | p = float(i + epoch * iteration) / args.epoch_preTraining / iteration 31 | m = 2. / (1. + np.exp(-10 * p)) - 1 # for the gradient reverse layer (GRL) 32 | batch_dict = defaultdict(list) #Pre-fetch a batch of data for each subject in advance and store them in this dictionary. 33 | data_dict = defaultdict(list) #Store the data of each subject in the current batch 34 | label_dict = defaultdict(list) #Store the labels corresponding to the data of each subject in the current batch 35 | label_data_dict = defaultdict(set) 36 | for j in range(len(source_iters)): 37 | try: 38 | batch_dict[j] = next(source_iters[j]) 39 | except: 40 | source_iters[j] = iter(source_loader[j]) 41 | batch_dict[j]= next(source_iters[j]) 42 | index = 0 43 | for o in batch_dict[j][1]: 44 | cur_label = o[0].item() 45 | data_dict[j].append(batch_dict[j][0][index]) 46 | label_dict[j].append(cur_label) 47 | index+=1 48 | 49 | for j in range(len(source_iters)): 50 | # Assign a unique ID to each source subject 51 | subject_id = torch.ones(args.batch_size) 52 | subject_id = subject_id * j 53 | subject_id = subject_id.long() 54 | #the input of the model 55 | source_data, source_label = batch_dict[j] 56 | # Prepare corresponding new batch of each subject, the new batch has same label with current batch. 57 | label_data_dict_list = [] 58 | for one_index in range(args.source_subjects): 59 | cur_data_list = data_dict[one_index] 60 | cur_label_list = label_dict[one_index] 61 | for one in range(args.batch_size): 62 | label_data_dict[cur_label_list[one]].add(cur_data_list[one]) 63 | label_data_dict_list.append(label_data_dict) 64 | label_data_dict = defaultdict(set) 65 | # Store the corresponding new batch of each subject, providing the supervision for different decoders 66 | corres_batch_data = [] 67 | for i in range(len(label_data_dict_list)): 68 | for one in source_label: 69 | label_cur = one[0].item() 70 | corres_batch_data.append(random.choice(list(label_data_dict_list[i][label_cur]))) 71 | corres_batch_data = torch.stack(corres_batch_data) 72 | if cuda: 73 | source_data = source_data.cuda() 74 | source_label = source_label.cuda() 75 | subject_id = subject_id.cuda() 76 | corres_batch_data = corres_batch_data.cuda() 77 | data_set_all += len(source_label) 78 | optimizer_PreTraining.zero_grad() 79 | # Call the pretraining model 80 | rec_loss, sim_loss = preTrainModel(source_data, corres_batch_data, subject_id, m, mark=j) 81 | # The loss of the pre-training phase, beta is the balancing hyperparameter 82 | loss_pretrain = rec_loss + args.beta * sim_loss 83 | loss_pretrain.backward() 84 | optimizer_PreTraining.step() 85 | print("data set amount: "+str(data_set_all)) 86 | writer.add_scalars('subject: '+str(one_subject+1)+' '+'train DMMR/loss', 87 | {'loss_pretrain':loss_pretrain.data,'rec_loss':rec_loss.data,'sim_loss':sim_loss.data}, epoch + 1) 88 | end_time_pretrain = time.time() 89 | pretrain_epoch_time = end_time_pretrain - start_time_pretrain 90 | print("The time required for one pre-training epoch is:", pretrain_epoch_time, "second") 91 | print("rec_loss: "+str(rec_loss)) 92 | 93 | # The fine-tuning phase 94 | source_iters2 = [] 95 | for i in range(len(source_loader)): 96 | source_iters2.append(iter(source_loader[i])) 97 | #Load the ABP module, the encoder from pretrained model and build a new model for the fine-tuning phase 98 | fineTuneModel = DMMRFineTuningModel(cuda, preTrainModel, number_of_source=len(source_loader), 99 | number_of_category=args.cls_classes, batch_size=args.batch_size, 100 | time_steps=args.time_steps) 101 | 102 | optimizer_FineTuning = torch.optim.Adam(fineTuneModel.parameters(), **optimizer_config) 103 | if cuda: 104 | fineTuneModel = fineTuneModel.cuda() 105 | for epoch in range(args.epoch_fineTuning): 106 | print("epoch: " + str(epoch)) 107 | start_time = time.time() 108 | fineTuneModel.train() 109 | count = 0 110 | data_set_all = 0 111 | for i in range(1, iteration + 1): 112 | batch_dict = defaultdict(list) 113 | data_dict = defaultdict(list) 114 | label_dict = defaultdict(list) 115 | 116 | for j in range(len(source_iters2)): 117 | try: 118 | batch_dict[j] = next(source_iters2[j]) 119 | except: 120 | source_iters2[j] = iter(source_loader[j]) 121 | batch_dict[j] = next(source_iters2[j]) 122 | index = 0 123 | for o in batch_dict[j][1]: 124 | cur_label = o[0].item() 125 | data_dict[j].append(batch_dict[j][0][index]) 126 | label_dict[j].append(cur_label) 127 | index += 1 128 | for j in range(len(source_iters)): 129 | source_data, source_label = batch_dict[j] 130 | if cuda: 131 | source_data = source_data.cuda() 132 | source_label = source_label.cuda() 133 | data_set_all += len(source_label) 134 | optimizer_FineTuning.zero_grad() 135 | # Call the fine-tuning model 136 | x_pred, x_logits, cls_loss = fineTuneModel(source_data, source_label) 137 | cls_loss.backward() 138 | optimizer_FineTuning.step() 139 | _, pred = torch.max(x_pred, dim=1) 140 | count += pred.eq(source_label.squeeze().data.view_as(pred)).sum() 141 | end_time = time.time() 142 | epoch_time = end_time - start_time 143 | print("The time required for one fine-tuning epoch is:", epoch_time, "second") 144 | print("data set amount: " + str(data_set_all)) 145 | acc = float(count) / data_set_all 146 | writer.add_scalars('subject: ' + str(one_subject + 1) + ' ' + 'train DMMR/loss', 147 | {'cls_loss': cls_loss.data}, epoch + 1) 148 | writer.add_scalar('subject: ' + str(one_subject + 1) + ' ' + 'train DMMR/train accuracy', acc, epoch + 1) 149 | print("acc: " + str(acc)) 150 | # test the fine-tuned model with the data of unseen target subject 151 | testModel = DMMRTestModel(fineTuneModel) 152 | acc_DMMR = testDMMR(data_loader_dict["test_loader"], testModel, cuda, args.batch_size) 153 | print("acc_DMMR: " + str(acc_DMMR)) 154 | writer.add_scalars('subject: ' + str(one_subject + 1) + ' ' + 'test DMMR/test acc', 155 | {'test acc': acc_DMMR}, epoch + 1) 156 | if acc_DMMR > acc_final: 157 | acc_final = acc_DMMR 158 | best_pretrain_model = copy.deepcopy(preTrainModel.state_dict()) 159 | best_tune_model = copy.deepcopy(fineTuneModel.state_dict()) 160 | best_test_model = copy.deepcopy(testModel.state_dict()) 161 | modelDir = "model/" + args.way + "/" + args.index + "/" 162 | try: 163 | os.makedirs(modelDir) 164 | except: 165 | pass 166 | # save models 167 | torch.save(best_pretrain_model, modelDir + str(one_subject) + '_pretrain_model.pth') 168 | torch.save(best_tune_model, modelDir + str(one_subject) + '_tune_model.pth') 169 | torch.save(best_test_model, modelDir + str(one_subject) + '_test_model.pth') 170 | return acc_final 171 | 172 | ############## Ablation studies ############## 173 | # w/o mix 174 | def trainDMMR_WithoutMix(data_loader_dict, optimizer_config, cuda, args, iteration, writer, one_subject): 175 | source_loader = data_loader_dict['source_loader'] 176 | # The pre-training phase 177 | preTrainModel = PreTrainingWithoutMix(cuda, number_of_source=len(source_loader), number_of_category=args.cls_classes, batch_size=args.batch_size, time_steps=args.time_steps) 178 | if cuda: 179 | preTrainModel = preTrainModel.cuda() 180 | source_iters = [] 181 | for i in range(len(source_loader)): 182 | source_iters.append(iter(source_loader[i])) 183 | optimizer_PreTraining = torch.optim.Adam(preTrainModel.parameters(), **optimizer_config) 184 | 185 | acc_final = 0 186 | for epoch in range(args.epoch_preTraining): 187 | print("epoch: "+str(epoch)) 188 | start_time_pretrain = time.time() 189 | preTrainModel.train() 190 | count = 0 191 | data_set_all = 0 192 | for i in range(1, iteration + 1): 193 | p = float(i + epoch * iteration) / args.epoch_preTraining / iteration 194 | m = 2. / (1. + np.exp(-10 * p)) - 1 195 | batch_dict = defaultdict(list) 196 | data_dict = defaultdict(list) 197 | label_dict = defaultdict(list) 198 | label_data_dict = defaultdict(set) 199 | 200 | for j in range(len(source_iters)): 201 | try: 202 | batch_dict[j] = next(source_iters[j]) 203 | except: 204 | source_iters[j] = iter(source_loader[j]) 205 | batch_dict[j]= next(source_iters[j]) 206 | index = 0 207 | for o in batch_dict[j][1]: 208 | cur_label = o[0].item() 209 | data_dict[j].append(batch_dict[j][0][index]) 210 | label_dict[j].append(cur_label) 211 | index+=1 212 | 213 | for j in range(len(source_iters)): 214 | subject_id = torch.ones(args.batch_size) 215 | subject_id = subject_id * j 216 | subject_id = subject_id.long() 217 | 218 | source_data, source_label = batch_dict[j] 219 | # prepare corresponding new batch, the new batch has same label with current batch 220 | label_data_dict_list = [] 221 | for one_index in range(args.source_subjects): 222 | cur_data_list = data_dict[one_index] 223 | cur_label_list = label_dict[one_index] 224 | for one in range(args.batch_size): 225 | label_data_dict[cur_label_list[one]].add(cur_data_list[one]) 226 | label_data_dict_list.append(label_data_dict) 227 | label_data_dict = defaultdict(set) 228 | corres_batch_data = [] 229 | for i in range(len(label_data_dict_list)): 230 | for one in source_label: 231 | label_cur = one[0].item() 232 | corres_batch_data.append(random.choice(list(label_data_dict_list[i][label_cur]))) 233 | corres_batch_data = torch.stack(corres_batch_data) 234 | if cuda: 235 | source_data = source_data.cuda() 236 | source_label = source_label.cuda() 237 | subject_id = subject_id.cuda() 238 | corres_batch_data =corres_batch_data.cuda() 239 | data_set_all+=len(source_label) 240 | optimizer_PreTraining.zero_grad() 241 | rec_loss, sim_loss = preTrainModel(source_data, corres_batch_data, subject_id, m, mark=j) 242 | loss_pretrain = rec_loss + args.beta*sim_loss 243 | loss_pretrain.backward() 244 | optimizer_PreTraining.step() 245 | print("data set amount: "+str(data_set_all)) 246 | writer.add_scalars('subject: '+str(one_subject+1)+' '+'train DMMR/loss', 247 | {'loss_pretrain':loss_pretrain.data,'rec_loss':rec_loss.data,'sim_loss':sim_loss.data}, epoch + 1) 248 | end_time_pretrain = time.time() 249 | pretrain_epoch_time = end_time_pretrain - start_time_pretrain 250 | print("The time required for one pre-training epoch is:", pretrain_epoch_time, "second") 251 | print("rec_loss: "+str(rec_loss)) 252 | 253 | # The fine-tuning phase 254 | source_iters2 = [] 255 | for i in range(len(source_loader)): 256 | source_iters2.append(iter(source_loader[i])) 257 | fineTuneModel = DMMRFineTuningModel(cuda, preTrainModel, number_of_source=len(source_loader), 258 | number_of_category=args.cls_classes, batch_size=args.batch_size, 259 | time_steps=args.time_steps) 260 | 261 | optimizer_FineTuning = torch.optim.Adam(fineTuneModel.parameters(), **optimizer_config) 262 | if cuda: 263 | fineTuneModel = fineTuneModel.cuda() 264 | for epoch in range(args.epoch_fineTuning): 265 | print("epoch: " + str(epoch)) 266 | start_time = time.time() 267 | fineTuneModel.train() 268 | count = 0 269 | data_set_all = 0 270 | for i in range(1, iteration + 1): 271 | p = float(i + epoch * iteration) / args.epoch_preTraining / iteration 272 | m = 2. / (1. + np.exp(-10 * p)) - 1 273 | batch_dict = defaultdict(list) 274 | data_dict = defaultdict(list) 275 | label_dict = defaultdict(list) 276 | label_data_dict = defaultdict(set) 277 | 278 | for j in range(len(source_iters2)): 279 | try: 280 | batch_dict[j] = next(source_iters2[j]) 281 | except: 282 | source_iters2[j] = iter(source_loader[j]) 283 | batch_dict[j] = next(source_iters2[j]) 284 | index = 0 285 | for o in batch_dict[j][1]: 286 | cur_label = o[0].item() 287 | data_dict[j].append(batch_dict[j][0][index]) 288 | label_dict[j].append(cur_label) 289 | index += 1 290 | for j in range(len(source_iters)): 291 | source_data, source_label = batch_dict[j] 292 | if cuda: 293 | source_data = source_data.cuda() 294 | source_label = source_label.cuda() 295 | data_set_all += len(source_label) 296 | optimizer_FineTuning.zero_grad() 297 | x_pred, x_logits, cls_loss = fineTuneModel(source_data, source_label) 298 | cls_loss.backward() 299 | optimizer_FineTuning.step() 300 | _, pred = torch.max(x_pred, dim=1) 301 | count += pred.eq(source_label.squeeze().data.view_as(pred)).sum() 302 | end_time = time.time() 303 | epoch_time = end_time - start_time 304 | print("The time required for one fine-tuning epoch is:", epoch_time, "second") 305 | print("data set amount: " + str(data_set_all)) 306 | acc = float(count) / data_set_all 307 | writer.add_scalars('subject: ' + str(one_subject + 1) + ' ' + 'train DMMR/loss', 308 | {'cls_loss': cls_loss.data}, epoch + 1) 309 | writer.add_scalar('subject: ' + str(one_subject + 1) + ' ' + 'train DMMR/train accuracy', acc, epoch + 1) 310 | print("acc: " + str(acc)) 311 | 312 | testModel = DMMRTestModel(fineTuneModel) 313 | acc_DMMR = testDMMR(data_loader_dict["test_loader"], testModel, cuda, args.batch_size) 314 | print("acc_DMMR: " + str(acc_DMMR)) 315 | writer.add_scalars('subject: ' + str(one_subject + 1) + ' ' + 'test DMMR/test acc', 316 | {'test acc': acc_DMMR}, epoch + 1) 317 | if acc_DMMR > acc_final: 318 | acc_final = acc_DMMR 319 | best_pretrain_model = copy.deepcopy(preTrainModel.state_dict()) 320 | best_tune_model = copy.deepcopy(fineTuneModel.state_dict()) 321 | best_test_model = copy.deepcopy(testModel.state_dict()) 322 | modelDir = "model/" + args.way + "/" + args.index + "/" 323 | try: 324 | os.makedirs(modelDir) 325 | except: 326 | pass 327 | torch.save(best_pretrain_model, modelDir + str(one_subject) + '_pretrain_model.pth') 328 | torch.save(best_tune_model, modelDir + str(one_subject) + '_tune_model.pth') 329 | torch.save(best_test_model, modelDir + str(one_subject) + '_test_model.pth') 330 | return acc_final 331 | # w/o noise 332 | def trainDMMR_WithoutNoise(data_loader_dict, optimizer_config, cuda, args, iteration, writer, one_subject): 333 | source_loader = data_loader_dict['source_loader'] 334 | # The pre-training phase 335 | preTrainModel = PreTrainingWithoutNoise(cuda, number_of_source=len(source_loader), number_of_category=args.cls_classes, batch_size=args.batch_size, time_steps=args.time_steps) 336 | if cuda: 337 | preTrainModel = preTrainModel.cuda() 338 | source_iters = [] 339 | for i in range(len(source_loader)): 340 | source_iters.append(iter(source_loader[i])) 341 | optimizer_PreTraining = torch.optim.Adam(preTrainModel.parameters(), **optimizer_config) 342 | 343 | acc_final = 0 344 | for epoch in range(args.epoch_preTraining): 345 | print("epoch: "+str(epoch)) 346 | start_time_pretrain = time.time() 347 | preTrainModel.train() 348 | count = 0 349 | data_set_all = 0 350 | for i in range(1, iteration + 1): 351 | p = float(i + epoch * iteration) / args.epoch_preTraining / iteration 352 | m = 2. / (1. + np.exp(-10 * p)) - 1 353 | batch_dict = defaultdict(list) 354 | data_dict = defaultdict(list) 355 | label_dict = defaultdict(list) 356 | label_data_dict = defaultdict(set) 357 | 358 | for j in range(len(source_iters)): 359 | try: 360 | batch_dict[j] = next(source_iters[j]) 361 | except: 362 | source_iters[j] = iter(source_loader[j]) 363 | batch_dict[j]= next(source_iters[j]) 364 | index = 0 365 | for o in batch_dict[j][1]: 366 | cur_label = o[0].item() 367 | data_dict[j].append(batch_dict[j][0][index]) 368 | label_dict[j].append(cur_label) 369 | index+=1 370 | 371 | for j in range(len(source_iters)): 372 | subject_id = torch.ones(args.batch_size) 373 | subject_id = subject_id * j 374 | subject_id = subject_id.long() 375 | 376 | source_data, source_label = batch_dict[j] 377 | # prepare corresponding new batch, the new batch has same label with current batch 378 | label_data_dict_list = [] 379 | for one_index in range(args.source_subjects): 380 | cur_data_list = data_dict[one_index] 381 | cur_label_list = label_dict[one_index] 382 | for one in range(args.batch_size): 383 | label_data_dict[cur_label_list[one]].add(cur_data_list[one]) 384 | label_data_dict_list.append(label_data_dict) 385 | label_data_dict = defaultdict(set) 386 | corres_batch_data = [] 387 | for i in range(len(label_data_dict_list)): 388 | for one in source_label: 389 | label_cur = one[0].item() 390 | corres_batch_data.append(random.choice(list(label_data_dict_list[i][label_cur]))) 391 | corres_batch_data = torch.stack(corres_batch_data) 392 | if cuda: 393 | source_data = source_data.cuda() 394 | source_label = source_label.cuda() 395 | subject_id = subject_id.cuda() 396 | corres_batch_data =corres_batch_data.cuda() 397 | data_set_all+=len(source_label) 398 | optimizer_PreTraining.zero_grad() 399 | rec_loss, sim_loss = preTrainModel(source_data, corres_batch_data, subject_id, m, mark=j) 400 | loss_pretrain = rec_loss + args.beta*sim_loss 401 | loss_pretrain.backward() 402 | optimizer_PreTraining.step() 403 | print("data set amount: "+str(data_set_all)) 404 | writer.add_scalars('subject: '+str(one_subject+1)+' '+'train DMMR/loss', 405 | {'loss_pretrain':loss_pretrain.data,'rec_loss':rec_loss.data,'sim_loss':sim_loss.data}, epoch + 1) 406 | end_time_pretrain = time.time() 407 | pretrain_epoch_time = end_time_pretrain - start_time_pretrain 408 | print("The time required for one pre-training epoch is:", pretrain_epoch_time, "second") 409 | print("rec_loss: "+str(rec_loss)) 410 | 411 | # The fine-tuning phase 412 | source_iters2 = [] 413 | for i in range(len(source_loader)): 414 | source_iters2.append(iter(source_loader[i])) 415 | fineTuneModel = DMMRFineTuningModel(cuda, preTrainModel, number_of_source=len(source_loader), 416 | number_of_category=args.cls_classes, batch_size=args.batch_size, 417 | time_steps=args.time_steps) 418 | 419 | optimizer_FineTuning = torch.optim.Adam(fineTuneModel.parameters(), **optimizer_config) 420 | if cuda: 421 | fineTuneModel = fineTuneModel.cuda() 422 | for epoch in range(args.epoch_fineTuning): 423 | print("epoch: " + str(epoch)) 424 | start_time = time.time() 425 | fineTuneModel.train() 426 | count = 0 427 | data_set_all = 0 428 | for i in range(1, iteration + 1): 429 | p = float(i + epoch * iteration) / args.epoch_preTraining / iteration 430 | m = 2. / (1. + np.exp(-10 * p)) - 1 431 | batch_dict = defaultdict(list) 432 | data_dict = defaultdict(list) 433 | label_dict = defaultdict(list) 434 | label_data_dict = defaultdict(set) 435 | 436 | for j in range(len(source_iters2)): 437 | try: 438 | batch_dict[j] = next(source_iters2[j]) 439 | except: 440 | source_iters2[j] = iter(source_loader[j]) 441 | batch_dict[j] = next(source_iters2[j]) 442 | index = 0 443 | for o in batch_dict[j][1]: 444 | cur_label = o[0].item() 445 | data_dict[j].append(batch_dict[j][0][index]) 446 | label_dict[j].append(cur_label) 447 | index += 1 448 | for j in range(len(source_iters)): 449 | source_data, source_label = batch_dict[j] 450 | if cuda: 451 | source_data = source_data.cuda() 452 | source_label = source_label.cuda() 453 | data_set_all += len(source_label) 454 | optimizer_FineTuning.zero_grad() 455 | x_pred, x_logits, cls_loss = fineTuneModel(source_data, source_label) 456 | cls_loss.backward() 457 | optimizer_FineTuning.step() 458 | _, pred = torch.max(x_pred, dim=1) 459 | count += pred.eq(source_label.squeeze().data.view_as(pred)).sum() 460 | end_time = time.time() 461 | epoch_time = end_time - start_time 462 | print("The time required for one fine-tuning epoch is:", epoch_time, "second") 463 | print("data set amount: " + str(data_set_all)) 464 | acc = float(count) / data_set_all 465 | writer.add_scalars('subject: ' + str(one_subject + 1) + ' ' + 'train DMMR/loss', 466 | {'cls_loss': cls_loss.data}, epoch + 1) 467 | writer.add_scalar('subject: ' + str(one_subject + 1) + ' ' + 'train DMMR/train accuracy', acc, epoch + 1) 468 | print("acc: " + str(acc)) 469 | 470 | testModel = DMMRTestModel(fineTuneModel) 471 | acc_DMMR = testDMMR(data_loader_dict["test_loader"], testModel, cuda, args.batch_size) 472 | print("acc_DMMR: " + str(acc_DMMR)) 473 | writer.add_scalars('subject: ' + str(one_subject + 1) + ' ' + 'test DMMR/test acc', 474 | {'test acc': acc_DMMR}, epoch + 1) 475 | if acc_DMMR > acc_final: 476 | acc_final = acc_DMMR 477 | best_pretrain_model = copy.deepcopy(preTrainModel.state_dict()) 478 | best_tune_model = copy.deepcopy(fineTuneModel.state_dict()) 479 | best_test_model = copy.deepcopy(testModel.state_dict()) 480 | modelDir = "model/"+args.way+"/"+args.index+"/" 481 | try: 482 | os.makedirs(modelDir) 483 | except: 484 | pass 485 | torch.save(best_pretrain_model, modelDir+str(one_subject)+'_pretrain_model.pth') 486 | torch.save(best_tune_model, modelDir+str(one_subject)+'_tune_model.pth') 487 | torch.save(best_test_model, modelDir+str(one_subject)+'_test_model.pth') 488 | return acc_final 489 | # w/o both 490 | def trainDMMR_WithoutBothMixAndNoise(data_loader_dict, optimizer_config, cuda, args, iteration, writer, one_subject): 491 | source_loader = data_loader_dict['source_loader'] 492 | # The pre-training phase 493 | preTrainModel = PreTrainingWithoutBothMixAndNoise(cuda, number_of_source=len(source_loader), number_of_category=args.cls_classes, batch_size=args.batch_size, time_steps=args.time_steps) 494 | if cuda: 495 | preTrainModel = preTrainModel.cuda() 496 | source_iters = [] 497 | for i in range(len(source_loader)): 498 | source_iters.append(iter(source_loader[i])) 499 | optimizer_PreTraining = torch.optim.Adam(preTrainModel.parameters(), **optimizer_config) 500 | 501 | acc_final = 0 502 | for epoch in range(args.epoch_preTraining): 503 | print("epoch: "+str(epoch)) 504 | start_time_pretrain = time.time() 505 | preTrainModel.train() 506 | count = 0 507 | data_set_all = 0 508 | for i in range(1, iteration + 1): 509 | p = float(i + epoch * iteration) / args.epoch_preTraining / iteration 510 | m = 2. / (1. + np.exp(-10 * p)) - 1 511 | batch_dict = defaultdict(list) 512 | data_dict = defaultdict(list) 513 | label_dict = defaultdict(list) 514 | label_data_dict = defaultdict(set) 515 | 516 | for j in range(len(source_iters)): 517 | try: 518 | batch_dict[j] = next(source_iters[j]) 519 | except: 520 | source_iters[j] = iter(source_loader[j]) 521 | batch_dict[j]= next(source_iters[j]) 522 | index = 0 523 | for o in batch_dict[j][1]: 524 | cur_label = o[0].item() 525 | data_dict[j].append(batch_dict[j][0][index]) 526 | label_dict[j].append(cur_label) 527 | index+=1 528 | 529 | for j in range(len(source_iters)): 530 | subject_id = torch.ones(args.batch_size) 531 | subject_id = subject_id * j 532 | subject_id = subject_id.long() 533 | 534 | source_data, source_label = batch_dict[j] 535 | # prepare corresponding new batch, the new batch has same label with current batch 536 | label_data_dict_list = [] 537 | for one_index in range(args.source_subjects): 538 | cur_data_list = data_dict[one_index] 539 | cur_label_list = label_dict[one_index] 540 | for one in range(args.batch_size): 541 | label_data_dict[cur_label_list[one]].add(cur_data_list[one]) 542 | label_data_dict_list.append(label_data_dict) 543 | label_data_dict = defaultdict(set) 544 | corres_batch_data = [] 545 | for i in range(len(label_data_dict_list)): 546 | for one in source_label: 547 | label_cur = one[0].item() 548 | corres_batch_data.append(random.choice(list(label_data_dict_list[i][label_cur]))) 549 | corres_batch_data = torch.stack(corres_batch_data) 550 | if cuda: 551 | source_data = source_data.cuda() 552 | source_label = source_label.cuda() 553 | subject_id = subject_id.cuda() 554 | corres_batch_data =corres_batch_data.cuda() 555 | data_set_all+=len(source_label) 556 | optimizer_PreTraining.zero_grad() 557 | rec_loss, sim_loss = preTrainModel(source_data, corres_batch_data, subject_id, m, mark=j) 558 | loss_pretrain = rec_loss + args.beta*sim_loss 559 | loss_pretrain.backward() 560 | optimizer_PreTraining.step() 561 | print("data set amount: "+str(data_set_all)) 562 | writer.add_scalars('subject: '+str(one_subject+1)+' '+'train DMMR/loss', 563 | {'loss_pretrain':loss_pretrain.data,'rec_loss':rec_loss.data,'sim_loss':sim_loss.data}, epoch + 1) 564 | end_time_pretrain = time.time() 565 | pretrain_epoch_time = end_time_pretrain - start_time_pretrain 566 | print("The time required for one pre-training epoch is:", pretrain_epoch_time, "second") 567 | print("rec_loss: "+str(rec_loss)) 568 | 569 | # The fine-tuning phase 570 | source_iters2 = [] 571 | for i in range(len(source_loader)): 572 | source_iters2.append(iter(source_loader[i])) 573 | fineTuneModel = DMMRFineTuningModel(cuda, preTrainModel, number_of_source=len(source_loader), 574 | number_of_category=args.cls_classes, batch_size=args.batch_size, 575 | time_steps=args.time_steps) 576 | 577 | optimizer_FineTuning = torch.optim.Adam(fineTuneModel.parameters(), **optimizer_config) 578 | if cuda: 579 | fineTuneModel = fineTuneModel.cuda() 580 | for epoch in range(args.epoch_fineTuning): 581 | print("epoch: " + str(epoch)) 582 | start_time = time.time() 583 | fineTuneModel.train() 584 | count = 0 585 | data_set_all = 0 586 | for i in range(1, iteration + 1): 587 | p = float(i + epoch * iteration) / args.epoch_preTraining / iteration 588 | m = 2. / (1. + np.exp(-10 * p)) - 1 589 | batch_dict = defaultdict(list) 590 | data_dict = defaultdict(list) 591 | label_dict = defaultdict(list) 592 | label_data_dict = defaultdict(set) 593 | 594 | for j in range(len(source_iters2)): 595 | try: 596 | batch_dict[j] = next(source_iters2[j]) 597 | except: 598 | source_iters2[j] = iter(source_loader[j]) 599 | batch_dict[j] = next(source_iters2[j]) 600 | index = 0 601 | for o in batch_dict[j][1]: 602 | cur_label = o[0].item() 603 | data_dict[j].append(batch_dict[j][0][index]) 604 | label_dict[j].append(cur_label) 605 | index += 1 606 | for j in range(len(source_iters)): 607 | source_data, source_label = batch_dict[j] 608 | if cuda: 609 | source_data = source_data.cuda() 610 | source_label = source_label.cuda() 611 | data_set_all += len(source_label) 612 | optimizer_FineTuning.zero_grad() 613 | x_pred, x_logits, cls_loss = fineTuneModel(source_data, source_label) 614 | cls_loss.backward() 615 | optimizer_FineTuning.step() 616 | _, pred = torch.max(x_pred, dim=1) 617 | count += pred.eq(source_label.squeeze().data.view_as(pred)).sum() 618 | end_time = time.time() 619 | epoch_time = end_time - start_time 620 | print("The time required for one fine-tuning epoch is:", epoch_time, "second") 621 | print("data set amount: " + str(data_set_all)) 622 | acc = float(count) / data_set_all 623 | writer.add_scalars('subject: ' + str(one_subject + 1) + ' ' + 'train DMMR/loss', 624 | {'cls_loss': cls_loss.data}, epoch + 1) 625 | writer.add_scalar('subject: ' + str(one_subject + 1) + ' ' + 'train DMMR/train accuracy', acc, epoch + 1) 626 | print("acc: " + str(acc)) 627 | 628 | testModel = DMMRTestModel(fineTuneModel) 629 | acc_DMMR = testDMMR(data_loader_dict["test_loader"], testModel, cuda, args.batch_size) 630 | print("acc_DMMR: " + str(acc_DMMR)) 631 | writer.add_scalars('subject: ' + str(one_subject + 1) + ' ' + 'test DMMR/test acc', 632 | {'test acc': acc_DMMR}, epoch + 1) 633 | if acc_DMMR > acc_final: 634 | acc_final = acc_DMMR 635 | best_pretrain_model = copy.deepcopy(preTrainModel.state_dict()) 636 | best_tune_model = copy.deepcopy(fineTuneModel.state_dict()) 637 | best_test_model = copy.deepcopy(testModel.state_dict()) 638 | modelDir = "model/" + args.way + "/" + args.index + "/" 639 | try: 640 | os.makedirs(modelDir) 641 | except: 642 | pass 643 | torch.save(best_pretrain_model, modelDir + str(one_subject) + '_pretrain_model.pth') 644 | torch.save(best_tune_model, modelDir + str(one_subject) + '_tune_model.pth') 645 | torch.save(best_test_model, modelDir + str(one_subject) + '_test_model.pth') 646 | return acc_final 647 | 648 | ############## Other noise injection methods ############## 649 | def trainDMMR_Noise_MaskChannels(data_loader_dict, optimizer_config, cuda, args, iteration, writer, one_subject): 650 | source_loader = data_loader_dict['source_loader'] 651 | # The pre-training phase 652 | preTrainModel = PreTrainingWithMaskChannels(cuda, number_of_source=len(source_loader), number_of_category=args.cls_classes, batch_size=args.batch_size, time_steps=args.time_steps) 653 | if cuda: 654 | preTrainModel = preTrainModel.cuda() 655 | source_iters = [] 656 | for i in range(len(source_loader)): 657 | source_iters.append(iter(source_loader[i])) 658 | optimizer_PreTraining = torch.optim.Adam(preTrainModel.parameters(), **optimizer_config) 659 | 660 | acc_final = 0 661 | for epoch in range(args.epoch_preTraining): 662 | print("epoch: "+str(epoch)) 663 | start_time_pretrain = time.time() 664 | preTrainModel.train() 665 | count = 0 666 | data_set_all = 0 667 | for i in range(1, iteration + 1): 668 | p = float(i + epoch * iteration) / args.epoch_preTraining / iteration 669 | m = 2. / (1. + np.exp(-10 * p)) - 1 670 | batch_dict = defaultdict(list) 671 | data_dict = defaultdict(list) 672 | label_dict = defaultdict(list) 673 | label_data_dict = defaultdict(set) 674 | 675 | for j in range(len(source_iters)): 676 | try: 677 | batch_dict[j] = next(source_iters[j]) 678 | except: 679 | source_iters[j] = iter(source_loader[j]) 680 | batch_dict[j]= next(source_iters[j]) 681 | index = 0 682 | for o in batch_dict[j][1]: 683 | cur_label = o[0].item() 684 | data_dict[j].append(batch_dict[j][0][index]) 685 | label_dict[j].append(cur_label) 686 | index+=1 687 | 688 | for j in range(len(source_iters)): 689 | subject_id = torch.ones(args.batch_size) 690 | subject_id = subject_id * j 691 | subject_id = subject_id.long() 692 | 693 | source_data, source_label = batch_dict[j] 694 | # prepare corresponding new batch, the new batch has same label with current batch 695 | label_data_dict_list = [] 696 | for one_index in range(args.source_subjects): 697 | cur_data_list = data_dict[one_index] 698 | cur_label_list = label_dict[one_index] 699 | for one in range(args.batch_size): 700 | label_data_dict[cur_label_list[one]].add(cur_data_list[one]) 701 | label_data_dict_list.append(label_data_dict) 702 | label_data_dict = defaultdict(set) 703 | corres_batch_data = [] 704 | for i in range(len(label_data_dict_list)): 705 | for one in source_label: 706 | label_cur = one[0].item() 707 | corres_batch_data.append(random.choice(list(label_data_dict_list[i][label_cur]))) 708 | corres_batch_data = torch.stack(corres_batch_data) 709 | if cuda: 710 | source_data = source_data.cuda() 711 | source_label = source_label.cuda() 712 | subject_id = subject_id.cuda() 713 | corres_batch_data =corres_batch_data.cuda() 714 | data_set_all+=len(source_label) 715 | optimizer_PreTraining.zero_grad() 716 | rec_loss, sim_loss = preTrainModel(source_data, corres_batch_data, subject_id, args, m, mark=j) 717 | loss_pretrain = rec_loss + args.beta*sim_loss 718 | loss_pretrain.backward() 719 | optimizer_PreTraining.step() 720 | print("data set amount: "+str(data_set_all)) 721 | writer.add_scalars('subject: '+str(one_subject+1)+' '+'train DMMR/loss', 722 | {'loss_pretrain':loss_pretrain.data,'rec_loss':rec_loss.data,'sim_loss':sim_loss.data}, epoch + 1) 723 | end_time_pretrain = time.time() 724 | pretrain_epoch_time = end_time_pretrain - start_time_pretrain 725 | print("The time required for one pre-training epoch is:", pretrain_epoch_time, "second") 726 | print("rec_loss: "+str(rec_loss)) 727 | 728 | # The fine-tuning phase 729 | source_iters2 = [] 730 | for i in range(len(source_loader)): 731 | source_iters2.append(iter(source_loader[i])) 732 | fineTuneModel = DMMRFineTuningModel(cuda, preTrainModel, number_of_source=len(source_loader), 733 | number_of_category=args.cls_classes, batch_size=args.batch_size, 734 | time_steps=args.time_steps) 735 | 736 | optimizer_FineTuning = torch.optim.Adam(fineTuneModel.parameters(), **optimizer_config) 737 | if cuda: 738 | fineTuneModel = fineTuneModel.cuda() 739 | for epoch in range(args.epoch_fineTuning): 740 | print("epoch: " + str(epoch)) 741 | start_time = time.time() 742 | fineTuneModel.train() 743 | count = 0 744 | data_set_all = 0 745 | for i in range(1, iteration + 1): 746 | p = float(i + epoch * iteration) / args.epoch_preTraining / iteration 747 | m = 2. / (1. + np.exp(-10 * p)) - 1 748 | batch_dict = defaultdict(list) 749 | data_dict = defaultdict(list) 750 | label_dict = defaultdict(list) 751 | label_data_dict = defaultdict(set) 752 | 753 | for j in range(len(source_iters2)): 754 | try: 755 | batch_dict[j] = next(source_iters2[j]) 756 | except: 757 | source_iters2[j] = iter(source_loader[j]) 758 | batch_dict[j] = next(source_iters2[j]) 759 | index = 0 760 | for o in batch_dict[j][1]: 761 | cur_label = o[0].item() 762 | data_dict[j].append(batch_dict[j][0][index]) 763 | label_dict[j].append(cur_label) 764 | index += 1 765 | for j in range(len(source_iters)): 766 | source_data, source_label = batch_dict[j] 767 | if cuda: 768 | source_data = source_data.cuda() 769 | source_label = source_label.cuda() 770 | data_set_all += len(source_label) 771 | optimizer_FineTuning.zero_grad() 772 | x_pred, x_logits, cls_loss = fineTuneModel(source_data, source_label) 773 | cls_loss.backward() 774 | optimizer_FineTuning.step() 775 | _, pred = torch.max(x_pred, dim=1) 776 | count += pred.eq(source_label.squeeze().data.view_as(pred)).sum() 777 | end_time = time.time() 778 | epoch_time = end_time - start_time 779 | print("The time required for one fine-tuning epoch is:", epoch_time, "second") 780 | print("data set amount: " + str(data_set_all)) 781 | acc = float(count) / data_set_all 782 | writer.add_scalars('subject: ' + str(one_subject + 1) + ' ' + 'train DMMR/loss', 783 | {'cls_loss': cls_loss.data}, epoch + 1) 784 | writer.add_scalar('subject: ' + str(one_subject + 1) + ' ' + 'train DMMR/train accuracy', acc, epoch + 1) 785 | print("acc: " + str(acc)) 786 | 787 | testModel = DMMRTestModel(fineTuneModel) 788 | acc_DMMR = testDMMR(data_loader_dict["test_loader"], testModel, cuda, args.batch_size) 789 | print("acc_DMMR: " + str(acc_DMMR)) 790 | writer.add_scalars('subject: ' + str(one_subject + 1) + ' ' + 'test DMMR/test acc', 791 | {'test acc': acc_DMMR}, epoch + 1) 792 | if acc_DMMR > acc_final: 793 | acc_final = acc_DMMR 794 | best_pretrain_model = copy.deepcopy(preTrainModel.state_dict()) 795 | best_tune_model = copy.deepcopy(fineTuneModel.state_dict()) 796 | best_test_model = copy.deepcopy(testModel.state_dict()) 797 | modelDir = "model/" + args.way + "/" + args.index + "/" 798 | try: 799 | os.makedirs(modelDir) 800 | except: 801 | pass 802 | torch.save(best_pretrain_model, modelDir + str(one_subject) + '_pretrain_model.pth') 803 | torch.save(best_tune_model, modelDir + str(one_subject) + '_tune_model.pth') 804 | torch.save(best_test_model, modelDir + str(one_subject) + '_test_model.pth') 805 | return acc_final 806 | 807 | def trainDMMR_Noise_MaskTimeSteps(data_loader_dict, optimizer_config, cuda, args, iteration, writer, one_subject): 808 | source_loader = data_loader_dict['source_loader'] 809 | # The pre-training phase 810 | preTrainModel = PreTrainingWithMaskTimeSteps(cuda, number_of_source=len(source_loader), number_of_category=args.cls_classes, batch_size=args.batch_size, time_steps=args.time_steps) 811 | if cuda: 812 | preTrainModel = preTrainModel.cuda() 813 | source_iters = [] 814 | for i in range(len(source_loader)): 815 | source_iters.append(iter(source_loader[i])) 816 | optimizer_PreTraining = torch.optim.Adam(preTrainModel.parameters(), **optimizer_config) 817 | 818 | acc_final = 0 819 | for epoch in range(args.epoch_preTraining): 820 | print("epoch: "+str(epoch)) 821 | start_time_pretrain = time.time() 822 | preTrainModel.train() 823 | count = 0 824 | data_set_all = 0 825 | for i in range(1, iteration + 1): 826 | p = float(i + epoch * iteration) / args.epoch_preTraining / iteration 827 | m = 2. / (1. + np.exp(-10 * p)) - 1 828 | batch_dict = defaultdict(list) 829 | data_dict = defaultdict(list) 830 | label_dict = defaultdict(list) 831 | label_data_dict = defaultdict(set) 832 | 833 | for j in range(len(source_iters)): 834 | try: 835 | batch_dict[j] = next(source_iters[j]) 836 | except: 837 | source_iters[j] = iter(source_loader[j]) 838 | batch_dict[j]= next(source_iters[j]) 839 | index = 0 840 | for o in batch_dict[j][1]: 841 | cur_label = o[0].item() 842 | data_dict[j].append(batch_dict[j][0][index]) 843 | label_dict[j].append(cur_label) 844 | index+=1 845 | 846 | for j in range(len(source_iters)): 847 | subject_id = torch.ones(args.batch_size) 848 | subject_id = subject_id * j 849 | subject_id = subject_id.long() 850 | 851 | source_data, source_label = batch_dict[j] 852 | # prepare corresponding new batch, the new batch has same label with current batch 853 | label_data_dict_list = [] 854 | for one_index in range(args.source_subjects): 855 | cur_data_list = data_dict[one_index] 856 | cur_label_list = label_dict[one_index] 857 | for one in range(args.batch_size): 858 | label_data_dict[cur_label_list[one]].add(cur_data_list[one]) 859 | label_data_dict_list.append(label_data_dict) 860 | label_data_dict = defaultdict(set) 861 | corres_batch_data = [] 862 | for i in range(len(label_data_dict_list)): 863 | for one in source_label: 864 | label_cur = one[0].item() 865 | corres_batch_data.append(random.choice(list(label_data_dict_list[i][label_cur]))) 866 | corres_batch_data = torch.stack(corres_batch_data) 867 | if cuda: 868 | source_data = source_data.cuda() 869 | source_label = source_label.cuda() 870 | subject_id = subject_id.cuda() 871 | corres_batch_data =corres_batch_data.cuda() 872 | data_set_all+=len(source_label) 873 | optimizer_PreTraining.zero_grad() 874 | rec_loss, sim_loss = preTrainModel(source_data, corres_batch_data, subject_id, args, m, mark=j) 875 | loss_pretrain = rec_loss + args.beta*sim_loss 876 | loss_pretrain.backward() 877 | optimizer_PreTraining.step() 878 | print("data set amount: "+str(data_set_all)) 879 | writer.add_scalars('subject: '+str(one_subject+1)+' '+'train DMMR/loss', 880 | {'loss_pretrain':loss_pretrain.data,'rec_loss':rec_loss.data,'sim_loss':sim_loss.data}, epoch + 1) 881 | end_time_pretrain = time.time() 882 | pretrain_epoch_time = end_time_pretrain - start_time_pretrain 883 | print("The time required for one pre-training epoch is:", pretrain_epoch_time, "second") 884 | print("rec_loss: "+str(rec_loss)) 885 | 886 | # The fine-tuning phase 887 | source_iters2 = [] 888 | for i in range(len(source_loader)): 889 | source_iters2.append(iter(source_loader[i])) 890 | fineTuneModel = DMMRFineTuningModel(cuda, preTrainModel, number_of_source=len(source_loader), 891 | number_of_category=args.cls_classes, batch_size=args.batch_size, 892 | time_steps=args.time_steps) 893 | 894 | optimizer_FineTuning = torch.optim.Adam(fineTuneModel.parameters(), **optimizer_config) 895 | if cuda: 896 | fineTuneModel = fineTuneModel.cuda() 897 | for epoch in range(args.epoch_fineTuning): 898 | print("epoch: " + str(epoch)) 899 | start_time = time.time() 900 | fineTuneModel.train() 901 | count = 0 902 | data_set_all = 0 903 | for i in range(1, iteration + 1): 904 | p = float(i + epoch * iteration) / args.epoch_preTraining / iteration 905 | m = 2. / (1. + np.exp(-10 * p)) - 1 906 | batch_dict = defaultdict(list) 907 | data_dict = defaultdict(list) 908 | label_dict = defaultdict(list) 909 | label_data_dict = defaultdict(set) 910 | 911 | for j in range(len(source_iters2)): 912 | try: 913 | batch_dict[j] = next(source_iters2[j]) 914 | except: 915 | source_iters2[j] = iter(source_loader[j]) 916 | batch_dict[j] = next(source_iters2[j]) 917 | index = 0 918 | for o in batch_dict[j][1]: 919 | cur_label = o[0].item() 920 | data_dict[j].append(batch_dict[j][0][index]) 921 | label_dict[j].append(cur_label) 922 | index += 1 923 | for j in range(len(source_iters)): 924 | source_data, source_label = batch_dict[j] 925 | if cuda: 926 | source_data = source_data.cuda() 927 | source_label = source_label.cuda() 928 | data_set_all += len(source_label) 929 | optimizer_FineTuning.zero_grad() 930 | x_pred, x_logits, cls_loss = fineTuneModel(source_data, source_label) 931 | cls_loss.backward() 932 | optimizer_FineTuning.step() 933 | _, pred = torch.max(x_pred, dim=1) 934 | count += pred.eq(source_label.squeeze().data.view_as(pred)).sum() 935 | end_time = time.time() 936 | epoch_time = end_time - start_time 937 | print("The time required for one fine-tuning epoch is:", epoch_time, "second") 938 | print("data set amount: " + str(data_set_all)) 939 | acc = float(count) / data_set_all 940 | writer.add_scalars('subject: ' + str(one_subject + 1) + ' ' + 'train DMMR/loss', 941 | {'cls_loss': cls_loss.data}, epoch + 1) 942 | writer.add_scalar('subject: ' + str(one_subject + 1) + ' ' + 'train DMMR/train accuracy', acc, epoch + 1) 943 | print("acc: " + str(acc)) 944 | 945 | testModel = DMMRTestModel(fineTuneModel) 946 | acc_DMMR = testDMMR(data_loader_dict["test_loader"], testModel, cuda, args.batch_size) 947 | print("acc_DMMR: " + str(acc_DMMR)) 948 | writer.add_scalars('subject: ' + str(one_subject + 1) + ' ' + 'test DMMR/test acc', 949 | {'test acc': acc_DMMR}, epoch + 1) 950 | if acc_DMMR > acc_final: 951 | acc_final = acc_DMMR 952 | best_pretrain_model = copy.deepcopy(preTrainModel.state_dict()) 953 | best_tune_model = copy.deepcopy(fineTuneModel.state_dict()) 954 | best_test_model = copy.deepcopy(testModel.state_dict()) 955 | modelDir = "model/" + args.way + "/" + args.index + "/" 956 | try: 957 | os.makedirs(modelDir) 958 | except: 959 | pass 960 | torch.save(best_pretrain_model, modelDir + str(one_subject) + '_pretrain_model.pth') 961 | torch.save(best_tune_model, modelDir + str(one_subject) + '_tune_model.pth') 962 | torch.save(best_test_model, modelDir + str(one_subject) + '_test_model.pth') 963 | return acc_final 964 | 965 | def trainDMMR_Noise_ChannelsShuffling(data_loader_dict, optimizer_config, cuda, args, iteration, writer, one_subject): 966 | source_loader = data_loader_dict['source_loader'] 967 | # The pre-training phase 968 | preTrainModel = PreTrainingWithChannelsShuffling(cuda, number_of_source=len(source_loader), number_of_category=args.cls_classes, batch_size=args.batch_size, time_steps=args.time_steps) 969 | if cuda: 970 | preTrainModel = preTrainModel.cuda() 971 | source_iters = [] 972 | for i in range(len(source_loader)): 973 | source_iters.append(iter(source_loader[i])) 974 | optimizer_PreTraining = torch.optim.Adam(preTrainModel.parameters(), **optimizer_config) 975 | 976 | acc_final = 0 977 | for epoch in range(args.epoch_preTraining): 978 | print("epoch: "+str(epoch)) 979 | start_time_pretrain = time.time() 980 | preTrainModel.train() 981 | count = 0 982 | data_set_all = 0 983 | for i in range(1, iteration + 1): 984 | p = float(i + epoch * iteration) / args.epoch_preTraining / iteration 985 | m = 2. / (1. + np.exp(-10 * p)) - 1 986 | batch_dict = defaultdict(list) 987 | data_dict = defaultdict(list) 988 | label_dict = defaultdict(list) 989 | label_data_dict = defaultdict(set) 990 | 991 | for j in range(len(source_iters)): 992 | try: 993 | batch_dict[j] = next(source_iters[j]) 994 | except: 995 | source_iters[j] = iter(source_loader[j]) 996 | batch_dict[j]= next(source_iters[j]) 997 | index = 0 998 | for o in batch_dict[j][1]: 999 | cur_label = o[0].item() 1000 | data_dict[j].append(batch_dict[j][0][index]) 1001 | label_dict[j].append(cur_label) 1002 | index+=1 1003 | 1004 | for j in range(len(source_iters)): 1005 | subject_id = torch.ones(args.batch_size) 1006 | subject_id = subject_id * j 1007 | subject_id = subject_id.long() 1008 | 1009 | source_data, source_label = batch_dict[j] 1010 | # prepare corresponding new batch, the new batch has same label with current batch 1011 | label_data_dict_list = [] 1012 | for one_index in range(args.source_subjects): 1013 | cur_data_list = data_dict[one_index] 1014 | cur_label_list = label_dict[one_index] 1015 | for one in range(args.batch_size): 1016 | label_data_dict[cur_label_list[one]].add(cur_data_list[one]) 1017 | label_data_dict_list.append(label_data_dict) 1018 | label_data_dict = defaultdict(set) 1019 | corres_batch_data = [] 1020 | for i in range(len(label_data_dict_list)): 1021 | for one in source_label: 1022 | label_cur = one[0].item() 1023 | corres_batch_data.append(random.choice(list(label_data_dict_list[i][label_cur]))) 1024 | corres_batch_data = torch.stack(corres_batch_data) 1025 | if cuda: 1026 | source_data = source_data.cuda() 1027 | source_label = source_label.cuda() 1028 | subject_id = subject_id.cuda() 1029 | corres_batch_data =corres_batch_data.cuda() 1030 | data_set_all+=len(source_label) 1031 | optimizer_PreTraining.zero_grad() 1032 | rec_loss, sim_loss = preTrainModel(source_data, corres_batch_data, subject_id, args, m, mark=j) 1033 | loss_pretrain = rec_loss + args.beta*sim_loss 1034 | loss_pretrain.backward() 1035 | optimizer_PreTraining.step() 1036 | print("data set amount: "+str(data_set_all)) 1037 | writer.add_scalars('subject: '+str(one_subject+1)+' '+'train DMMR/loss', 1038 | {'loss_pretrain':loss_pretrain.data,'rec_loss':rec_loss.data,'sim_loss':sim_loss.data}, epoch + 1) 1039 | end_time_pretrain = time.time() 1040 | pretrain_epoch_time = end_time_pretrain - start_time_pretrain 1041 | print("The time required for one pre-training epoch is:", pretrain_epoch_time, "second") 1042 | print("rec_loss: "+str(rec_loss)) 1043 | 1044 | # The fine-tuning phase 1045 | source_iters2 = [] 1046 | for i in range(len(source_loader)): 1047 | source_iters2.append(iter(source_loader[i])) 1048 | fineTuneModel = DMMRFineTuningModel(cuda, preTrainModel, number_of_source=len(source_loader), 1049 | number_of_category=args.cls_classes, batch_size=args.batch_size, 1050 | time_steps=args.time_steps) 1051 | 1052 | optimizer_FineTuning = torch.optim.Adam(fineTuneModel.parameters(), **optimizer_config) 1053 | if cuda: 1054 | fineTuneModel = fineTuneModel.cuda() 1055 | for epoch in range(args.epoch_fineTuning): 1056 | print("epoch: " + str(epoch)) 1057 | start_time = time.time() 1058 | fineTuneModel.train() 1059 | count = 0 1060 | data_set_all = 0 1061 | for i in range(1, iteration + 1): 1062 | p = float(i + epoch * iteration) / args.epoch_preTraining / iteration 1063 | m = 2. / (1. + np.exp(-10 * p)) - 1 1064 | batch_dict = defaultdict(list) 1065 | data_dict = defaultdict(list) 1066 | label_dict = defaultdict(list) 1067 | label_data_dict = defaultdict(set) 1068 | 1069 | for j in range(len(source_iters2)): 1070 | try: 1071 | batch_dict[j] = next(source_iters2[j]) 1072 | except: 1073 | source_iters2[j] = iter(source_loader[j]) 1074 | batch_dict[j] = next(source_iters2[j]) 1075 | index = 0 1076 | for o in batch_dict[j][1]: 1077 | cur_label = o[0].item() 1078 | data_dict[j].append(batch_dict[j][0][index]) 1079 | label_dict[j].append(cur_label) 1080 | index += 1 1081 | for j in range(len(source_iters)): 1082 | source_data, source_label = batch_dict[j] 1083 | if cuda: 1084 | source_data = source_data.cuda() 1085 | source_label = source_label.cuda() 1086 | data_set_all += len(source_label) 1087 | optimizer_FineTuning.zero_grad() 1088 | x_pred, x_logits, cls_loss = fineTuneModel(source_data, source_label) 1089 | cls_loss.backward() 1090 | optimizer_FineTuning.step() 1091 | _, pred = torch.max(x_pred, dim=1) 1092 | count += pred.eq(source_label.squeeze().data.view_as(pred)).sum() 1093 | end_time = time.time() 1094 | epoch_time = end_time - start_time 1095 | print("The time required for one fine-tuning epoch is:", epoch_time, "second") 1096 | print("data set amount: " + str(data_set_all)) 1097 | acc = float(count) / data_set_all 1098 | writer.add_scalars('subject: ' + str(one_subject + 1) + ' ' + 'train DMMR/loss', 1099 | {'cls_loss': cls_loss.data}, epoch + 1) 1100 | writer.add_scalar('subject: ' + str(one_subject + 1) + ' ' + 'train DMMR/train accuracy', acc, epoch + 1) 1101 | print("acc: " + str(acc)) 1102 | 1103 | testModel = DMMRTestModel(fineTuneModel) 1104 | acc_DMMR = testDMMR(data_loader_dict["test_loader"], testModel, cuda, args.batch_size) 1105 | print("acc_DMMR: " + str(acc_DMMR)) 1106 | writer.add_scalars('subject: ' + str(one_subject + 1) + ' ' + 'test DMMR/test acc', 1107 | {'test acc': acc_DMMR}, epoch + 1) 1108 | if acc_DMMR > acc_final: 1109 | acc_final = acc_DMMR 1110 | best_pretrain_model = copy.deepcopy(preTrainModel.state_dict()) 1111 | best_tune_model = copy.deepcopy(fineTuneModel.state_dict()) 1112 | best_test_model = copy.deepcopy(testModel.state_dict()) 1113 | modelDir = "model/" + args.way + "/" + args.index + "/" 1114 | try: 1115 | os.makedirs(modelDir) 1116 | except: 1117 | pass 1118 | torch.save(best_pretrain_model, modelDir + str(one_subject) + '_pretrain_model.pth') 1119 | torch.save(best_tune_model, modelDir + str(one_subject) + '_tune_model.pth') 1120 | torch.save(best_test_model, modelDir + str(one_subject) + '_test_model.pth') 1121 | return acc_final 1122 | 1123 | def trainDMMR_Noise_Dropout(data_loader_dict, optimizer_config, cuda, args, iteration, writer, one_subject): 1124 | source_loader = data_loader_dict['source_loader'] 1125 | # The pre-training phase 1126 | preTrainModel = PreTrainingWithDropout(cuda, number_of_source=len(source_loader), number_of_category=args.cls_classes, batch_size=args.batch_size, time_steps=args.time_steps, dropout_rate=0.2) 1127 | if cuda: 1128 | preTrainModel = preTrainModel.cuda() 1129 | source_iters = [] 1130 | for i in range(len(source_loader)): 1131 | source_iters.append(iter(source_loader[i])) 1132 | optimizer_PreTraining = torch.optim.Adam(preTrainModel.parameters(), **optimizer_config) 1133 | 1134 | acc_final = 0 1135 | for epoch in range(args.epoch_preTraining): 1136 | print("epoch: "+str(epoch)) 1137 | start_time_pretrain = time.time() 1138 | preTrainModel.train() 1139 | count = 0 1140 | data_set_all = 0 1141 | for i in range(1, iteration + 1): 1142 | p = float(i + epoch * iteration) / args.epoch_preTraining / iteration 1143 | m = 2. / (1. + np.exp(-10 * p)) - 1 1144 | batch_dict = defaultdict(list) 1145 | data_dict = defaultdict(list) 1146 | label_dict = defaultdict(list) 1147 | label_data_dict = defaultdict(set) 1148 | 1149 | for j in range(len(source_iters)): 1150 | try: 1151 | batch_dict[j] = next(source_iters[j]) 1152 | except: 1153 | source_iters[j] = iter(source_loader[j]) 1154 | batch_dict[j]= next(source_iters[j]) 1155 | index = 0 1156 | for o in batch_dict[j][1]: 1157 | cur_label = o[0].item() 1158 | data_dict[j].append(batch_dict[j][0][index]) 1159 | label_dict[j].append(cur_label) 1160 | index+=1 1161 | 1162 | for j in range(len(source_iters)): 1163 | subject_id = torch.ones(args.batch_size) 1164 | subject_id = subject_id * j 1165 | subject_id = subject_id.long() 1166 | 1167 | source_data, source_label = batch_dict[j] 1168 | # prepare corresponding new batch, the new batch has same label with current batch 1169 | label_data_dict_list = [] 1170 | for one_index in range(args.source_subjects): 1171 | cur_data_list = data_dict[one_index] 1172 | cur_label_list = label_dict[one_index] 1173 | for one in range(args.batch_size): 1174 | label_data_dict[cur_label_list[one]].add(cur_data_list[one]) 1175 | label_data_dict_list.append(label_data_dict) 1176 | label_data_dict = defaultdict(set) 1177 | corres_batch_data = [] 1178 | for i in range(len(label_data_dict_list)): 1179 | for one in source_label: 1180 | label_cur = one[0].item() 1181 | corres_batch_data.append(random.choice(list(label_data_dict_list[i][label_cur]))) 1182 | corres_batch_data = torch.stack(corres_batch_data) 1183 | if cuda: 1184 | source_data = source_data.cuda() 1185 | source_label = source_label.cuda() 1186 | subject_id = subject_id.cuda() 1187 | corres_batch_data =corres_batch_data.cuda() 1188 | data_set_all+=len(source_label) 1189 | optimizer_PreTraining.zero_grad() 1190 | rec_loss, sim_loss = preTrainModel(source_data, corres_batch_data, subject_id, args, m, mark=j) 1191 | loss_pretrain = rec_loss + args.beta*sim_loss 1192 | loss_pretrain.backward() 1193 | optimizer_PreTraining.step() 1194 | print("data set amount: "+str(data_set_all)) 1195 | writer.add_scalars('subject: '+str(one_subject+1)+' '+'train DMMR/loss', 1196 | {'loss_pretrain':loss_pretrain.data,'rec_loss':rec_loss.data,'sim_loss':sim_loss.data}, epoch + 1) 1197 | end_time_pretrain = time.time() 1198 | pretrain_epoch_time = end_time_pretrain - start_time_pretrain 1199 | print("The time required for one pre-training epoch is:", pretrain_epoch_time, "second") 1200 | print("rec_loss: "+str(rec_loss)) 1201 | 1202 | # The fine-tuning phase 1203 | source_iters2 = [] 1204 | for i in range(len(source_loader)): 1205 | source_iters2.append(iter(source_loader[i])) 1206 | fineTuneModel = DMMRFineTuningModel(cuda, preTrainModel, number_of_source=len(source_loader), 1207 | number_of_category=args.cls_classes, batch_size=args.batch_size, 1208 | time_steps=args.time_steps) 1209 | 1210 | optimizer_FineTuning = torch.optim.Adam(fineTuneModel.parameters(), **optimizer_config) 1211 | if cuda: 1212 | fineTuneModel = fineTuneModel.cuda() 1213 | for epoch in range(args.epoch_fineTuning): 1214 | print("epoch: " + str(epoch)) 1215 | start_time = time.time() 1216 | fineTuneModel.train() 1217 | count = 0 1218 | data_set_all = 0 1219 | for i in range(1, iteration + 1): 1220 | p = float(i + epoch * iteration) / args.epoch_preTraining / iteration 1221 | m = 2. / (1. + np.exp(-10 * p)) - 1 1222 | batch_dict = defaultdict(list) 1223 | data_dict = defaultdict(list) 1224 | label_dict = defaultdict(list) 1225 | label_data_dict = defaultdict(set) 1226 | 1227 | for j in range(len(source_iters2)): 1228 | try: 1229 | batch_dict[j] = next(source_iters2[j]) 1230 | except: 1231 | source_iters2[j] = iter(source_loader[j]) 1232 | batch_dict[j] = next(source_iters2[j]) 1233 | index = 0 1234 | for o in batch_dict[j][1]: 1235 | cur_label = o[0].item() 1236 | data_dict[j].append(batch_dict[j][0][index]) 1237 | label_dict[j].append(cur_label) 1238 | index += 1 1239 | for j in range(len(source_iters)): 1240 | source_data, source_label = batch_dict[j] 1241 | if cuda: 1242 | source_data = source_data.cuda() 1243 | source_label = source_label.cuda() 1244 | data_set_all += len(source_label) 1245 | optimizer_FineTuning.zero_grad() 1246 | x_pred, x_logits, cls_loss = fineTuneModel(source_data, source_label) 1247 | cls_loss.backward() 1248 | optimizer_FineTuning.step() 1249 | _, pred = torch.max(x_pred, dim=1) 1250 | count += pred.eq(source_label.squeeze().data.view_as(pred)).sum() 1251 | end_time = time.time() 1252 | epoch_time = end_time - start_time 1253 | print("The time required for one fine-tuning epoch is:", epoch_time, "second") 1254 | print("data set amount: " + str(data_set_all)) 1255 | acc = float(count) / data_set_all 1256 | writer.add_scalars('subject: ' + str(one_subject + 1) + ' ' + 'train DMMR/loss', 1257 | {'cls_loss': cls_loss.data}, epoch + 1) 1258 | writer.add_scalar('subject: ' + str(one_subject + 1) + ' ' + 'train DMMR/train accuracy', acc, epoch + 1) 1259 | print("acc: " + str(acc)) 1260 | 1261 | testModel = DMMRTestModel(fineTuneModel) 1262 | acc_DMMR = testDMMR(data_loader_dict["test_loader"], testModel, cuda, args.batch_size) 1263 | print("acc_DMMR: " + str(acc_DMMR)) 1264 | writer.add_scalars('subject: ' + str(one_subject + 1) + ' ' + 'test DMMR/test acc', 1265 | {'test acc': acc_DMMR}, epoch + 1) 1266 | if acc_DMMR > acc_final: 1267 | acc_final = acc_DMMR 1268 | best_pretrain_model = copy.deepcopy(preTrainModel.state_dict()) 1269 | best_tune_model = copy.deepcopy(fineTuneModel.state_dict()) 1270 | best_test_model = copy.deepcopy(testModel.state_dict()) 1271 | modelDir = "model/" + args.way + "/" + args.index + "/" 1272 | try: 1273 | os.makedirs(modelDir) 1274 | except: 1275 | pass 1276 | torch.save(best_pretrain_model, modelDir + str(one_subject) + '_pretrain_model.pth') 1277 | torch.save(best_tune_model, modelDir + str(one_subject) + '_tune_model.pth') 1278 | torch.save(best_test_model, modelDir + str(one_subject) + '_test_model.pth') 1279 | return acc_final 1280 | 1281 | ############## T-SNE plots ############## 1282 | class FeatureVisualize(object): 1283 | ''' 1284 | Visualize features by TSNE 1285 | ''' 1286 | 1287 | def __init__(self, features, labels): 1288 | ''' 1289 | features: (m,n) 1290 | labels: (m,) 1291 | ''' 1292 | self.features = features 1293 | self.labels = labels 1294 | 1295 | def plot_tsne(self, save_filename, save_eps=False): 1296 | ''' Plot TSNE figure. Set save_eps=True if you want to save a .eps file. 1297 | ''' 1298 | tsne = TSNE(n_components=2, init='pca', random_state=0) 1299 | features = tsne.fit_transform(self.features) 1300 | x_min, x_max = np.min(features, 0), np.max(features, 0) 1301 | data = (features - x_min) / (x_max - x_min) 1302 | del features 1303 | for i in range(data.shape[0]): 1304 | colors = plt.cm.tab20.colors 1305 | plt.scatter(data[i, 0], data[i, 1], color=colors[self.labels[i]]) 1306 | plt.colorbar() 1307 | plt.xticks([]) 1308 | plt.yticks([]) 1309 | plt.title('T-SNE') 1310 | if save_eps: 1311 | plt.savefig('tsne.eps', dpi=600, format='eps') 1312 | plt.savefig(save_filename, dpi=600) 1313 | plt.show() 1314 | def TSNEForDMMR(data_loader_dict, cuda, args): 1315 | source_loader = data_loader_dict['source_loader'] 1316 | # The pre-training phase 1317 | target_loader = data_loader_dict["test_loader"] 1318 | preTrainModel = DMMRPreTrainingModel(cuda, 1319 | number_of_source=len(source_loader), 1320 | number_of_category=args.cls_classes, 1321 | batch_size=args.batch_size, 1322 | time_steps=args.time_steps) 1323 | #load the pretrained model 1324 | preTrainModel.load_state_dict(torch.load("T-SNE/model/1_pretrain_model.pth", map_location='cpu')) 1325 | preTrainModel.eval() 1326 | pretrainReturnFeature = ModelReturnFeatures(preTrainModel, time_steps=args.time_steps) 1327 | fineTuneModel = DMMRFineTuningModel(cuda, preTrainModel, number_of_source=len(source_loader), 1328 | number_of_category=args.cls_classes, 1329 | batch_size=args.batch_size, 1330 | time_steps=args.time_steps) 1331 | # load the fine-tuned model 1332 | fineTuneModel.load_state_dict(torch.load("T-SNE/model/1_tune_model.pth", map_location='cpu')) 1333 | fineTuneModel.eval() 1334 | fineTuneModelReturnFeauters = ModelReturnFeatures(fineTuneModel, time_steps=args.time_steps) 1335 | fineTuneModelReturnFeauters.eval() 1336 | 1337 | source_iters = [] 1338 | for i in range(len(source_loader)): 1339 | source_iters.append(iter(source_loader[i])) 1340 | 1341 | origin_features_list = [] 1342 | origin_subject_id_list = [] 1343 | label_list = [] 1344 | pretrain_shared_features_list = [] 1345 | shared_features_list = [] 1346 | for i in range(1, 2): 1347 | for j in range(len(source_iters)): 1348 | try: 1349 | source_data, source_label = next(source_iters[j]) 1350 | except: 1351 | source_iters[j] = iter(source_loader[j]) 1352 | source_data, source_label = next(source_iters[j]) 1353 | subject_id = torch.ones(args.batch_size) 1354 | subject_id = subject_id * j 1355 | subject_id = subject_id.long() 1356 | 1357 | _, pretrain_shared_feature = pretrainReturnFeature(source_data) 1358 | _, shared_feature = fineTuneModelReturnFeauters(source_data) 1359 | 1360 | num_samples = 50 1361 | # 50 samples are taken from each individual subject data 1362 | source_data_narray = source_data.numpy() 1363 | label_data_narray = source_label.squeeze().numpy() 1364 | # Reshape for sampling 1365 | source_data_narray = source_data_narray.reshape(512, 30 * 310) 1366 | # Randomly select 50 samples from it to obtain a tensor of size (50, 310). 1367 | random_indices = np.random.choice(source_data_narray.shape[0], num_samples, replace=False) 1368 | source_data_narray_50 = source_data_narray[random_indices] 1369 | subject_narray = np.full((num_samples,), j) 1370 | label_data_narray_50 = label_data_narray[random_indices] 1371 | #origin feature 1372 | origin_features_list.append(source_data_narray_50) 1373 | origin_subject_id_list.append(subject_narray) 1374 | label_list.append(label_data_narray_50) 1375 | 1376 | # pretrained feature 1377 | pretrain_shared_feature_narray = pretrain_shared_feature.detach().numpy() 1378 | pretrain_shared_feature_narray_50 = pretrain_shared_feature_narray[random_indices] 1379 | pretrain_shared_features_list.append(pretrain_shared_feature_narray_50) 1380 | #fine-tuned feature 1381 | shared_feature_narray = shared_feature.detach().numpy() 1382 | shared_feature_narray_50 = shared_feature_narray[random_indices] 1383 | shared_features_list.append(shared_feature_narray_50) 1384 | 1385 | #generate target data 1386 | target_data, target_label = next(iter(target_loader)) 1387 | _, target_pretrain_shared_feature = pretrainReturnFeature(target_data) 1388 | _, target_shared_feature = fineTuneModelReturnFeauters(target_data) 1389 | target_data_narray = target_data.numpy() 1390 | target_label = target_label.squeeze().numpy() 1391 | target_data_narray = target_data_narray.reshape(512, 30 * 310) 1392 | random_indices_target = np.random.choice(target_data_narray.shape[0], num_samples, replace=False) 1393 | target_data_narray_50 = target_data_narray[random_indices_target] 1394 | target_subject_id = np.full((num_samples,), 14) 1395 | target_label_narray_50 = target_label[random_indices] 1396 | 1397 | 1398 | #add target subject data 1399 | origin_features_list.append(target_data_narray_50) 1400 | origin_subject_id_list.append(target_subject_id) 1401 | label_list.append(target_label_narray_50) 1402 | 1403 | target_pretrain_shared_feature_narray = target_pretrain_shared_feature.detach().numpy() 1404 | target_pretrain_shared_feature_narray_50 = target_pretrain_shared_feature_narray[random_indices] 1405 | pretrain_shared_features_list.append(target_pretrain_shared_feature_narray_50) 1406 | 1407 | target_shared_feature_narray = target_shared_feature.detach().numpy() 1408 | target_shared_feature_narray_50 = target_shared_feature_narray[random_indices] 1409 | shared_features_list.append(target_shared_feature_narray_50) 1410 | 1411 | 1412 | #concat for later norm 1413 | origin_stacked_feature = np.concatenate(origin_features_list, axis=0) 1414 | stacked_subject_id = np.concatenate(origin_subject_id_list, axis=0) 1415 | stacked_label = np.concatenate(label_list, axis=0) 1416 | 1417 | # T-SNE 1418 | #origin data 1419 | vis_pretrain_shared = FeatureVisualize(origin_stacked_feature, stacked_subject_id) 1420 | vis_pretrain_shared.plot_tsne('T-SNE/plot/origin_subject.jpg',save_eps=False) 1421 | vis_pretrain_shared = FeatureVisualize(origin_stacked_feature, stacked_label) 1422 | vis_pretrain_shared.plot_tsne("T-SNE/plot/origin_label.jpg",save_eps=False) 1423 | 1424 | # pretrained feature 1425 | pretrain_shared_stacked_feature = np.concatenate(pretrain_shared_features_list, axis=0) 1426 | vis_pretrain_shared = FeatureVisualize(pretrain_shared_stacked_feature, stacked_subject_id) 1427 | vis_pretrain_shared.plot_tsne('T-SNE/plot/pretrain_subject.jpg',save_eps=False) 1428 | vis_pretrain_shared = FeatureVisualize(pretrain_shared_stacked_feature, stacked_label) 1429 | vis_pretrain_shared.plot_tsne("T-SNE/plot/pretrain_label.jpg",save_eps=False) 1430 | # fine tuned data 1431 | shared_stacked_feature = np.concatenate(shared_features_list, axis=0) 1432 | vis_shared = FeatureVisualize(shared_stacked_feature, stacked_subject_id) 1433 | vis_shared.plot_tsne("T-SNE/plot/tune_subject.jpg",save_eps=False) 1434 | vis_shared_label = FeatureVisualize(shared_stacked_feature, stacked_label) 1435 | vis_shared_label.plot_tsne("T-SNE/plot/tune_label.jpg",save_eps=False) 1436 | return 0 --------------------------------------------------------------------------------