├── .DS_Store
├── .idea
├── .gitignore
├── PPDA_PLUS.iml
├── inspectionProfiles
│ └── Project_Default.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── GradientReverseLayer.py
├── LICENSE
├── Readme.md
├── T-SNE
├── generatePlotByTSNE.py
├── model
│ ├── 1_pretrain_model.pth
│ └── 1_tune_model.pth
└── plot
│ ├── origin_label.jpg
│ ├── origin_subject.jpg
│ ├── pretrain_label.jpg
│ ├── pretrain_subject.jpg
│ ├── tune_label.jpg
│ └── tune_subject.jpg
├── __pycache__
├── GradientReverseLayer.cpython-36.pyc
├── model.cpython-36.pyc
├── preprocess.cpython-36.pyc
├── test.cpython-36.pyc
└── train.cpython-36.pyc
├── ablation
├── withoutBothMixAndNoise.py
├── withoutNoise.py
└── witoutMix.py
├── main.py
├── model.py
├── noiseInjectionMethods
├── Dropout.py
├── channelsShuffling.py
├── maskChannels.py
└── maskTimeSteps.py
├── preprocess.py
├── requirements.txt
├── test.py
└── train.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CodeBreathing/DMMR/5fd5e6ebf1a14d7afa1c07213a71aa97980b5b8b/.DS_Store
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/PPDA_PLUS.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/GradientReverseLayer.py:
--------------------------------------------------------------------------------
1 | from torch.autograd import Function
2 |
3 | #gradient reversal
4 | class ReverseLayerF(Function):
5 | @staticmethod
6 | def forward(ctx, x, m):
7 | ctx.m = m
8 | return x.view_as(x)
9 |
10 | @staticmethod
11 | def backward(ctx, grad_output):
12 | output = grad_output.neg() * ctx.m
13 | return output, None
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/Readme.md:
--------------------------------------------------------------------------------
1 | # DMMR
2 | This is the official PyTorch implementation for our AAAI'24 paper DMMR: Cross-Subject Domain Generalization for EEG-Based Emotion Recognition via Denoising Mixed Mutual Reconstruction
3 | [Paper link:](https://ojs.aaai.org/index.php/AAAI/article/view/27819)
4 |
5 | ## Datasets
6 | The public available datasets (SEED and SEED-IV) can be downloaded from the https://bcmi.sjtu.edu.cn/home/seed/index.html
7 |
8 | To facilitate data retrieval, the data from the first session of all subjects is utilized in both datasets, the file structure of the datasets should be like:
9 | ```
10 | ExtractedFeatures/
11 | 1/
12 | eeg_feature_smooth/
13 | 1/
14 | ```
15 | Kindly change the file path in the main.py
16 |
17 | ## Usage of DMMR
18 | Run `python main.py`, and The results will be recorded in TensorBoard.
19 | The argument for the `dataset_name` is set to be `seed3` for the SEED dataset, and `seed4` for the SEED-IV dataset, respectively.
20 |
21 | ## Ablation Studies
22 | Run `python ablation/witoutMix.py`
23 | Run `python ablation/withoutNoise.py`
24 | Run `python ablation/withoutBothMixAndNoise.py`
25 |
26 | ## other noise injection methods
27 | Run `python noiseInjectionMethods/maskChannels.py`
28 | Run `python noiseInjectionMethods/maskTimeSteps.py`
29 | Run `python noiseInjectionMethods/channelsShuffling.py`
30 | Run `python noiseInjectionMethods/Dropout.py`
31 |
32 | ## Plot with TSNE
33 | Run `python T-SNE/generatePlotByTSNE.py`
34 |
35 | ## Citation
36 | If you found our work useful for your research, please cite our work:
37 | ```
38 | @inproceedings{wang2024dmmr,
39 | title={DMMR: Cross-Subject Domain Generalization for EEG-Based Emotion Recognition via Denoising Mixed Mutual Reconstruction},
40 | author={Wang, Yiming and Zhang, Bin and Tang, Yujiao},
41 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
42 | volume={38},
43 | number={1},
44 | pages={628--636},
45 | year={2024}
46 | }
47 | ```
48 | We thank the following repositories for providing helpful functions used in our work:
49 | [MS-MDA](https://github.com/VoiceBeer/MS-MDA)
50 | [DANN](https://github.com/fungtion/DANN)
51 |
52 |
53 |
--------------------------------------------------------------------------------
/T-SNE/generatePlotByTSNE.py:
--------------------------------------------------------------------------------
1 | from preprocess import getDataLoaders
2 | import math
3 | import argparse
4 | from train import *
5 | import random
6 | import os
7 | import numpy as np
8 |
9 | os.environ["CUDA_VISIBLE_DEVICES"] = '3'
10 |
11 | def set_seed(seed=3):
12 | random.seed(seed)
13 | np.random.seed(seed)
14 | os.environ['PYTHONHASHSEED'] = str(seed)
15 | torch.manual_seed(seed)
16 | torch.cuda.manual_seed(seed)
17 | torch.cuda.manual_seed_all(seed)
18 | torch.backends.cudnn.deterministic = True
19 |
20 | def main(data_loader_dict,args, cuda, seed=3):
21 | set_seed(seed)
22 | acc = TSNEForDMMR(data_loader_dict, cuda, args)
23 | return acc
24 |
25 | if __name__ == '__main__':
26 | cuda = torch.cuda.is_available()
27 | parser = argparse.ArgumentParser(description='DMMR')
28 |
29 | #config of experiment
30 | parser.add_argument("--way", type=str, default='TSNE', help="name of current way")
31 | parser.add_argument("--index", type=str, default='plot features', help="tensorboard index")
32 | parser.add_argument("--num_workers_train", type=int, default=0, help="classes of dataset")
33 | parser.add_argument("--num_workers_test", type=int, default=0, help="classes of dataset")
34 | parser.add_argument("--path", type=str, default="../eeg_data/ExtractedFeatures/", help="classes of dataset")
35 |
36 | #config of dataset
37 | parser.add_argument("--dataset_name", type=str, nargs='?', default='seed3', help="all subject numbers")
38 | parser.add_argument("--session", type=str, nargs='?', default='1', help="select session")
39 | parser.add_argument("--subjects", type=int, choices=[15], default=15, help="all subject numbers")
40 | parser.add_argument("--cls_classes", type=int, choices=[3], default=3, help="classes of dataset")
41 | parser.add_argument("--dim", type=int, default=310, help="dim of input")
42 |
43 | #config of net
44 | parser.add_argument("--input_dim", type=int, default=310, help="input dim is the same with sample's last dim")
45 | parser.add_argument("--hid_dim", type=int, default=64, help="hid dim is for hidden layer of lstm")
46 | parser.add_argument("--n_layers", type=int, default=2, help="num of layers of lstm")
47 | parser.add_argument("--batch_size", type=int, default=512, help="batch size")
48 | parser.add_argument("--time_steps", type=int, choices=[30], default=30, help="window size")
49 | parser.add_argument("--epoch_preTraining", type=int, default=300, help="epoch of baseModel")
50 | parser.add_argument("--epoch_fineTuning", type=int, default=500, help="epoch of baseModel4")
51 | parser.add_argument("--lr", type=int, default=1e-3, help="epoch of pretrain")
52 | parser.add_argument("--weight_decay", type=float, default=0.0005, help="weight decay")
53 | parser.add_argument("--beta", type=float, default=0.05, help="balancing hyperparameter in the loss of pretraining phase")
54 |
55 | args = parser.parse_args()
56 | args.source_subjects = args.subjects-1
57 | net_config = {"fts": args.dim, "cls": args.cls_classes}
58 | optim_config = {"lr": args.lr, "weight_decay": args.weight_decay}
59 |
60 | acc_list=[]
61 | #use the subject 1 as the target subject for testing
62 | one_subject = 1
63 | # 1.data preparation
64 | source_loaders, test_loader = getDataLoaders(one_subject, args)
65 | data_loader_dict = {"source_loader": source_loaders, "test_loader":test_loader}
66 | # 2. call main
67 | acc = main(data_loader_dict, args, cuda)
--------------------------------------------------------------------------------
/T-SNE/model/1_pretrain_model.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CodeBreathing/DMMR/5fd5e6ebf1a14d7afa1c07213a71aa97980b5b8b/T-SNE/model/1_pretrain_model.pth
--------------------------------------------------------------------------------
/T-SNE/model/1_tune_model.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CodeBreathing/DMMR/5fd5e6ebf1a14d7afa1c07213a71aa97980b5b8b/T-SNE/model/1_tune_model.pth
--------------------------------------------------------------------------------
/T-SNE/plot/origin_label.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CodeBreathing/DMMR/5fd5e6ebf1a14d7afa1c07213a71aa97980b5b8b/T-SNE/plot/origin_label.jpg
--------------------------------------------------------------------------------
/T-SNE/plot/origin_subject.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CodeBreathing/DMMR/5fd5e6ebf1a14d7afa1c07213a71aa97980b5b8b/T-SNE/plot/origin_subject.jpg
--------------------------------------------------------------------------------
/T-SNE/plot/pretrain_label.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CodeBreathing/DMMR/5fd5e6ebf1a14d7afa1c07213a71aa97980b5b8b/T-SNE/plot/pretrain_label.jpg
--------------------------------------------------------------------------------
/T-SNE/plot/pretrain_subject.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CodeBreathing/DMMR/5fd5e6ebf1a14d7afa1c07213a71aa97980b5b8b/T-SNE/plot/pretrain_subject.jpg
--------------------------------------------------------------------------------
/T-SNE/plot/tune_label.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CodeBreathing/DMMR/5fd5e6ebf1a14d7afa1c07213a71aa97980b5b8b/T-SNE/plot/tune_label.jpg
--------------------------------------------------------------------------------
/T-SNE/plot/tune_subject.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CodeBreathing/DMMR/5fd5e6ebf1a14d7afa1c07213a71aa97980b5b8b/T-SNE/plot/tune_subject.jpg
--------------------------------------------------------------------------------
/__pycache__/GradientReverseLayer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CodeBreathing/DMMR/5fd5e6ebf1a14d7afa1c07213a71aa97980b5b8b/__pycache__/GradientReverseLayer.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CodeBreathing/DMMR/5fd5e6ebf1a14d7afa1c07213a71aa97980b5b8b/__pycache__/model.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/preprocess.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CodeBreathing/DMMR/5fd5e6ebf1a14d7afa1c07213a71aa97980b5b8b/__pycache__/preprocess.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/test.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CodeBreathing/DMMR/5fd5e6ebf1a14d7afa1c07213a71aa97980b5b8b/__pycache__/test.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/train.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CodeBreathing/DMMR/5fd5e6ebf1a14d7afa1c07213a71aa97980b5b8b/__pycache__/train.cpython-36.pyc
--------------------------------------------------------------------------------
/ablation/withoutBothMixAndNoise.py:
--------------------------------------------------------------------------------
1 | from preprocess import getDataLoaders
2 | import math
3 | from torch.utils.tensorboard import SummaryWriter
4 | import argparse
5 | from train import *
6 | import random
7 | import os
8 | import numpy as np
9 |
10 | os.environ["CUDA_VISIBLE_DEVICES"] = '3'
11 |
12 | def set_seed(seed=3):
13 | random.seed(seed)
14 | np.random.seed(seed)
15 | os.environ['PYTHONHASHSEED'] = str(seed)
16 | torch.manual_seed(seed)
17 | torch.cuda.manual_seed(seed)
18 | torch.cuda.manual_seed_all(seed)
19 | torch.backends.cudnn.deterministic = True
20 |
21 | def main(data_loader_dict, args, optim_config, cuda, writer, one_subject, seed=3):
22 | set_seed(seed)
23 | if args.dataset_name == 'seed3':
24 | iteration = 7
25 | elif args.dataset_name == 'seed4':
26 | iteration = 3
27 | acc = trainDMMR_WithoutBothMixAndNoise(data_loader_dict, optim_config, cuda, args, iteration, writer, one_subject)
28 | return acc
29 |
30 | if __name__ == '__main__':
31 | cuda = torch.cuda.is_available()
32 | parser = argparse.ArgumentParser(description='DMMR')
33 |
34 | #config of experiment
35 | parser.add_argument("--way", type=str, default='DMMR/seed3', help="name of current way")
36 | parser.add_argument("--index", type=str, default='withoutBothMixAndNoise', help="tensorboard index")
37 |
38 | #config of dataset
39 | parser.add_argument("--dataset_name", type=str, nargs='?', default='seed3', help="the dataset name, supporting seed3 and seed4")
40 | parser.add_argument("--session", type=str, nargs='?', default='1', help="selected session")
41 | parser.add_argument("--subjects", type=int, choices=[15], default=15, help="the number of all subject")
42 | parser.add_argument("--dim", type=int, default=310, help="dim of input")
43 |
44 | #config of DMMR
45 | parser.add_argument("--input_dim", type=int, default=310, help="input dim is the same with sample's last dim")
46 | parser.add_argument("--hid_dim", type=int, default=64, help="hid dim is for hidden layer of lstm")
47 | parser.add_argument("--n_layers", type=int, default=1, help="num of layers of lstm")
48 | parser.add_argument("--epoch_fineTuning", type=int, default=500, help="epoch of the fine-tuning phase")
49 | parser.add_argument("--lr", type=int, default=1e-3, help="epoch of calModel")
50 | parser.add_argument("--weight_decay", type=float, default=0.0005, help="weight decay")
51 | parser.add_argument("--beta", type=float, default=0.05, help="balancing hyperparameter in the loss of pretraining phase")
52 |
53 |
54 | args = parser.parse_args()
55 | args.source_subjects = args.subjects-1
56 | args.seed3_path = "../eeg_data/ExtractedFeatures/"
57 | args.seed4_path = "../eeg_data/eeg_feature_smooth/"
58 | if cuda:
59 | args.num_workers_train = 4
60 | args.num_workers_test = 2
61 | else:
62 | args.num_workers_train = 0
63 | args.num_workers_test = 0
64 | if args.dataset_name == "seed3":
65 | args.path = args.seed3_path
66 | args.cls_classes = 3
67 | args.time_steps = 30
68 | args.batch_size = 512 #batch_size
69 | args.epoch_preTraining = 300 #epoch of the pre-training phase
70 | elif args.dataset_name == "seed4":
71 | args.path = args.seed4_path
72 | args.cls_classes = 4
73 | args.time_steps = 10
74 | args.batch_size = 256 #batch_size
75 | args.epoch_preTraining = 400 #epoch of the pre-training phase
76 | else:
77 | print("need to define the input dataset")
78 | optim_config = {"lr": args.lr, "weight_decay": args.weight_decay}
79 | # leave-one-subject-out cross-validation
80 | acc_list=[]
81 | writer = SummaryWriter("data/session"+args.session+"/"+args.way+"/" + args.index)
82 | for one_subject in range(0, args.subjects):
83 | # 1.data preparation
84 | source_loaders, test_loader = getDataLoaders(one_subject, args)
85 | data_loader_dict = {"source_loader": source_loaders, "test_loader":test_loader}
86 | # 2. main
87 | acc = main(data_loader_dict, args, optim_config, cuda, writer, one_subject)
88 | writer.add_scalars('single experiment acc: ',
89 | {'test acc': acc}, one_subject + 1)
90 | writer.flush()
91 | acc_list.append(acc)
92 | writer.add_text('final acc avg', str(np.average(acc_list)))
93 | writer.add_text('final acc std', str(np.std(acc_list)))
94 | acc_list_str = [str(x) for x in acc_list]
95 | writer.add_text('final each acc', ",".join(acc_list_str))
96 | writer.add_scalars('final experiment acc scala: /avg',
97 | {'test acc': np.average(acc_list)})
98 | writer.add_scalars('final experiment acc scala: /std',
99 | {'test acc': np.std(acc_list)})
100 | writer.close()
101 |
--------------------------------------------------------------------------------
/ablation/withoutNoise.py:
--------------------------------------------------------------------------------
1 | from preprocess import getDataLoaders
2 | import math
3 | from torch.utils.tensorboard import SummaryWriter
4 | import argparse
5 | from train import *
6 | import random
7 | import os
8 | import numpy as np
9 |
10 | os.environ["CUDA_VISIBLE_DEVICES"] = '3'
11 |
12 | def set_seed(seed=3):
13 | random.seed(seed)
14 | np.random.seed(seed)
15 | os.environ['PYTHONHASHSEED'] = str(seed)
16 | torch.manual_seed(seed)
17 | torch.cuda.manual_seed(seed)
18 | torch.cuda.manual_seed_all(seed)
19 | torch.backends.cudnn.deterministic = True
20 |
21 | def main(data_loader_dict, args, optim_config, cuda, writer, one_subject, seed=3):
22 | set_seed(seed)
23 | if args.dataset_name == 'seed3':
24 | iteration = 7
25 | elif args.dataset_name == 'seed4':
26 | iteration = 3
27 | acc = trainDMMR_WithoutNoise(data_loader_dict, optim_config, cuda, args, iteration, writer, one_subject)
28 | return acc
29 |
30 | if __name__ == '__main__':
31 | cuda = torch.cuda.is_available()
32 | parser = argparse.ArgumentParser(description='DMMR')
33 |
34 | #config of experiment
35 | parser.add_argument("--way", type=str, default='DMMR/seed3', help="name of current way")
36 | parser.add_argument("--index", type=str, default='0', help="tensorboard index")
37 |
38 | #config of dataset
39 | parser.add_argument("--dataset_name", type=str, nargs='?', default='seed3', help="the dataset name, supporting seed3 and seed4")
40 | parser.add_argument("--session", type=str, nargs='?', default='1', help="selected session")
41 | parser.add_argument("--subjects", type=int, choices=[15], default=15, help="the number of all subject")
42 | parser.add_argument("--dim", type=int, default=310, help="dim of input")
43 |
44 | #config of DMMR
45 | parser.add_argument("--input_dim", type=int, default=310, help="input dim is the same with sample's last dim")
46 | parser.add_argument("--hid_dim", type=int, default=64, help="hid dim is for hidden layer of lstm")
47 | parser.add_argument("--n_layers", type=int, default=1, help="num of layers of lstm")
48 | parser.add_argument("--epoch_fineTuning", type=int, default=500, help="epoch of the fine-tuning phase")
49 | parser.add_argument("--lr", type=int, default=1e-3, help="epoch of calModel")
50 | parser.add_argument("--weight_decay", type=float, default=0.0005, help="weight decay")
51 | parser.add_argument("--beta", type=float, default=0.05, help="balancing hyperparameter in the loss of pretraining phase")
52 |
53 |
54 | args = parser.parse_args()
55 | args.source_subjects = args.subjects-1
56 | args.seed3_path = "../eeg_data/ExtractedFeatures/"
57 | args.seed4_path = "../eeg_data/eeg_feature_smooth/"
58 | if cuda:
59 | args.num_workers_train = 4
60 | args.num_workers_test = 2
61 | else:
62 | args.num_workers_train = 0
63 | args.num_workers_test = 0
64 | if args.dataset_name == "seed3":
65 | args.path = args.seed3_path
66 | args.cls_classes = 3
67 | args.time_steps = 30
68 | args.batch_size = 512 #batch_size
69 | args.epoch_preTraining = 300 #epoch of the pre-training phase
70 | elif args.dataset_name == "seed4":
71 | args.path = args.seed4_path
72 | args.cls_classes = 4
73 | args.time_steps = 10
74 | args.batch_size = 256 #batch_size
75 | args.epoch_preTraining = 400 #epoch of the pre-training phase
76 | else:
77 | print("need to define the input dataset")
78 | optim_config = {"lr": args.lr, "weight_decay": args.weight_decay}
79 | # leave-one-subject-out cross-validation
80 | acc_list=[]
81 | writer = SummaryWriter("data/session"+args.session+"/"+args.way+"/" + args.index)
82 | for one_subject in range(0, args.subjects):
83 | # 1.data preparation
84 | source_loaders, test_loader = getDataLoaders(one_subject, args)
85 | data_loader_dict = {"source_loader": source_loaders, "test_loader":test_loader}
86 | # 2. main
87 | acc = main(data_loader_dict, args, optim_config, cuda, writer, one_subject)
88 | writer.add_scalars('single experiment acc: ',
89 | {'test acc': acc}, one_subject + 1)
90 | writer.flush()
91 | acc_list.append(acc)
92 | writer.add_text('final acc avg', str(np.average(acc_list)))
93 | writer.add_text('final acc std', str(np.std(acc_list)))
94 | acc_list_str = [str(x) for x in acc_list]
95 | writer.add_text('final each acc', ",".join(acc_list_str))
96 | writer.add_scalars('final experiment acc scala: /avg',
97 | {'test acc': np.average(acc_list)})
98 | writer.add_scalars('final experiment acc scala: /std',
99 | {'test acc': np.std(acc_list)})
100 | writer.close()
101 |
--------------------------------------------------------------------------------
/ablation/witoutMix.py:
--------------------------------------------------------------------------------
1 | from preprocess import getDataLoaders
2 | import math
3 | from torch.utils.tensorboard import SummaryWriter
4 | import argparse
5 | from train import *
6 | import random
7 | import os
8 | import numpy as np
9 |
10 | os.environ["CUDA_VISIBLE_DEVICES"] = '3'
11 |
12 | def set_seed(seed=3):
13 | random.seed(seed)
14 | np.random.seed(seed)
15 | os.environ['PYTHONHASHSEED'] = str(seed)
16 | torch.manual_seed(seed)
17 | torch.cuda.manual_seed(seed)
18 | torch.cuda.manual_seed_all(seed)
19 | torch.backends.cudnn.deterministic = True
20 |
21 | def main(data_loader_dict, args, optim_config, cuda, writer, one_subject, seed=3):
22 | set_seed(seed)
23 | if args.dataset_name == 'seed3':
24 | iteration = 7
25 | elif args.dataset_name == 'seed4':
26 | iteration = 3
27 | acc = trainDMMR_WithoutMix(data_loader_dict, optim_config, cuda, args, iteration, writer, one_subject)
28 | return acc
29 |
30 | if __name__ == '__main__':
31 | cuda = torch.cuda.is_available()
32 | parser = argparse.ArgumentParser(description='DMMR')
33 |
34 | #config of experiment
35 | parser.add_argument("--way", type=str, default='DMMR/seed3', help="name of current way")
36 | parser.add_argument("--index", type=str, default='withoutMix', help="tensorboard index")
37 |
38 | #config of dataset
39 | parser.add_argument("--dataset_name", type=str, nargs='?', default='seed3', help="the dataset name, supporting seed3 and seed4")
40 | parser.add_argument("--session", type=str, nargs='?', default='1', help="selected session")
41 | parser.add_argument("--subjects", type=int, choices=[15], default=15, help="the number of all subject")
42 | parser.add_argument("--dim", type=int, default=310, help="dim of input")
43 |
44 | #config of DMMR
45 | parser.add_argument("--input_dim", type=int, default=310, help="input dim is the same with sample's last dim")
46 | parser.add_argument("--hid_dim", type=int, default=64, help="hid dim is for hidden layer of lstm")
47 | parser.add_argument("--n_layers", type=int, default=1, help="num of layers of lstm")
48 | parser.add_argument("--epoch_fineTuning", type=int, default=500, help="epoch of the fine-tuning phase")
49 | parser.add_argument("--lr", type=int, default=1e-3, help="epoch of calModel")
50 | parser.add_argument("--weight_decay", type=float, default=0.0005, help="weight decay")
51 | parser.add_argument("--beta", type=float, default=0.05, help="balancing hyperparameter in the loss of pretraining phase")
52 |
53 |
54 | args = parser.parse_args()
55 | args.source_subjects = args.subjects-1
56 | args.seed3_path = "../eeg_data/ExtractedFeatures/"
57 | args.seed4_path = "../eeg_data/eeg_feature_smooth/"
58 | if cuda:
59 | args.num_workers_train = 4
60 | args.num_workers_test = 2
61 | else:
62 | args.num_workers_train = 0
63 | args.num_workers_test = 0
64 | if args.dataset_name == "seed3":
65 | args.path = args.seed3_path
66 | args.cls_classes = 3
67 | args.time_steps = 30
68 | args.batch_size = 512 #batch_size
69 | args.epoch_preTraining = 300 #epoch of the pre-training phase
70 | elif args.dataset_name == "seed4":
71 | args.path = args.seed4_path
72 | args.cls_classes = 4
73 | args.time_steps = 10
74 | args.batch_size = 256 #batch_size
75 | args.epoch_preTraining = 400 #epoch of the pre-training phase
76 | else:
77 | print("need to define the input dataset")
78 | optim_config = {"lr": args.lr, "weight_decay": args.weight_decay}
79 | # leave-one-subject-out cross-validation
80 | acc_list=[]
81 | writer = SummaryWriter("data/session"+args.session+"/"+args.way+"/" + args.index)
82 | for one_subject in range(0, args.subjects):
83 | # 1.data preparation
84 | source_loaders, test_loader = getDataLoaders(one_subject, args)
85 | data_loader_dict = {"source_loader": source_loaders, "test_loader":test_loader}
86 | # 2. main
87 | acc = main(data_loader_dict, args, optim_config, cuda, writer, one_subject)
88 | writer.add_scalars('single experiment acc: ',
89 | {'test acc': acc}, one_subject + 1)
90 | writer.flush()
91 | acc_list.append(acc)
92 | writer.add_text('final acc avg', str(np.average(acc_list)))
93 | writer.add_text('final acc std', str(np.std(acc_list)))
94 | acc_list_str = [str(x) for x in acc_list]
95 | writer.add_text('final each acc', ",".join(acc_list_str))
96 | writer.add_scalars('final experiment acc scala: /avg',
97 | {'test acc': np.average(acc_list)})
98 | writer.add_scalars('final experiment acc scala: /std',
99 | {'test acc': np.std(acc_list)})
100 | writer.close()
101 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from preprocess import getDataLoaders
2 | import math
3 | from torch.utils.tensorboard import SummaryWriter
4 | import argparse
5 | from train import *
6 | import random
7 | import os
8 | import numpy as np
9 |
10 | os.environ["CUDA_VISIBLE_DEVICES"] = '3'
11 |
12 | def set_seed(seed=3):
13 | random.seed(seed)
14 | np.random.seed(seed)
15 | os.environ['PYTHONHASHSEED'] = str(seed)
16 | torch.manual_seed(seed)
17 | torch.cuda.manual_seed(seed)
18 | torch.cuda.manual_seed_all(seed)
19 | torch.backends.cudnn.deterministic = True
20 |
21 | def main(data_loader_dict, args, optim_config, cuda, writer, one_subject, seed=3):
22 | set_seed(seed)
23 | if args.dataset_name == 'seed3':
24 | iteration = 7
25 | elif args.dataset_name == 'seed4':
26 | iteration = 3
27 | acc = trainDMMR(data_loader_dict, optim_config, cuda, args, iteration, writer, one_subject)
28 | return acc
29 |
30 | if __name__ == '__main__':
31 | cuda = torch.cuda.is_available()
32 | parser = argparse.ArgumentParser(description='DMMR')
33 |
34 | #config of experiment
35 | parser.add_argument("--way", type=str, default='DMMR/seed3', help="name of current way")
36 | parser.add_argument("--index", type=str, default='0', help="tensorboard index")
37 |
38 | #config of dataset
39 | parser.add_argument("--dataset_name", type=str, nargs='?', default='seed3', help="the dataset name, supporting seed3 and seed4")
40 | parser.add_argument("--session", type=str, nargs='?', default='1', help="selected session")
41 | parser.add_argument("--subjects", type=int, choices=[15], default=15, help="the number of all subject")
42 | parser.add_argument("--dim", type=int, default=310, help="dim of input")
43 |
44 | #config of DMMR
45 | parser.add_argument("--input_dim", type=int, default=310, help="input dim is the same with sample's last dim")
46 | parser.add_argument("--hid_dim", type=int, default=64, help="hid dim is for hidden layer of lstm")
47 | parser.add_argument("--n_layers", type=int, default=1, help="num of layers of lstm")
48 | parser.add_argument("--epoch_fineTuning", type=int, default=500, help="epoch of the fine-tuning phase")
49 | parser.add_argument("--lr", type=int, default=1e-3, help="epoch of calModel")
50 | parser.add_argument("--weight_decay", type=float, default=0.0005, help="weight decay")
51 | parser.add_argument("--beta", type=float, default=0.05, help="balancing hyperparameter in the loss of pretraining phase")
52 |
53 |
54 | args = parser.parse_args()
55 | args.source_subjects = args.subjects-1
56 | args.seed3_path = "../eeg_data/ExtractedFeatures/"
57 | args.seed4_path = "../eeg_data/eeg_feature_smooth/"
58 | if cuda:
59 | args.num_workers_train = 4
60 | args.num_workers_test = 2
61 | else:
62 | args.num_workers_train = 0
63 | args.num_workers_test = 0
64 | if args.dataset_name == "seed3":
65 | args.path = args.seed3_path
66 | args.cls_classes = 3
67 | args.time_steps = 30
68 | args.batch_size = 512 #batch_size
69 | args.epoch_preTraining = 300 #epoch of the pre-training phase
70 | elif args.dataset_name == "seed4":
71 | args.path = args.seed4_path
72 | args.cls_classes = 4
73 | args.time_steps = 10
74 | args.batch_size = 256 #batch_size
75 | args.epoch_preTraining = 400 #epoch of the pre-training phase
76 | else:
77 | print("need to define the input dataset")
78 | optim_config = {"lr": args.lr, "weight_decay": args.weight_decay}
79 | # leave-one-subject-out cross-validation
80 | acc_list=[]
81 | writer = SummaryWriter("data/session"+args.session+"/"+args.way+"/" + args.index)
82 | for one_subject in range(0, args.subjects):
83 | # 1.data preparation
84 | source_loaders, test_loader = getDataLoaders(one_subject, args)
85 | data_loader_dict = {"source_loader": source_loaders, "test_loader":test_loader}
86 | # 2. main
87 | acc = main(data_loader_dict, args, optim_config, cuda, writer, one_subject)
88 | writer.add_scalars('single experiment acc: ',
89 | {'test acc': acc}, one_subject + 1)
90 | writer.flush()
91 | acc_list.append(acc)
92 | writer.add_text('final acc avg', str(np.average(acc_list)))
93 | writer.add_text('final acc std', str(np.std(acc_list)))
94 | acc_list_str = [str(x) for x in acc_list]
95 | writer.add_text('final each acc', ",".join(acc_list_str))
96 | writer.add_scalars('final experiment acc scala: /avg',
97 | {'test acc': np.average(acc_list)})
98 | writer.add_scalars('final experiment acc scala: /std',
99 | {'test acc': np.std(acc_list)})
100 | writer.close()
101 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from GradientReverseLayer import ReverseLayerF
5 | import random
6 | import copy
7 |
8 | # The ABP module
9 | class Attention(nn.Module):
10 | def __init__(self, cuda, input_dim):
11 | super(Attention, self).__init__()
12 | self.input_dim = input_dim
13 | if cuda:
14 | self.w_linear = nn.Parameter(torch.randn(input_dim, input_dim).cuda())
15 | self.u_linear = nn.Parameter(torch.randn(input_dim).cuda())
16 | else:
17 | self.w_linear = nn.Parameter(torch.randn(input_dim, input_dim))
18 | self.u_linear = nn.Parameter(torch.randn(input_dim))
19 |
20 | def forward(self, x, batch_size, time_steps):
21 | x_reshape = torch.Tensor.reshape(x, [-1, self.input_dim])
22 | attn_softmax = F.softmax(torch.mm(x_reshape, self.w_linear)+ self.u_linear,1)
23 | res = torch.mul(attn_softmax, x_reshape)
24 | res = torch.Tensor.reshape(res, [batch_size, time_steps, self.input_dim])
25 | return res
26 |
27 | class LSTM(nn.Module):
28 | def __init__(self, input_dim=310, output_dim=64, layers=2, location=-1):
29 | super(LSTM, self).__init__()
30 | self.lstm = nn.LSTM(input_dim, output_dim, num_layers=layers, batch_first=True)
31 | self.location = location
32 | def forward(self, x):
33 | # self.lstm.flatten_parameters()
34 | feature, (hn, cn) = self.lstm(x)
35 | return feature[:, self.location, :], hn, cn
36 |
37 | class Encoder(nn.Module):
38 | def __init__(self, input_dim=310, hid_dim=64, n_layers=2):
39 | super(Encoder, self).__init__()
40 | self.theta = LSTM(input_dim, hid_dim, n_layers)
41 | def forward(self, x):
42 | x_h = self.theta(x)
43 | return x_h
44 |
45 | class Decoder(nn.Module):
46 | def __init__(self, input_dim=310, hid_dim=64, n_layers=2,output_dim=310):
47 | super(Decoder, self).__init__()
48 | self.rnn = nn.LSTM(input_dim, hid_dim, n_layers)
49 | self.fc_out = nn.Linear(hid_dim, output_dim)
50 | def forward(self, input, hidden, cell, time_steps):
51 | out =[]
52 | out1 = self.fc_out(input)
53 | out.append(out1)
54 | out1= out1.unsqueeze(0) # input = [batch size] to [1, batch size]
55 | for i in range(time_steps-1):
56 | output, (hidden, cell) = self.rnn(out1,
57 | (hidden, cell)) # output =[seq len, batch size, hid dim* ndirection]
58 | out_cur = self.fc_out(output.squeeze(0)) # prediction = [batch size, output dim]
59 | out.append(out_cur)
60 | out1 = out_cur.unsqueeze(0)
61 | out.reverse()
62 | out = torch.stack(out)
63 | out = out.transpose(1,0) #batch first
64 | return out, hidden, cell
65 |
66 |
67 | #namely The Subject Classifier SD
68 | class DomainClassifier(nn.Module):
69 | def __init__(self, input_dim =64, output_dim=14):
70 | super(DomainClassifier, self).__init__()
71 | self.classifier = nn.Linear(input_dim, output_dim)
72 |
73 | def forward(self, x):
74 | x = self.classifier(x)
75 | return x
76 |
77 | # The MSE loss
78 | class MSE(nn.Module):
79 | def __init__(self):
80 | super(MSE, self).__init__()
81 |
82 | def forward(self, pred, real):
83 | diffs = torch.add(real, -pred)
84 | n = torch.numel(diffs.data)
85 | mse = torch.sum(diffs.pow(2)) / n
86 | return mse
87 |
88 |
89 | def timeStepsShuffle(source_data):
90 | source_data_1 = source_data.clone()
91 | #retain the last time step
92 | curTimeStep_1 = source_data_1[:, -1, :]
93 | # get data of other time steps
94 | dim_size = source_data[:, :-1, :].size(1)
95 | # generate a random sequence
96 | idxs_1 = list(range(dim_size))
97 | # generate a shuffled sequence
98 | random.shuffle(idxs_1)
99 | # get data corresponding to the shuffled sequence
100 | else_1 = source_data_1[:, idxs_1, :]
101 | # add the origin last time step
102 | result_1 = torch.cat([else_1, curTimeStep_1.unsqueeze(1)], dim=1)
103 | return result_1
104 |
105 | def maskTimeSteps(source_data, rate):
106 | source_data_1 = source_data.clone()
107 | num_zeros = int(source_data.size(1) * rate)
108 | #mask certain rate of time steps ignoring the last
109 | zero_indices_1 = torch.randperm(source_data_1.size(1)-1)[:num_zeros]
110 | source_data_1[:, zero_indices_1,:] = 0
111 | return source_data_1
112 |
113 | def maskChannels(source_data, args, rate):
114 | # reshape for operating the channel dimension
115 | source_data_reshaped = source_data.reshape(args.batch_size, args.time_steps, 5, 62)
116 | source_data_reshaped_1 = source_data_reshaped.clone()
117 | num_zeros = int(source_data_reshaped.size(-1) * rate)
118 | # mask certain rate of channels
119 | zero_indices_1 = torch.randperm(source_data_reshaped_1.size(-1))[:num_zeros]
120 | source_data_reshaped_1[..., zero_indices_1] = 0
121 | source_data_reshaped_1 = source_data_reshaped_1.reshape(args.batch_size, args.time_steps, 310)
122 | return source_data_reshaped_1
123 |
124 | def shuffleChannels(source_data, args):
125 | # reshape for operating the channel dimension
126 | source_data_reshaped = source_data.reshape(args.batch_size, args.time_steps, 5, 62)
127 | source_data_reshaped_1 = source_data_reshaped.clone()
128 | dim_size = source_data_reshaped[..., :].size(-1)
129 | # # generate a random sequence
130 | idxs_1 = list(range(dim_size))
131 | random.shuffle(idxs_1)
132 | # shuffle channels
133 | source_data_reshaped_1 = source_data_reshaped_1[..., idxs_1]
134 | result_1 = source_data_reshaped_1.reshape(args.batch_size, args.time_steps, 310)
135 | return result_1
136 |
137 | # proposed DMMR model
138 | class DMMRPreTrainingModel(nn.Module):
139 | def __init__(self, cuda, number_of_source=14, number_of_category=3, batch_size=10, time_steps=15):
140 | super(DMMRPreTrainingModel, self).__init__()
141 | self.batch_size = batch_size
142 | self.time_steps = time_steps
143 | self.number_of_source = number_of_source
144 | self.attentionLayer = Attention(cuda, input_dim=310)
145 | self.sharedEncoder = Encoder(input_dim=310, hid_dim=64, n_layers=1)
146 | self.mse = MSE()
147 | self.domainClassifier = DomainClassifier(input_dim=64, output_dim=14)
148 | for i in range(number_of_source):
149 | exec('self.decoder' + str(i) + '=Decoder(input_dim=310, hid_dim=64, n_layers=1, output_dim=310)')
150 | def forward(self, x, corres, subject_id, m=0, mark=0):
151 | # Noise Injection, with the proposed method Time Steps Shuffling
152 | x = timeStepsShuffle(x)
153 | # The ABP module
154 | x = self.attentionLayer(x, x.shape[0], self.time_steps)
155 | # Encoder the weighted features with one-layer LSTM
156 | shared_last_out, shared_hn, shared_cn = self.sharedEncoder(x)
157 | # The DG_DANN module
158 | # The GRL layer in the first stage
159 | reverse_feature = ReverseLayerF.apply(shared_last_out, m)
160 | # The Subject Discriminator
161 | subject_predict = self.domainClassifier(reverse_feature)
162 | subject_predict = F.log_softmax(subject_predict,dim=1)
163 | # The domain adversarial loss
164 | sim_loss = F.nll_loss(subject_predict, subject_id)
165 |
166 | # Build Supervision for Decoders, the inputs are also weighted
167 | corres = self.attentionLayer(corres, corres.shape[0], self.time_steps)
168 | splitted_tensors = torch.chunk(corres, self.number_of_source, dim=0)
169 | rec_loss = 0
170 | mixSubjectFeature = 0
171 | for i in range(self.number_of_source):
172 | # Reconstruct features in the first stage
173 | x_out, *_ = eval('self.decoder' + str(i))(shared_last_out, shared_hn, shared_cn, self.time_steps)
174 | # The proposed mix method for data augmentation
175 | mixSubjectFeature += x_out
176 | shared_last_out_2, shared_hn_2, shared_cn_2 = self.sharedEncoder(mixSubjectFeature)
177 | for i in range(self.number_of_source):
178 | # Reconstruct features in the second stage
179 | x_out, *_ = eval('self.decoder' + str(i))(shared_last_out_2, shared_hn_2, shared_cn_2, self.time_steps)
180 | # Compute the reconstructive loss in the second stage only
181 | rec_loss += self.mse(x_out, splitted_tensors[i])
182 | return rec_loss, sim_loss
183 | class DMMRFineTuningModel(nn.Module):
184 | def __init__(self, cuda, baseModel, number_of_source=14, number_of_category=3, batch_size=10, time_steps=15):
185 | super(DMMRFineTuningModel, self).__init__()
186 | self.baseModel = copy.deepcopy(baseModel)
187 | self.batch_size = batch_size
188 | self.time_steps = time_steps
189 | self.number_of_source = number_of_source
190 | # The ABP module and sharedEncoder are from the pretrained model
191 | self.attentionLayer = self.baseModel.attentionLayer
192 | self.sharedEncoder = self.baseModel.sharedEncoder
193 | # Add a new emotion classifier for emotion recognition
194 | self.cls_fc = nn.Sequential(nn.Linear(64, 64, bias=False), nn.BatchNorm1d(64),
195 | nn.ReLU(inplace=True), nn.Linear(64, number_of_category, bias=True))
196 | self.mse = MSE()
197 | for i in range(number_of_source):
198 | exec('self.decoder' + str(i) + '=Decoder(input_dim=310, hid_dim=64, n_layers=1, output_dim=310)')
199 | def forward(self, x, label_src=0):
200 | x = self.attentionLayer(x, x.shape[0], self.time_steps)
201 | shared_last_out, shared_hn, shared_cn = self.sharedEncoder(x)
202 | x_logits = self.cls_fc(shared_last_out)
203 | x_pred = F.log_softmax(x_logits, dim=1)
204 | cls_loss = F.nll_loss(x_pred, label_src.squeeze())
205 | return x_pred, x_logits, cls_loss
206 |
207 | class DMMRTestModel(nn.Module):
208 | def __init__(self, baseModel):
209 | super(DMMRTestModel, self).__init__()
210 | self.baseModel = copy.deepcopy(baseModel)
211 | def forward(self, x):
212 | x = self.baseModel.attentionLayer(x, self.baseModel.batch_size, self.baseModel.time_steps)
213 | shared_last_out, shared_hn, shared_cn = self.baseModel.sharedEncoder(x)
214 | x_shared_logits = self.baseModel.cls_fc(shared_last_out)
215 | return x_shared_logits
216 |
217 | ############## other noisy injection methods ##############
218 | class PreTrainingWithMaskTimeSteps(nn.Module):
219 | def __init__(self, cuda, number_of_source=14, number_of_category=3, batch_size=10, time_steps=15):
220 | super(PreTrainingWithMaskTimeSteps, self).__init__()
221 | self.batch_size = batch_size
222 | self.time_steps = time_steps
223 | self.number_of_source = number_of_source
224 | self.attentionLayer = Attention(cuda, input_dim=310)
225 | self.sharedEncoder = Encoder(input_dim=310, hid_dim=64, n_layers=1)
226 | self.mse = MSE()
227 | self.domainClassifier = DomainClassifier(input_dim=64, output_dim=14)
228 | for i in range(number_of_source):
229 | exec('self.decoder' + str(i) + '=Decoder(input_dim=310, hid_dim=64, n_layers=1, output_dim=310)')
230 | def forward(self, x, corres, subject_id, args, m=0, mark=0):
231 | x = maskTimeSteps(x, 0.2)
232 | x = self.attentionLayer(x, x.shape[0], self.time_steps)
233 | shared_last_out, shared_hn, shared_cn = self.sharedEncoder(x)
234 |
235 | reverse_feature = ReverseLayerF.apply(shared_last_out, m)
236 | subject_predict = self.domainClassifier(reverse_feature)
237 | subject_predict = F.log_softmax(subject_predict,dim=1)
238 | sim_loss = F.nll_loss(subject_predict, subject_id)
239 |
240 | corres = self.attentionLayer(corres, corres.shape[0], self.time_steps)
241 | splitted_tensors = torch.chunk(corres, self.number_of_source, dim=0)
242 | rec_loss = 0
243 | mixSubjectFeature = 0
244 | for i in range(self.number_of_source):
245 | x_out, *_ = eval('self.decoder' + str(i))(shared_last_out, shared_hn, shared_cn, self.time_steps)
246 | mixSubjectFeature += x_out
247 | shared_last_out_2, shared_hn_2, shared_cn_2 = self.sharedEncoder(mixSubjectFeature)
248 | for i in range(self.number_of_source):
249 | x_out, *_ = eval('self.decoder' + str(i))(shared_last_out_2, shared_hn_2, shared_cn_2, self.time_steps)
250 | rec_loss += self.mse(x_out, splitted_tensors[i])
251 | return rec_loss, sim_loss
252 |
253 | class PreTrainingWithMaskChannels(nn.Module):
254 | def __init__(self, cuda, number_of_source=14, number_of_category=3, batch_size=10, time_steps=15):
255 | super(PreTrainingWithMaskChannels, self).__init__()
256 | self.batch_size = batch_size
257 | self.time_steps = time_steps
258 | self.number_of_source = number_of_source
259 | self.attentionLayer = Attention(cuda, input_dim=310)
260 | self.sharedEncoder = Encoder(input_dim=310, hid_dim=64, n_layers=1)
261 | self.mse = MSE()
262 | self.domainClassifier = DomainClassifier(input_dim=64, output_dim=14)
263 | for i in range(number_of_source):
264 | exec('self.decoder' + str(i) + '=Decoder(input_dim=310, hid_dim=64, n_layers=1, output_dim=310)')
265 | def forward(self, x, corres, subject_id, args, m=0, mark=0):
266 | x = maskChannels(x, args, 0.2)
267 | x = self.attentionLayer(x, x.shape[0], self.time_steps)
268 | shared_last_out, shared_hn, shared_cn = self.sharedEncoder(x)
269 |
270 | reverse_feature = ReverseLayerF.apply(shared_last_out, m)
271 | subject_predict = self.domainClassifier(reverse_feature)
272 | subject_predict = F.log_softmax(subject_predict,dim=1)
273 | sim_loss = F.nll_loss(subject_predict, subject_id)
274 |
275 | corres = self.attentionLayer(corres, corres.shape[0], self.time_steps)
276 | splitted_tensors = torch.chunk(corres, self.number_of_source, dim=0)
277 | rec_loss = 0
278 | mixSubjectFeature = 0
279 | for i in range(self.number_of_source):
280 | x_out, *_ = eval('self.decoder' + str(i))(shared_last_out, shared_hn, shared_cn, self.time_steps)
281 | mixSubjectFeature += x_out
282 | shared_last_out_2, shared_hn_2, shared_cn_2 = self.sharedEncoder(mixSubjectFeature)
283 | for i in range(self.number_of_source):
284 | x_out, *_ = eval('self.decoder' + str(i))(shared_last_out_2, shared_hn_2, shared_cn_2, self.time_steps)
285 | rec_loss += self.mse(x_out, splitted_tensors[i])
286 | return rec_loss, sim_loss
287 |
288 | class PreTrainingWithChannelsShuffling(nn.Module):
289 | def __init__(self, cuda, number_of_source=14, number_of_category=3, batch_size=10, time_steps=15):
290 | super(PreTrainingWithChannelsShuffling, self).__init__()
291 | self.batch_size = batch_size
292 | self.time_steps = time_steps
293 | self.number_of_source = number_of_source
294 | self.attentionLayer = Attention(cuda, input_dim=310)
295 | self.sharedEncoder = Encoder(input_dim=310, hid_dim=64, n_layers=1)
296 | self.mse = MSE()
297 | self.domainClassifier = DomainClassifier(input_dim=64, output_dim=14)
298 | for i in range(number_of_source):
299 | exec('self.decoder' + str(i) + '=Decoder(input_dim=310, hid_dim=64, n_layers=1, output_dim=310)')
300 | def forward(self, x, corres, subject_id, args, m=0, mark=0):
301 | x = shuffleChannels(x, args)
302 | x = self.attentionLayer(x, x.shape[0], self.time_steps)
303 | shared_last_out, shared_hn, shared_cn = self.sharedEncoder(x)
304 |
305 | reverse_feature = ReverseLayerF.apply(shared_last_out, m)
306 | subject_predict = self.domainClassifier(reverse_feature)
307 | subject_predict = F.log_softmax(subject_predict,dim=1)
308 | sim_loss = F.nll_loss(subject_predict, subject_id)
309 |
310 | corres = self.attentionLayer(corres, corres.shape[0], self.time_steps)
311 | splitted_tensors = torch.chunk(corres, self.number_of_source, dim=0)
312 | rec_loss = 0
313 | mixSubjectFeature = 0
314 | for i in range(self.number_of_source):
315 | x_out, *_ = eval('self.decoder' + str(i))(shared_last_out, shared_hn, shared_cn, self.time_steps)
316 | mixSubjectFeature += x_out
317 | shared_last_out_2, shared_hn_2, shared_cn_2 = self.sharedEncoder(mixSubjectFeature)
318 | for i in range(self.number_of_source):
319 | x_out, *_ = eval('self.decoder' + str(i))(shared_last_out_2, shared_hn_2, shared_cn_2, self.time_steps)
320 | rec_loss += self.mse(x_out, splitted_tensors[i])
321 | return rec_loss, sim_loss
322 |
323 | class PreTrainingWithDropout(nn.Module):
324 | def __init__(self, cuda, number_of_source=14, number_of_category=3, batch_size=10, time_steps=15, dropout_rate=0.2):
325 | super(PreTrainingWithDropout, self).__init__()
326 | self.batch_size = batch_size
327 | self.time_steps = time_steps
328 | self.number_of_source = number_of_source
329 | self.attentionLayer = Attention(cuda, input_dim=310)
330 | self.sharedEncoder = Encoder(input_dim=310, hid_dim=64, n_layers=1)
331 | self.mse = MSE()
332 | self.domainClassifier = DomainClassifier(input_dim=64, output_dim=14)
333 | self.dropout = nn.Dropout(dropout_rate) # noise
334 | for i in range(number_of_source):
335 | exec('self.decoder' + str(i) + '=Decoder(input_dim=310, hid_dim=64, n_layers=1, output_dim=310)')
336 | def forward(self, x, corres, subject_id, args, m=0, mark=0):
337 | x = self.dropout(x)
338 | x = self.attentionLayer(x, x.shape[0], self.time_steps)
339 | shared_last_out, shared_hn, shared_cn = self.sharedEncoder(x)
340 |
341 | reverse_feature = ReverseLayerF.apply(shared_last_out, m)
342 | subject_predict = self.domainClassifier(reverse_feature)
343 | subject_predict = F.log_softmax(subject_predict,dim=1)
344 | sim_loss = F.nll_loss(subject_predict, subject_id)
345 |
346 | corres = self.attentionLayer(corres, corres.shape[0], self.time_steps)
347 | splitted_tensors = torch.chunk(corres, self.number_of_source, dim=0)
348 | rec_loss = 0
349 | mixSubjectFeature = 0
350 | for i in range(self.number_of_source):
351 | x_out, *_ = eval('self.decoder' + str(i))(shared_last_out, shared_hn, shared_cn, self.time_steps)
352 | mixSubjectFeature += x_out
353 | shared_last_out_2, shared_hn_2, shared_cn_2 = self.sharedEncoder(mixSubjectFeature)
354 | for i in range(self.number_of_source):
355 | x_out, *_ = eval('self.decoder' + str(i))(shared_last_out_2, shared_hn_2, shared_cn_2, self.time_steps)
356 | rec_loss += self.mse(x_out, splitted_tensors[i])
357 | return rec_loss, sim_loss
358 |
359 |
360 | ############## noiseInjectionMethods stydy ##############
361 | #w/o mix
362 | class PreTrainingWithoutMix(nn.Module):
363 | def __init__(self, cuda, number_of_source=14, number_of_category=3, batch_size=10, time_steps=15):
364 | super(PreTrainingWithoutMix, self).__init__()
365 | self.batch_size = batch_size
366 | self.time_steps = time_steps
367 | self.number_of_source = number_of_source
368 | self.attentionLayer = Attention(cuda, input_dim=310)
369 | self.sharedEncoder = Encoder(input_dim=310, hid_dim=64, n_layers=1)
370 | self.mse = MSE()
371 | self.domainClassifier = DomainClassifier(input_dim=64, output_dim=14)
372 | for i in range(number_of_source):
373 | exec('self.decoder' + str(i) + '=Decoder(input_dim=310, hid_dim=64, n_layers=1, output_dim=310)')
374 | def forward(self, x, corres, subject_id, m=0, mark=0):
375 | x = timeStepsShuffle(x)
376 | x = self.attentionLayer(x, x.shape[0], self.time_steps)
377 | shared_last_out, shared_hn, shared_cn = self.sharedEncoder(x)
378 |
379 | reverse_feature = ReverseLayerF.apply(shared_last_out, m)
380 | subject_predict = self.domainClassifier(reverse_feature)
381 | subject_predict = F.log_softmax(subject_predict,dim=1)
382 | sim_loss = F.nll_loss(subject_predict, subject_id)
383 |
384 | corres = self.attentionLayer(corres, corres.shape[0], self.time_steps)
385 | splitted_tensors = torch.chunk(corres, self.number_of_source, dim=0)
386 | rec_loss = 0
387 | for i in range(self.number_of_source):
388 | x_out, *_ = eval('self.decoder' + str(i))(shared_last_out, shared_hn, shared_cn, self.time_steps)
389 | rec_loss += self.mse(x_out, splitted_tensors[i])
390 | return rec_loss, sim_loss
391 | #w/o noise
392 | class PreTrainingWithoutNoise(nn.Module):
393 | def __init__(self, cuda, number_of_source=14, number_of_category=3, batch_size=10, time_steps=15):
394 | super(PreTrainingWithoutNoise, self).__init__()
395 | self.batch_size = batch_size
396 | self.time_steps = time_steps
397 | self.number_of_source = number_of_source
398 | self.attentionLayer = Attention(cuda, input_dim=310)
399 | self.sharedEncoder = Encoder(input_dim=310, hid_dim=64, n_layers=1)
400 | self.mse = MSE()
401 | self.domainClassifier = DomainClassifier(input_dim=64, output_dim=14)
402 | for i in range(number_of_source):
403 | exec('self.decoder' + str(i) + '=Decoder(input_dim=310, hid_dim=64, n_layers=1, output_dim=310)')
404 | def forward(self, x, corres, subject_id, m=0, mark=0):
405 | x = self.attentionLayer(x, x.shape[0], self.time_steps)
406 | shared_last_out, shared_hn, shared_cn = self.sharedEncoder(x)
407 |
408 | reverse_feature = ReverseLayerF.apply(shared_last_out, m)
409 | subject_predict = self.domainClassifier(reverse_feature)
410 | subject_predict = F.log_softmax(subject_predict,dim=1)
411 | sim_loss = F.nll_loss(subject_predict, subject_id)
412 |
413 | corres = self.attentionLayer(corres, corres.shape[0], self.time_steps)
414 | splitted_tensors = torch.chunk(corres, self.number_of_source, dim=0)
415 | rec_loss = 0
416 | mixSubjectFeature = 0
417 | for i in range(self.number_of_source):
418 | x_out, *_ = eval('self.decoder' + str(i))(shared_last_out, shared_hn, shared_cn, self.time_steps)
419 | mixSubjectFeature += x_out
420 | shared_last_out_2, shared_hn_2, shared_cn_2 = self.sharedEncoder(mixSubjectFeature)
421 | for i in range(self.number_of_source):
422 | x_out, *_ = eval('self.decoder' + str(i))(shared_last_out_2, shared_hn_2, shared_cn_2, self.time_steps)
423 | rec_loss += self.mse(x_out, splitted_tensors[i])
424 | return rec_loss, sim_loss
425 | #w/o both
426 | class PreTrainingWithoutBothMixAndNoise(nn.Module):
427 | def __init__(self, cuda, number_of_source=14, number_of_category=3, batch_size=10, time_steps=15):
428 | super(PreTrainingWithoutBothMixAndNoise, self).__init__()
429 | self.batch_size = batch_size
430 | self.time_steps = time_steps
431 | self.number_of_source = number_of_source
432 | self.attentionLayer = Attention(cuda, input_dim=310)
433 | self.sharedEncoder = Encoder(input_dim=310, hid_dim=64, n_layers=1)
434 | self.mse = MSE()
435 | self.domainClassifier = DomainClassifier(input_dim=64, output_dim=14)
436 | for i in range(number_of_source):
437 | exec('self.decoder' + str(i) + '=Decoder(input_dim=310, hid_dim=64, n_layers=1, output_dim=310)')
438 | def forward(self, x, corres, subject_id, m=0, mark=0):
439 | x = self.attentionLayer(x, x.shape[0], self.time_steps)
440 | shared_last_out, shared_hn, shared_cn = self.sharedEncoder(x)
441 |
442 | reverse_feature = ReverseLayerF.apply(shared_last_out, m)
443 | subject_predict = self.domainClassifier(reverse_feature)
444 | subject_predict = F.log_softmax(subject_predict,dim=1)
445 | sim_loss = F.nll_loss(subject_predict, subject_id)
446 |
447 | corres = self.attentionLayer(corres, corres.shape[0], self.time_steps)
448 | splitted_tensors = torch.chunk(corres, self.number_of_source, dim=0)
449 | rec_loss = 0
450 | for i in range(self.number_of_source):
451 | x_out, *_ = eval('self.decoder' + str(i))(shared_last_out, shared_hn, shared_cn, self.time_steps)
452 | rec_loss += self.mse(x_out, splitted_tensors[i])
453 | return rec_loss, sim_loss
454 |
455 | #return feature of shared feature for T_SNE plots
456 | class ModelReturnFeatures(nn.Module):
457 | def __init__(self, baseModel, time_steps=15):
458 | super(ModelReturnFeatures, self).__init__()
459 | self.baseModel = baseModel
460 | self.time_steps = time_steps
461 | def forward(self, x):
462 | x = self.baseModel.attentionLayer(x, x.shape[0], self.time_steps)
463 | shared_last_out, shared_hn, shared_cn = self.baseModel.sharedEncoder(x)
464 | return x, shared_last_out
465 |
466 |
467 |
--------------------------------------------------------------------------------
/noiseInjectionMethods/Dropout.py:
--------------------------------------------------------------------------------
1 | from preprocess import getDataLoaders
2 | import math
3 | from torch.utils.tensorboard import SummaryWriter
4 | import argparse
5 | from train import *
6 | import random
7 | import os
8 | import numpy as np
9 |
10 | os.environ["CUDA_VISIBLE_DEVICES"] = '3'
11 |
12 | def set_seed(seed=3):
13 | random.seed(seed)
14 | np.random.seed(seed)
15 | os.environ['PYTHONHASHSEED'] = str(seed)
16 | torch.manual_seed(seed)
17 | torch.cuda.manual_seed(seed)
18 | torch.cuda.manual_seed_all(seed)
19 | torch.backends.cudnn.deterministic = True
20 |
21 | def main(data_loader_dict, args, optim_config, cuda, writer, one_subject, seed=3):
22 | set_seed(seed)
23 | if args.dataset_name == 'seed3':
24 | iteration = 7
25 | elif args.dataset_name == 'seed4':
26 | iteration = 3
27 | acc = trainDMMR_Noise_Dropout(data_loader_dict, optim_config, cuda, args, iteration, writer, one_subject)
28 | return acc
29 |
30 | if __name__ == '__main__':
31 | cuda = torch.cuda.is_available()
32 | parser = argparse.ArgumentParser(description='DMMR')
33 |
34 | #config of experiment
35 | parser.add_argument("--way", type=str, default='DMMR/seed3', help="name of current way")
36 | parser.add_argument("--index", type=str, default='Dropout', help="tensorboard index")
37 |
38 | #config of dataset
39 | parser.add_argument("--dataset_name", type=str, nargs='?', default='seed3', help="the dataset name, supporting seed3 and seed4")
40 | parser.add_argument("--session", type=str, nargs='?', default='1', help="selected session")
41 | parser.add_argument("--subjects", type=int, choices=[15], default=15, help="the number of all subject")
42 | parser.add_argument("--dim", type=int, default=310, help="dim of input")
43 |
44 | #config of DMMR
45 | parser.add_argument("--input_dim", type=int, default=310, help="input dim is the same with sample's last dim")
46 | parser.add_argument("--hid_dim", type=int, default=64, help="hid dim is for hidden layer of lstm")
47 | parser.add_argument("--n_layers", type=int, default=1, help="num of layers of lstm")
48 | parser.add_argument("--epoch_fineTuning", type=int, default=500, help="epoch of the fine-tuning phase")
49 | parser.add_argument("--lr", type=int, default=1e-3, help="epoch of calModel")
50 | parser.add_argument("--weight_decay", type=float, default=0.0005, help="weight decay")
51 | parser.add_argument("--beta", type=float, default=0.05, help="balancing hyperparameter in the loss of pretraining phase")
52 |
53 |
54 | args = parser.parse_args()
55 | args.source_subjects = args.subjects-1
56 | args.seed3_path = "../eeg_data/ExtractedFeatures/"
57 | args.seed4_path = "../eeg_data/eeg_feature_smooth/"
58 | if cuda:
59 | args.num_workers_train = 4
60 | args.num_workers_test = 2
61 | else:
62 | args.num_workers_train = 0
63 | args.num_workers_test = 0
64 | if args.dataset_name == "seed3":
65 | args.path = args.seed3_path
66 | args.cls_classes = 3
67 | args.time_steps = 30
68 | args.batch_size = 512 #batch_size
69 | args.epoch_preTraining = 300 #epoch of the pre-training phase
70 | elif args.dataset_name == "seed4":
71 | args.path = args.seed4_path
72 | args.cls_classes = 4
73 | args.time_steps = 10
74 | args.batch_size = 256 #batch_size
75 | args.epoch_preTraining = 400 #epoch of the pre-training phase
76 | else:
77 | print("need to define the input dataset")
78 | optim_config = {"lr": args.lr, "weight_decay": args.weight_decay}
79 | # leave-one-subject-out cross-validation
80 | acc_list=[]
81 | writer = SummaryWriter("data/session"+args.session+"/"+args.way+"/" + args.index)
82 | for one_subject in range(0, args.subjects):
83 | # 1.data preparation
84 | source_loaders, test_loader = getDataLoaders(one_subject, args)
85 | data_loader_dict = {"source_loader": source_loaders, "test_loader":test_loader}
86 | # 2. main
87 | acc = main(data_loader_dict, args, optim_config, cuda, writer, one_subject)
88 | writer.add_scalars('single experiment acc: ',
89 | {'test acc': acc}, one_subject + 1)
90 | writer.flush()
91 | acc_list.append(acc)
92 | writer.add_text('final acc avg', str(np.average(acc_list)))
93 | writer.add_text('final acc std', str(np.std(acc_list)))
94 | acc_list_str = [str(x) for x in acc_list]
95 | writer.add_text('final each acc', ",".join(acc_list_str))
96 | writer.add_scalars('final experiment acc scala: /avg',
97 | {'test acc': np.average(acc_list)})
98 | writer.add_scalars('final experiment acc scala: /std',
99 | {'test acc': np.std(acc_list)})
100 | writer.close()
101 |
--------------------------------------------------------------------------------
/noiseInjectionMethods/channelsShuffling.py:
--------------------------------------------------------------------------------
1 | from preprocess import getDataLoaders
2 | import math
3 | from torch.utils.tensorboard import SummaryWriter
4 | import argparse
5 | from train import *
6 | import random
7 | import os
8 | import numpy as np
9 |
10 | os.environ["CUDA_VISIBLE_DEVICES"] = '3'
11 |
12 | def set_seed(seed=3):
13 | random.seed(seed)
14 | np.random.seed(seed)
15 | os.environ['PYTHONHASHSEED'] = str(seed)
16 | torch.manual_seed(seed)
17 | torch.cuda.manual_seed(seed)
18 | torch.cuda.manual_seed_all(seed)
19 | torch.backends.cudnn.deterministic = True
20 |
21 | def main(data_loader_dict, args, optim_config, cuda, writer, one_subject, seed=3):
22 | set_seed(seed)
23 | if args.dataset_name == 'seed3':
24 | iteration = 7
25 | elif args.dataset_name == 'seed4':
26 | iteration = 3
27 | acc = trainDMMR_Noise_ChannelsShuffling(data_loader_dict, optim_config, cuda, args, iteration, writer, one_subject)
28 | return acc
29 |
30 | if __name__ == '__main__':
31 | cuda = torch.cuda.is_available()
32 | parser = argparse.ArgumentParser(description='DMMR')
33 |
34 | #config of experiment
35 | parser.add_argument("--way", type=str, default='DMMR/seed3', help="name of current way")
36 | parser.add_argument("--index", type=str, default='channelsShuffling', help="tensorboard index")
37 |
38 | #config of dataset
39 | parser.add_argument("--dataset_name", type=str, nargs='?', default='seed3', help="the dataset name, supporting seed3 and seed4")
40 | parser.add_argument("--session", type=str, nargs='?', default='1', help="selected session")
41 | parser.add_argument("--subjects", type=int, choices=[15], default=15, help="the number of all subject")
42 | parser.add_argument("--dim", type=int, default=310, help="dim of input")
43 |
44 | #config of DMMR
45 | parser.add_argument("--input_dim", type=int, default=310, help="input dim is the same with sample's last dim")
46 | parser.add_argument("--hid_dim", type=int, default=64, help="hid dim is for hidden layer of lstm")
47 | parser.add_argument("--n_layers", type=int, default=1, help="num of layers of lstm")
48 | parser.add_argument("--epoch_fineTuning", type=int, default=500, help="epoch of the fine-tuning phase")
49 | parser.add_argument("--lr", type=int, default=1e-3, help="epoch of calModel")
50 | parser.add_argument("--weight_decay", type=float, default=0.0005, help="weight decay")
51 | parser.add_argument("--beta", type=float, default=0.05, help="balancing hyperparameter in the loss of pretraining phase")
52 |
53 |
54 | args = parser.parse_args()
55 | args.source_subjects = args.subjects-1
56 | args.seed3_path = "../eeg_data/ExtractedFeatures/"
57 | args.seed4_path = "../eeg_data/eeg_feature_smooth/"
58 | if cuda:
59 | args.num_workers_train = 4
60 | args.num_workers_test = 2
61 | else:
62 | args.num_workers_train = 0
63 | args.num_workers_test = 0
64 | if args.dataset_name == "seed3":
65 | args.path = args.seed3_path
66 | args.cls_classes = 3
67 | args.time_steps = 30
68 | args.batch_size = 512 #batch_size
69 | args.epoch_preTraining = 300 #epoch of the pre-training phase
70 | elif args.dataset_name == "seed4":
71 | args.path = args.seed4_path
72 | args.cls_classes = 4
73 | args.time_steps = 10
74 | args.batch_size = 256 #batch_size
75 | args.epoch_preTraining = 400 #epoch of the pre-training phase
76 | else:
77 | print("need to define the input dataset")
78 | optim_config = {"lr": args.lr, "weight_decay": args.weight_decay}
79 | # leave-one-subject-out cross-validation
80 | acc_list=[]
81 | writer = SummaryWriter("data/session"+args.session+"/"+args.way+"/" + args.index)
82 | for one_subject in range(0, args.subjects):
83 | # 1.data preparation
84 | source_loaders, test_loader = getDataLoaders(one_subject, args)
85 | data_loader_dict = {"source_loader": source_loaders, "test_loader":test_loader}
86 | # 2. main
87 | acc = main(data_loader_dict, args, optim_config, cuda, writer, one_subject)
88 | writer.add_scalars('single experiment acc: ',
89 | {'test acc': acc}, one_subject + 1)
90 | writer.flush()
91 | acc_list.append(acc)
92 | writer.add_text('final acc avg', str(np.average(acc_list)))
93 | writer.add_text('final acc std', str(np.std(acc_list)))
94 | acc_list_str = [str(x) for x in acc_list]
95 | writer.add_text('final each acc', ",".join(acc_list_str))
96 | writer.add_scalars('final experiment acc scala: /avg',
97 | {'test acc': np.average(acc_list)})
98 | writer.add_scalars('final experiment acc scala: /std',
99 | {'test acc': np.std(acc_list)})
100 | writer.close()
101 |
--------------------------------------------------------------------------------
/noiseInjectionMethods/maskChannels.py:
--------------------------------------------------------------------------------
1 | from preprocess import getDataLoaders
2 | import math
3 | from torch.utils.tensorboard import SummaryWriter
4 | import argparse
5 | from train import *
6 | import random
7 | import os
8 | import numpy as np
9 |
10 | os.environ["CUDA_VISIBLE_DEVICES"] = '3'
11 |
12 | def set_seed(seed=3):
13 | random.seed(seed)
14 | np.random.seed(seed)
15 | os.environ['PYTHONHASHSEED'] = str(seed)
16 | torch.manual_seed(seed)
17 | torch.cuda.manual_seed(seed)
18 | torch.cuda.manual_seed_all(seed)
19 | torch.backends.cudnn.deterministic = True
20 |
21 | def main(data_loader_dict, args, optim_config, cuda, writer, one_subject, seed=3):
22 | set_seed(seed)
23 | if args.dataset_name == 'seed3':
24 | iteration = 7
25 | elif args.dataset_name == 'seed4':
26 | iteration = 3
27 | acc = trainDMMR_Noise_MaskChannels(data_loader_dict, optim_config, cuda, args, iteration, writer, one_subject)
28 | return acc
29 |
30 | if __name__ == '__main__':
31 | cuda = torch.cuda.is_available()
32 | parser = argparse.ArgumentParser(description='DMMR')
33 |
34 | #config of experiment
35 | parser.add_argument("--way", type=str, default='DMMR/seed3', help="name of current way")
36 | parser.add_argument("--index", type=str, default='maskChannels', help="tensorboard index")
37 |
38 | #config of dataset
39 | parser.add_argument("--dataset_name", type=str, nargs='?', default='seed3', help="the dataset name, supporting seed3 and seed4")
40 | parser.add_argument("--session", type=str, nargs='?', default='1', help="selected session")
41 | parser.add_argument("--subjects", type=int, choices=[15], default=15, help="the number of all subject")
42 | parser.add_argument("--dim", type=int, default=310, help="dim of input")
43 |
44 | #config of DMMR
45 | parser.add_argument("--input_dim", type=int, default=310, help="input dim is the same with sample's last dim")
46 | parser.add_argument("--hid_dim", type=int, default=64, help="hid dim is for hidden layer of lstm")
47 | parser.add_argument("--n_layers", type=int, default=1, help="num of layers of lstm")
48 | parser.add_argument("--epoch_fineTuning", type=int, default=500, help="epoch of the fine-tuning phase")
49 | parser.add_argument("--lr", type=int, default=1e-3, help="epoch of calModel")
50 | parser.add_argument("--weight_decay", type=float, default=0.0005, help="weight decay")
51 | parser.add_argument("--beta", type=float, default=0.05, help="balancing hyperparameter in the loss of pretraining phase")
52 |
53 |
54 | args = parser.parse_args()
55 | args.source_subjects = args.subjects-1
56 | args.seed3_path = "../eeg_data/ExtractedFeatures/"
57 | args.seed4_path = "../eeg_data/eeg_feature_smooth/"
58 | if cuda:
59 | args.num_workers_train = 4
60 | args.num_workers_test = 2
61 | else:
62 | args.num_workers_train = 0
63 | args.num_workers_test = 0
64 | if args.dataset_name == "seed3":
65 | args.path = args.seed3_path
66 | args.cls_classes = 3
67 | args.time_steps = 30
68 | args.batch_size = 512 #batch_size
69 | args.epoch_preTraining = 300 #epoch of the pre-training phase
70 | elif args.dataset_name == "seed4":
71 | args.path = args.seed4_path
72 | args.cls_classes = 4
73 | args.time_steps = 10
74 | args.batch_size = 256 #batch_size
75 | args.epoch_preTraining = 400 #epoch of the pre-training phase
76 | else:
77 | print("need to define the input dataset")
78 | optim_config = {"lr": args.lr, "weight_decay": args.weight_decay}
79 | # leave-one-subject-out cross-validation
80 | acc_list=[]
81 | writer = SummaryWriter("data/session"+args.session+"/"+args.way+"/" + args.index)
82 | for one_subject in range(0, args.subjects):
83 | # 1.data preparation
84 | source_loaders, test_loader = getDataLoaders(one_subject, args)
85 | data_loader_dict = {"source_loader": source_loaders, "test_loader":test_loader}
86 | # 2. main
87 | acc = main(data_loader_dict, args, optim_config, cuda, writer, one_subject)
88 | writer.add_scalars('single experiment acc: ',
89 | {'test acc': acc}, one_subject + 1)
90 | writer.flush()
91 | acc_list.append(acc)
92 | writer.add_text('final acc avg', str(np.average(acc_list)))
93 | writer.add_text('final acc std', str(np.std(acc_list)))
94 | acc_list_str = [str(x) for x in acc_list]
95 | writer.add_text('final each acc', ",".join(acc_list_str))
96 | writer.add_scalars('final experiment acc scala: /avg',
97 | {'test acc': np.average(acc_list)})
98 | writer.add_scalars('final experiment acc scala: /std',
99 | {'test acc': np.std(acc_list)})
100 | writer.close()
101 |
--------------------------------------------------------------------------------
/noiseInjectionMethods/maskTimeSteps.py:
--------------------------------------------------------------------------------
1 | from preprocess import getDataLoaders
2 | import math
3 | from torch.utils.tensorboard import SummaryWriter
4 | import argparse
5 | from train import *
6 | import random
7 | import os
8 | import numpy as np
9 |
10 | os.environ["CUDA_VISIBLE_DEVICES"] = '3'
11 |
12 | def set_seed(seed=3):
13 | random.seed(seed)
14 | np.random.seed(seed)
15 | os.environ['PYTHONHASHSEED'] = str(seed)
16 | torch.manual_seed(seed)
17 | torch.cuda.manual_seed(seed)
18 | torch.cuda.manual_seed_all(seed)
19 | torch.backends.cudnn.deterministic = True
20 |
21 | def main(data_loader_dict, args, optim_config, cuda, writer, one_subject, seed=3):
22 | set_seed(seed)
23 | if args.dataset_name == 'seed3':
24 | iteration = 7
25 | elif args.dataset_name == 'seed4':
26 | iteration = 3
27 | acc = trainDMMR_Noise_MaskTimeSteps(data_loader_dict, optim_config, cuda, args, iteration, writer, one_subject)
28 | return acc
29 |
30 | if __name__ == '__main__':
31 | cuda = torch.cuda.is_available()
32 | parser = argparse.ArgumentParser(description='DMMR')
33 |
34 | #config of experiment
35 | parser.add_argument("--way", type=str, default='DMMR/seed3', help="name of current way")
36 | parser.add_argument("--index", type=str, default='maskTimeSteps', help="tensorboard index")
37 |
38 | #config of dataset
39 | parser.add_argument("--dataset_name", type=str, nargs='?', default='seed3', help="the dataset name, supporting seed3 and seed4")
40 | parser.add_argument("--session", type=str, nargs='?', default='1', help="selected session")
41 | parser.add_argument("--subjects", type=int, choices=[15], default=15, help="the number of all subject")
42 | parser.add_argument("--dim", type=int, default=310, help="dim of input")
43 |
44 | #config of DMMR
45 | parser.add_argument("--input_dim", type=int, default=310, help="input dim is the same with sample's last dim")
46 | parser.add_argument("--hid_dim", type=int, default=64, help="hid dim is for hidden layer of lstm")
47 | parser.add_argument("--n_layers", type=int, default=1, help="num of layers of lstm")
48 | parser.add_argument("--epoch_fineTuning", type=int, default=500, help="epoch of the fine-tuning phase")
49 | parser.add_argument("--lr", type=int, default=1e-3, help="epoch of calModel")
50 | parser.add_argument("--weight_decay", type=float, default=0.0005, help="weight decay")
51 | parser.add_argument("--beta", type=float, default=0.05, help="balancing hyperparameter in the loss of pretraining phase")
52 |
53 |
54 | args = parser.parse_args()
55 | args.source_subjects = args.subjects-1
56 | args.seed3_path = "../eeg_data/ExtractedFeatures/"
57 | args.seed4_path = "../eeg_data/eeg_feature_smooth/"
58 | if cuda:
59 | args.num_workers_train = 4
60 | args.num_workers_test = 2
61 | else:
62 | args.num_workers_train = 0
63 | args.num_workers_test = 0
64 | if args.dataset_name == "seed3":
65 | args.path = args.seed3_path
66 | args.cls_classes = 3
67 | args.time_steps = 30
68 | args.batch_size = 512 #batch_size
69 | args.epoch_preTraining = 300 #epoch of the pre-training phase
70 | elif args.dataset_name == "seed4":
71 | args.path = args.seed4_path
72 | args.cls_classes = 4
73 | args.time_steps = 10
74 | args.batch_size = 256 #batch_size
75 | args.epoch_preTraining = 400 #epoch of the pre-training phase
76 | else:
77 | print("need to define the input dataset")
78 | optim_config = {"lr": args.lr, "weight_decay": args.weight_decay}
79 | # leave-one-subject-out cross-validation
80 | acc_list=[]
81 | writer = SummaryWriter("data/session"+args.session+"/"+args.way+"/" + args.index)
82 | for one_subject in range(0, args.subjects):
83 | # 1.data preparation
84 | source_loaders, test_loader = getDataLoaders(one_subject, args)
85 | data_loader_dict = {"source_loader": source_loaders, "test_loader":test_loader}
86 | # 2. main
87 | acc = main(data_loader_dict, args, optim_config, cuda, writer, one_subject)
88 | writer.add_scalars('single experiment acc: ',
89 | {'test acc': acc}, one_subject + 1)
90 | writer.flush()
91 | acc_list.append(acc)
92 | writer.add_text('final acc avg', str(np.average(acc_list)))
93 | writer.add_text('final acc std', str(np.std(acc_list)))
94 | acc_list_str = [str(x) for x in acc_list]
95 | writer.add_text('final each acc', ",".join(acc_list_str))
96 | writer.add_scalars('final experiment acc scala: /avg',
97 | {'test acc': np.average(acc_list)})
98 | writer.add_scalars('final experiment acc scala: /std',
99 | {'test acc': np.std(acc_list)})
100 | writer.close()
101 |
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import scipy.io as scio
4 | import os
5 |
6 | def get_data_path(file_path):
7 | data_path = []
8 | for f in os.listdir(file_path):
9 | if f.startswith("."):
10 | continue
11 | else:
12 | data_path.append(os.path.join(file_path, f))
13 | return data_path
14 |
15 | # Time-domain sliding window of the DE feature in an experiment
16 | def window_slice(data, time_steps):
17 | data = np.transpose(data, (1, 0, 2)).reshape(-1, 310)
18 | xs = []
19 | for i in range(data.shape[0] - time_steps + 1):
20 | xs.append(data[i: i + time_steps])
21 | xs = np.concatenate(xs).reshape((len(xs), -1, 310))
22 | return xs
23 |
24 |
25 | def get_number_of_label_n_trial(dataset_name):
26 | '''
27 | description: get the number of categories, trial number and the corresponding labels
28 | param {type}
29 | return {type}:
30 | trial: int
31 | label: int
32 | label_xxx: list 3*15
33 | '''
34 | # global variables
35 | label_seed4 = [[1,2,3,0,2,0,0,1,0,1,2,1,1,1,2,3,2,2,3,3,0,3,0,3],
36 | [2,1,3,0,0,2,0,2,3,3,2,3,2,0,1,1,2,1,0,3,0,1,3,1],
37 | [1,2,2,1,3,3,3,1,1,2,1,0,2,3,3,0,2,3,0,0,2,0,1,0]]
38 | label_seed3 = [[2,1,0,0,1,2,0,1,2,2,1,0,1,2,0],
39 | [2,1,0,0,1,2,0,1,2,2,1,0,1,2,0],
40 | [2,1,0,0,1,2,0,1,2,2,1,0,1,2,0]]
41 | if dataset_name == 'seed3':
42 | label = 3
43 | trial = 15
44 | return trial, label, label_seed3
45 | elif dataset_name == 'seed4':
46 | label = 4
47 | trial = 24
48 | return trial, label, label_seed4
49 | else:
50 | print('Unexcepted dataset name')
51 |
52 | #The return is a list of x and y, where the list contains data for each subject
53 | def load_trained_data(samples_path_list, args):
54 | # load the label data
55 | _, _, labels = get_number_of_label_n_trial(args.dataset_name)
56 | label = labels[int(args.session)-1]
57 | if args.dataset_name=="seed3":
58 | label = np.resize(label, (15,))
59 | label = np.reshape(label, (1, 15))
60 | elif args.dataset_name=="seed4":
61 | label = np.resize(label, (24,))
62 | label = np.reshape(label, (1, 24))
63 | X_train_all = []
64 | Y_tain_all = []
65 | #Iterate through each subject (there are 14 source subjects in both datasets)
66 | for path in samples_path_list:
67 | # load the sample data
68 | sample = scio.loadmat(path, verify_compressed_data_integrity=False)
69 | flag = 0
70 | X_train = []
71 | y_train = []
72 | for key, val in sample.items():
73 | if key.startswith("de_LDS"):
74 | X_train.append(window_slice(val, args.time_steps))
75 | train_label = np.full((X_train[-1].shape[0], 1), label[0, flag])
76 | y_train.append(train_label)
77 | flag += 1
78 | X_train_one_subject=np.concatenate(X_train)
79 | y_train_one_subject=np.concatenate(y_train)
80 | X_train_all.append(X_train_one_subject)
81 | Y_tain_all.append(y_train_one_subject)
82 | return X_train_all, Y_tain_all
83 |
84 | def normalize(features, select_dim=0):
85 | features_min, _ = torch.min(features, dim=select_dim)
86 | features_max, _ = torch.max(features, dim=select_dim)
87 | features_min = features_min.unsqueeze(select_dim)
88 | features_max = features_max.unsqueeze(select_dim)
89 | return (features - features_min)/(features_max - features_min)
90 |
91 | # Load the data, return a list of samples and labels, including tensors of data for each subject
92 | def load4train(samples_path_list, args):
93 | """
94 | load the SEED data set
95 | """
96 | train_sample, train_label = load_trained_data(samples_path_list, args)
97 | sample_res = []
98 | label_res = []
99 | for subject_index in range(len(train_sample)):
100 | # transfer from ndarray to tensor
101 | one_subject_samples = torch.from_numpy(train_sample[subject_index]).type(torch.FloatTensor)
102 | one_subject_labels = torch.from_numpy(train_label[subject_index]).type(torch.LongTensor)
103 | # normalize tensor
104 | one_subject_samples = normalize(one_subject_samples)
105 | sample_res.append(one_subject_samples)
106 | label_res.append(one_subject_labels)
107 | return sample_res, label_res
108 |
109 | #The input is the full set of data from the first session.
110 | def getDataLoaders(one_subject, args):
111 | pre_path=args.path
112 | config_path = {"file_path": pre_path + args.session + "/",
113 | "label_path": pre_path+"label.mat"}
114 | path_list = get_data_path(config_path["file_path"])
115 | try:
116 | target_path_list = [i for i in path_list if(i.startswith(config_path["file_path"] + str(int(one_subject+1))+"_"))]
117 | target_path=target_path_list[0]
118 | except:
119 | print("target data not exist")
120 | path_list.remove(target_path)
121 | source_path_list = path_list
122 |
123 | # read from DE feature
124 | sources_sample, sources_label = load4train(source_path_list, args)
125 | targets_sample, targets_label = load4train(target_path_list, args)
126 |
127 | if(len(targets_label)==1):
128 | target_sample = targets_sample[0]
129 | target_label = targets_label[0]
130 |
131 | # Generate Data loaders
132 | source_dsets = []
133 | for i in range(len(sources_sample)):
134 | source_dsets.append(torch.utils.data.TensorDataset(sources_sample[i], sources_label[i]))
135 | target_dset = torch.utils.data.TensorDataset(target_sample, target_label)
136 |
137 | source_loaders = []
138 | for j in range(len(source_dsets)):
139 | source_loaders.append(torch.utils.data.DataLoader(source_dsets[j], args.batch_size, shuffle=True, num_workers=args.num_workers_train, drop_last=True))
140 | test_loader = torch.utils.data.DataLoader(target_dset, args.batch_size, shuffle=False, num_workers=args.num_workers_test, drop_last=True)
141 | return source_loaders, test_loader
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | python==3.6.13
2 | numpy==1.19.2
3 | scipy==1.5.2
4 | torch==1.10.1
5 | torchvision==0.11.2
6 | tensorboard==1.15.0
7 | pillow==8.3.1
8 | matplotlib==3.3.4
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 |
4 | def testDMMR(dataLoader, DMMRTestModel, cuda, batch_size):
5 | print("testing DMMR")
6 | index = 0
7 | count = 0
8 | data_set_all = 0
9 | if cuda:
10 | DMMRTestModel = DMMRTestModel.cuda()
11 | DMMRTestModel.eval()
12 | with torch.no_grad():
13 | for _, (test_input, label) in enumerate(dataLoader):
14 | if cuda:
15 | test_input, label = test_input.cuda(), label.cuda()
16 | test_input, label = Variable(test_input), Variable(label)
17 | data_set_all += len(label)
18 | x_shared_pred = DMMRTestModel(test_input)
19 | _, pred = torch.max(x_shared_pred, dim=1)
20 | count += pred.eq(label.squeeze().data.view_as(pred)).sum()
21 | index += batch_size
22 | acc = float(count) / data_set_all
23 | return acc
24 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from model import *
4 | import numpy as np
5 | from test import *
6 | from collections import defaultdict
7 | import random
8 | from sklearn.manifold import TSNE
9 | import matplotlib.pyplot as plt
10 |
11 | def trainDMMR(data_loader_dict, optimizer_config, cuda, args, iteration, writer, one_subject):
12 | # data of source subjects, which is used as the training set
13 | source_loader = data_loader_dict['source_loader']
14 | # The pre-training phase
15 | preTrainModel = DMMRPreTrainingModel(cuda, number_of_source=len(source_loader), number_of_category=args.cls_classes, batch_size=args.batch_size, time_steps=args.time_steps)
16 | if cuda:
17 | preTrainModel = preTrainModel.cuda()
18 | source_iters = []
19 | for i in range(len(source_loader)):
20 | source_iters.append(iter(source_loader[i]))
21 | optimizer_PreTraining = torch.optim.Adam(preTrainModel.parameters(), **optimizer_config)
22 |
23 | acc_final = 0
24 | for epoch in range(args.epoch_preTraining):
25 | print("epoch: "+str(epoch))
26 | start_time_pretrain = time.time()
27 | preTrainModel.train()
28 | data_set_all = 0
29 | for i in range(1, iteration + 1):
30 | p = float(i + epoch * iteration) / args.epoch_preTraining / iteration
31 | m = 2. / (1. + np.exp(-10 * p)) - 1 # for the gradient reverse layer (GRL)
32 | batch_dict = defaultdict(list) #Pre-fetch a batch of data for each subject in advance and store them in this dictionary.
33 | data_dict = defaultdict(list) #Store the data of each subject in the current batch
34 | label_dict = defaultdict(list) #Store the labels corresponding to the data of each subject in the current batch
35 | label_data_dict = defaultdict(set)
36 | for j in range(len(source_iters)):
37 | try:
38 | batch_dict[j] = next(source_iters[j])
39 | except:
40 | source_iters[j] = iter(source_loader[j])
41 | batch_dict[j]= next(source_iters[j])
42 | index = 0
43 | for o in batch_dict[j][1]:
44 | cur_label = o[0].item()
45 | data_dict[j].append(batch_dict[j][0][index])
46 | label_dict[j].append(cur_label)
47 | index+=1
48 |
49 | for j in range(len(source_iters)):
50 | # Assign a unique ID to each source subject
51 | subject_id = torch.ones(args.batch_size)
52 | subject_id = subject_id * j
53 | subject_id = subject_id.long()
54 | #the input of the model
55 | source_data, source_label = batch_dict[j]
56 | # Prepare corresponding new batch of each subject, the new batch has same label with current batch.
57 | label_data_dict_list = []
58 | for one_index in range(args.source_subjects):
59 | cur_data_list = data_dict[one_index]
60 | cur_label_list = label_dict[one_index]
61 | for one in range(args.batch_size):
62 | label_data_dict[cur_label_list[one]].add(cur_data_list[one])
63 | label_data_dict_list.append(label_data_dict)
64 | label_data_dict = defaultdict(set)
65 | # Store the corresponding new batch of each subject, providing the supervision for different decoders
66 | corres_batch_data = []
67 | for i in range(len(label_data_dict_list)):
68 | for one in source_label:
69 | label_cur = one[0].item()
70 | corres_batch_data.append(random.choice(list(label_data_dict_list[i][label_cur])))
71 | corres_batch_data = torch.stack(corres_batch_data)
72 | if cuda:
73 | source_data = source_data.cuda()
74 | source_label = source_label.cuda()
75 | subject_id = subject_id.cuda()
76 | corres_batch_data = corres_batch_data.cuda()
77 | data_set_all += len(source_label)
78 | optimizer_PreTraining.zero_grad()
79 | # Call the pretraining model
80 | rec_loss, sim_loss = preTrainModel(source_data, corres_batch_data, subject_id, m, mark=j)
81 | # The loss of the pre-training phase, beta is the balancing hyperparameter
82 | loss_pretrain = rec_loss + args.beta * sim_loss
83 | loss_pretrain.backward()
84 | optimizer_PreTraining.step()
85 | print("data set amount: "+str(data_set_all))
86 | writer.add_scalars('subject: '+str(one_subject+1)+' '+'train DMMR/loss',
87 | {'loss_pretrain':loss_pretrain.data,'rec_loss':rec_loss.data,'sim_loss':sim_loss.data}, epoch + 1)
88 | end_time_pretrain = time.time()
89 | pretrain_epoch_time = end_time_pretrain - start_time_pretrain
90 | print("The time required for one pre-training epoch is:", pretrain_epoch_time, "second")
91 | print("rec_loss: "+str(rec_loss))
92 |
93 | # The fine-tuning phase
94 | source_iters2 = []
95 | for i in range(len(source_loader)):
96 | source_iters2.append(iter(source_loader[i]))
97 | #Load the ABP module, the encoder from pretrained model and build a new model for the fine-tuning phase
98 | fineTuneModel = DMMRFineTuningModel(cuda, preTrainModel, number_of_source=len(source_loader),
99 | number_of_category=args.cls_classes, batch_size=args.batch_size,
100 | time_steps=args.time_steps)
101 |
102 | optimizer_FineTuning = torch.optim.Adam(fineTuneModel.parameters(), **optimizer_config)
103 | if cuda:
104 | fineTuneModel = fineTuneModel.cuda()
105 | for epoch in range(args.epoch_fineTuning):
106 | print("epoch: " + str(epoch))
107 | start_time = time.time()
108 | fineTuneModel.train()
109 | count = 0
110 | data_set_all = 0
111 | for i in range(1, iteration + 1):
112 | batch_dict = defaultdict(list)
113 | data_dict = defaultdict(list)
114 | label_dict = defaultdict(list)
115 |
116 | for j in range(len(source_iters2)):
117 | try:
118 | batch_dict[j] = next(source_iters2[j])
119 | except:
120 | source_iters2[j] = iter(source_loader[j])
121 | batch_dict[j] = next(source_iters2[j])
122 | index = 0
123 | for o in batch_dict[j][1]:
124 | cur_label = o[0].item()
125 | data_dict[j].append(batch_dict[j][0][index])
126 | label_dict[j].append(cur_label)
127 | index += 1
128 | for j in range(len(source_iters)):
129 | source_data, source_label = batch_dict[j]
130 | if cuda:
131 | source_data = source_data.cuda()
132 | source_label = source_label.cuda()
133 | data_set_all += len(source_label)
134 | optimizer_FineTuning.zero_grad()
135 | # Call the fine-tuning model
136 | x_pred, x_logits, cls_loss = fineTuneModel(source_data, source_label)
137 | cls_loss.backward()
138 | optimizer_FineTuning.step()
139 | _, pred = torch.max(x_pred, dim=1)
140 | count += pred.eq(source_label.squeeze().data.view_as(pred)).sum()
141 | end_time = time.time()
142 | epoch_time = end_time - start_time
143 | print("The time required for one fine-tuning epoch is:", epoch_time, "second")
144 | print("data set amount: " + str(data_set_all))
145 | acc = float(count) / data_set_all
146 | writer.add_scalars('subject: ' + str(one_subject + 1) + ' ' + 'train DMMR/loss',
147 | {'cls_loss': cls_loss.data}, epoch + 1)
148 | writer.add_scalar('subject: ' + str(one_subject + 1) + ' ' + 'train DMMR/train accuracy', acc, epoch + 1)
149 | print("acc: " + str(acc))
150 | # test the fine-tuned model with the data of unseen target subject
151 | testModel = DMMRTestModel(fineTuneModel)
152 | acc_DMMR = testDMMR(data_loader_dict["test_loader"], testModel, cuda, args.batch_size)
153 | print("acc_DMMR: " + str(acc_DMMR))
154 | writer.add_scalars('subject: ' + str(one_subject + 1) + ' ' + 'test DMMR/test acc',
155 | {'test acc': acc_DMMR}, epoch + 1)
156 | if acc_DMMR > acc_final:
157 | acc_final = acc_DMMR
158 | best_pretrain_model = copy.deepcopy(preTrainModel.state_dict())
159 | best_tune_model = copy.deepcopy(fineTuneModel.state_dict())
160 | best_test_model = copy.deepcopy(testModel.state_dict())
161 | modelDir = "model/" + args.way + "/" + args.index + "/"
162 | try:
163 | os.makedirs(modelDir)
164 | except:
165 | pass
166 | # save models
167 | torch.save(best_pretrain_model, modelDir + str(one_subject) + '_pretrain_model.pth')
168 | torch.save(best_tune_model, modelDir + str(one_subject) + '_tune_model.pth')
169 | torch.save(best_test_model, modelDir + str(one_subject) + '_test_model.pth')
170 | return acc_final
171 |
172 | ############## Ablation studies ##############
173 | # w/o mix
174 | def trainDMMR_WithoutMix(data_loader_dict, optimizer_config, cuda, args, iteration, writer, one_subject):
175 | source_loader = data_loader_dict['source_loader']
176 | # The pre-training phase
177 | preTrainModel = PreTrainingWithoutMix(cuda, number_of_source=len(source_loader), number_of_category=args.cls_classes, batch_size=args.batch_size, time_steps=args.time_steps)
178 | if cuda:
179 | preTrainModel = preTrainModel.cuda()
180 | source_iters = []
181 | for i in range(len(source_loader)):
182 | source_iters.append(iter(source_loader[i]))
183 | optimizer_PreTraining = torch.optim.Adam(preTrainModel.parameters(), **optimizer_config)
184 |
185 | acc_final = 0
186 | for epoch in range(args.epoch_preTraining):
187 | print("epoch: "+str(epoch))
188 | start_time_pretrain = time.time()
189 | preTrainModel.train()
190 | count = 0
191 | data_set_all = 0
192 | for i in range(1, iteration + 1):
193 | p = float(i + epoch * iteration) / args.epoch_preTraining / iteration
194 | m = 2. / (1. + np.exp(-10 * p)) - 1
195 | batch_dict = defaultdict(list)
196 | data_dict = defaultdict(list)
197 | label_dict = defaultdict(list)
198 | label_data_dict = defaultdict(set)
199 |
200 | for j in range(len(source_iters)):
201 | try:
202 | batch_dict[j] = next(source_iters[j])
203 | except:
204 | source_iters[j] = iter(source_loader[j])
205 | batch_dict[j]= next(source_iters[j])
206 | index = 0
207 | for o in batch_dict[j][1]:
208 | cur_label = o[0].item()
209 | data_dict[j].append(batch_dict[j][0][index])
210 | label_dict[j].append(cur_label)
211 | index+=1
212 |
213 | for j in range(len(source_iters)):
214 | subject_id = torch.ones(args.batch_size)
215 | subject_id = subject_id * j
216 | subject_id = subject_id.long()
217 |
218 | source_data, source_label = batch_dict[j]
219 | # prepare corresponding new batch, the new batch has same label with current batch
220 | label_data_dict_list = []
221 | for one_index in range(args.source_subjects):
222 | cur_data_list = data_dict[one_index]
223 | cur_label_list = label_dict[one_index]
224 | for one in range(args.batch_size):
225 | label_data_dict[cur_label_list[one]].add(cur_data_list[one])
226 | label_data_dict_list.append(label_data_dict)
227 | label_data_dict = defaultdict(set)
228 | corres_batch_data = []
229 | for i in range(len(label_data_dict_list)):
230 | for one in source_label:
231 | label_cur = one[0].item()
232 | corres_batch_data.append(random.choice(list(label_data_dict_list[i][label_cur])))
233 | corres_batch_data = torch.stack(corres_batch_data)
234 | if cuda:
235 | source_data = source_data.cuda()
236 | source_label = source_label.cuda()
237 | subject_id = subject_id.cuda()
238 | corres_batch_data =corres_batch_data.cuda()
239 | data_set_all+=len(source_label)
240 | optimizer_PreTraining.zero_grad()
241 | rec_loss, sim_loss = preTrainModel(source_data, corres_batch_data, subject_id, m, mark=j)
242 | loss_pretrain = rec_loss + args.beta*sim_loss
243 | loss_pretrain.backward()
244 | optimizer_PreTraining.step()
245 | print("data set amount: "+str(data_set_all))
246 | writer.add_scalars('subject: '+str(one_subject+1)+' '+'train DMMR/loss',
247 | {'loss_pretrain':loss_pretrain.data,'rec_loss':rec_loss.data,'sim_loss':sim_loss.data}, epoch + 1)
248 | end_time_pretrain = time.time()
249 | pretrain_epoch_time = end_time_pretrain - start_time_pretrain
250 | print("The time required for one pre-training epoch is:", pretrain_epoch_time, "second")
251 | print("rec_loss: "+str(rec_loss))
252 |
253 | # The fine-tuning phase
254 | source_iters2 = []
255 | for i in range(len(source_loader)):
256 | source_iters2.append(iter(source_loader[i]))
257 | fineTuneModel = DMMRFineTuningModel(cuda, preTrainModel, number_of_source=len(source_loader),
258 | number_of_category=args.cls_classes, batch_size=args.batch_size,
259 | time_steps=args.time_steps)
260 |
261 | optimizer_FineTuning = torch.optim.Adam(fineTuneModel.parameters(), **optimizer_config)
262 | if cuda:
263 | fineTuneModel = fineTuneModel.cuda()
264 | for epoch in range(args.epoch_fineTuning):
265 | print("epoch: " + str(epoch))
266 | start_time = time.time()
267 | fineTuneModel.train()
268 | count = 0
269 | data_set_all = 0
270 | for i in range(1, iteration + 1):
271 | p = float(i + epoch * iteration) / args.epoch_preTraining / iteration
272 | m = 2. / (1. + np.exp(-10 * p)) - 1
273 | batch_dict = defaultdict(list)
274 | data_dict = defaultdict(list)
275 | label_dict = defaultdict(list)
276 | label_data_dict = defaultdict(set)
277 |
278 | for j in range(len(source_iters2)):
279 | try:
280 | batch_dict[j] = next(source_iters2[j])
281 | except:
282 | source_iters2[j] = iter(source_loader[j])
283 | batch_dict[j] = next(source_iters2[j])
284 | index = 0
285 | for o in batch_dict[j][1]:
286 | cur_label = o[0].item()
287 | data_dict[j].append(batch_dict[j][0][index])
288 | label_dict[j].append(cur_label)
289 | index += 1
290 | for j in range(len(source_iters)):
291 | source_data, source_label = batch_dict[j]
292 | if cuda:
293 | source_data = source_data.cuda()
294 | source_label = source_label.cuda()
295 | data_set_all += len(source_label)
296 | optimizer_FineTuning.zero_grad()
297 | x_pred, x_logits, cls_loss = fineTuneModel(source_data, source_label)
298 | cls_loss.backward()
299 | optimizer_FineTuning.step()
300 | _, pred = torch.max(x_pred, dim=1)
301 | count += pred.eq(source_label.squeeze().data.view_as(pred)).sum()
302 | end_time = time.time()
303 | epoch_time = end_time - start_time
304 | print("The time required for one fine-tuning epoch is:", epoch_time, "second")
305 | print("data set amount: " + str(data_set_all))
306 | acc = float(count) / data_set_all
307 | writer.add_scalars('subject: ' + str(one_subject + 1) + ' ' + 'train DMMR/loss',
308 | {'cls_loss': cls_loss.data}, epoch + 1)
309 | writer.add_scalar('subject: ' + str(one_subject + 1) + ' ' + 'train DMMR/train accuracy', acc, epoch + 1)
310 | print("acc: " + str(acc))
311 |
312 | testModel = DMMRTestModel(fineTuneModel)
313 | acc_DMMR = testDMMR(data_loader_dict["test_loader"], testModel, cuda, args.batch_size)
314 | print("acc_DMMR: " + str(acc_DMMR))
315 | writer.add_scalars('subject: ' + str(one_subject + 1) + ' ' + 'test DMMR/test acc',
316 | {'test acc': acc_DMMR}, epoch + 1)
317 | if acc_DMMR > acc_final:
318 | acc_final = acc_DMMR
319 | best_pretrain_model = copy.deepcopy(preTrainModel.state_dict())
320 | best_tune_model = copy.deepcopy(fineTuneModel.state_dict())
321 | best_test_model = copy.deepcopy(testModel.state_dict())
322 | modelDir = "model/" + args.way + "/" + args.index + "/"
323 | try:
324 | os.makedirs(modelDir)
325 | except:
326 | pass
327 | torch.save(best_pretrain_model, modelDir + str(one_subject) + '_pretrain_model.pth')
328 | torch.save(best_tune_model, modelDir + str(one_subject) + '_tune_model.pth')
329 | torch.save(best_test_model, modelDir + str(one_subject) + '_test_model.pth')
330 | return acc_final
331 | # w/o noise
332 | def trainDMMR_WithoutNoise(data_loader_dict, optimizer_config, cuda, args, iteration, writer, one_subject):
333 | source_loader = data_loader_dict['source_loader']
334 | # The pre-training phase
335 | preTrainModel = PreTrainingWithoutNoise(cuda, number_of_source=len(source_loader), number_of_category=args.cls_classes, batch_size=args.batch_size, time_steps=args.time_steps)
336 | if cuda:
337 | preTrainModel = preTrainModel.cuda()
338 | source_iters = []
339 | for i in range(len(source_loader)):
340 | source_iters.append(iter(source_loader[i]))
341 | optimizer_PreTraining = torch.optim.Adam(preTrainModel.parameters(), **optimizer_config)
342 |
343 | acc_final = 0
344 | for epoch in range(args.epoch_preTraining):
345 | print("epoch: "+str(epoch))
346 | start_time_pretrain = time.time()
347 | preTrainModel.train()
348 | count = 0
349 | data_set_all = 0
350 | for i in range(1, iteration + 1):
351 | p = float(i + epoch * iteration) / args.epoch_preTraining / iteration
352 | m = 2. / (1. + np.exp(-10 * p)) - 1
353 | batch_dict = defaultdict(list)
354 | data_dict = defaultdict(list)
355 | label_dict = defaultdict(list)
356 | label_data_dict = defaultdict(set)
357 |
358 | for j in range(len(source_iters)):
359 | try:
360 | batch_dict[j] = next(source_iters[j])
361 | except:
362 | source_iters[j] = iter(source_loader[j])
363 | batch_dict[j]= next(source_iters[j])
364 | index = 0
365 | for o in batch_dict[j][1]:
366 | cur_label = o[0].item()
367 | data_dict[j].append(batch_dict[j][0][index])
368 | label_dict[j].append(cur_label)
369 | index+=1
370 |
371 | for j in range(len(source_iters)):
372 | subject_id = torch.ones(args.batch_size)
373 | subject_id = subject_id * j
374 | subject_id = subject_id.long()
375 |
376 | source_data, source_label = batch_dict[j]
377 | # prepare corresponding new batch, the new batch has same label with current batch
378 | label_data_dict_list = []
379 | for one_index in range(args.source_subjects):
380 | cur_data_list = data_dict[one_index]
381 | cur_label_list = label_dict[one_index]
382 | for one in range(args.batch_size):
383 | label_data_dict[cur_label_list[one]].add(cur_data_list[one])
384 | label_data_dict_list.append(label_data_dict)
385 | label_data_dict = defaultdict(set)
386 | corres_batch_data = []
387 | for i in range(len(label_data_dict_list)):
388 | for one in source_label:
389 | label_cur = one[0].item()
390 | corres_batch_data.append(random.choice(list(label_data_dict_list[i][label_cur])))
391 | corres_batch_data = torch.stack(corres_batch_data)
392 | if cuda:
393 | source_data = source_data.cuda()
394 | source_label = source_label.cuda()
395 | subject_id = subject_id.cuda()
396 | corres_batch_data =corres_batch_data.cuda()
397 | data_set_all+=len(source_label)
398 | optimizer_PreTraining.zero_grad()
399 | rec_loss, sim_loss = preTrainModel(source_data, corres_batch_data, subject_id, m, mark=j)
400 | loss_pretrain = rec_loss + args.beta*sim_loss
401 | loss_pretrain.backward()
402 | optimizer_PreTraining.step()
403 | print("data set amount: "+str(data_set_all))
404 | writer.add_scalars('subject: '+str(one_subject+1)+' '+'train DMMR/loss',
405 | {'loss_pretrain':loss_pretrain.data,'rec_loss':rec_loss.data,'sim_loss':sim_loss.data}, epoch + 1)
406 | end_time_pretrain = time.time()
407 | pretrain_epoch_time = end_time_pretrain - start_time_pretrain
408 | print("The time required for one pre-training epoch is:", pretrain_epoch_time, "second")
409 | print("rec_loss: "+str(rec_loss))
410 |
411 | # The fine-tuning phase
412 | source_iters2 = []
413 | for i in range(len(source_loader)):
414 | source_iters2.append(iter(source_loader[i]))
415 | fineTuneModel = DMMRFineTuningModel(cuda, preTrainModel, number_of_source=len(source_loader),
416 | number_of_category=args.cls_classes, batch_size=args.batch_size,
417 | time_steps=args.time_steps)
418 |
419 | optimizer_FineTuning = torch.optim.Adam(fineTuneModel.parameters(), **optimizer_config)
420 | if cuda:
421 | fineTuneModel = fineTuneModel.cuda()
422 | for epoch in range(args.epoch_fineTuning):
423 | print("epoch: " + str(epoch))
424 | start_time = time.time()
425 | fineTuneModel.train()
426 | count = 0
427 | data_set_all = 0
428 | for i in range(1, iteration + 1):
429 | p = float(i + epoch * iteration) / args.epoch_preTraining / iteration
430 | m = 2. / (1. + np.exp(-10 * p)) - 1
431 | batch_dict = defaultdict(list)
432 | data_dict = defaultdict(list)
433 | label_dict = defaultdict(list)
434 | label_data_dict = defaultdict(set)
435 |
436 | for j in range(len(source_iters2)):
437 | try:
438 | batch_dict[j] = next(source_iters2[j])
439 | except:
440 | source_iters2[j] = iter(source_loader[j])
441 | batch_dict[j] = next(source_iters2[j])
442 | index = 0
443 | for o in batch_dict[j][1]:
444 | cur_label = o[0].item()
445 | data_dict[j].append(batch_dict[j][0][index])
446 | label_dict[j].append(cur_label)
447 | index += 1
448 | for j in range(len(source_iters)):
449 | source_data, source_label = batch_dict[j]
450 | if cuda:
451 | source_data = source_data.cuda()
452 | source_label = source_label.cuda()
453 | data_set_all += len(source_label)
454 | optimizer_FineTuning.zero_grad()
455 | x_pred, x_logits, cls_loss = fineTuneModel(source_data, source_label)
456 | cls_loss.backward()
457 | optimizer_FineTuning.step()
458 | _, pred = torch.max(x_pred, dim=1)
459 | count += pred.eq(source_label.squeeze().data.view_as(pred)).sum()
460 | end_time = time.time()
461 | epoch_time = end_time - start_time
462 | print("The time required for one fine-tuning epoch is:", epoch_time, "second")
463 | print("data set amount: " + str(data_set_all))
464 | acc = float(count) / data_set_all
465 | writer.add_scalars('subject: ' + str(one_subject + 1) + ' ' + 'train DMMR/loss',
466 | {'cls_loss': cls_loss.data}, epoch + 1)
467 | writer.add_scalar('subject: ' + str(one_subject + 1) + ' ' + 'train DMMR/train accuracy', acc, epoch + 1)
468 | print("acc: " + str(acc))
469 |
470 | testModel = DMMRTestModel(fineTuneModel)
471 | acc_DMMR = testDMMR(data_loader_dict["test_loader"], testModel, cuda, args.batch_size)
472 | print("acc_DMMR: " + str(acc_DMMR))
473 | writer.add_scalars('subject: ' + str(one_subject + 1) + ' ' + 'test DMMR/test acc',
474 | {'test acc': acc_DMMR}, epoch + 1)
475 | if acc_DMMR > acc_final:
476 | acc_final = acc_DMMR
477 | best_pretrain_model = copy.deepcopy(preTrainModel.state_dict())
478 | best_tune_model = copy.deepcopy(fineTuneModel.state_dict())
479 | best_test_model = copy.deepcopy(testModel.state_dict())
480 | modelDir = "model/"+args.way+"/"+args.index+"/"
481 | try:
482 | os.makedirs(modelDir)
483 | except:
484 | pass
485 | torch.save(best_pretrain_model, modelDir+str(one_subject)+'_pretrain_model.pth')
486 | torch.save(best_tune_model, modelDir+str(one_subject)+'_tune_model.pth')
487 | torch.save(best_test_model, modelDir+str(one_subject)+'_test_model.pth')
488 | return acc_final
489 | # w/o both
490 | def trainDMMR_WithoutBothMixAndNoise(data_loader_dict, optimizer_config, cuda, args, iteration, writer, one_subject):
491 | source_loader = data_loader_dict['source_loader']
492 | # The pre-training phase
493 | preTrainModel = PreTrainingWithoutBothMixAndNoise(cuda, number_of_source=len(source_loader), number_of_category=args.cls_classes, batch_size=args.batch_size, time_steps=args.time_steps)
494 | if cuda:
495 | preTrainModel = preTrainModel.cuda()
496 | source_iters = []
497 | for i in range(len(source_loader)):
498 | source_iters.append(iter(source_loader[i]))
499 | optimizer_PreTraining = torch.optim.Adam(preTrainModel.parameters(), **optimizer_config)
500 |
501 | acc_final = 0
502 | for epoch in range(args.epoch_preTraining):
503 | print("epoch: "+str(epoch))
504 | start_time_pretrain = time.time()
505 | preTrainModel.train()
506 | count = 0
507 | data_set_all = 0
508 | for i in range(1, iteration + 1):
509 | p = float(i + epoch * iteration) / args.epoch_preTraining / iteration
510 | m = 2. / (1. + np.exp(-10 * p)) - 1
511 | batch_dict = defaultdict(list)
512 | data_dict = defaultdict(list)
513 | label_dict = defaultdict(list)
514 | label_data_dict = defaultdict(set)
515 |
516 | for j in range(len(source_iters)):
517 | try:
518 | batch_dict[j] = next(source_iters[j])
519 | except:
520 | source_iters[j] = iter(source_loader[j])
521 | batch_dict[j]= next(source_iters[j])
522 | index = 0
523 | for o in batch_dict[j][1]:
524 | cur_label = o[0].item()
525 | data_dict[j].append(batch_dict[j][0][index])
526 | label_dict[j].append(cur_label)
527 | index+=1
528 |
529 | for j in range(len(source_iters)):
530 | subject_id = torch.ones(args.batch_size)
531 | subject_id = subject_id * j
532 | subject_id = subject_id.long()
533 |
534 | source_data, source_label = batch_dict[j]
535 | # prepare corresponding new batch, the new batch has same label with current batch
536 | label_data_dict_list = []
537 | for one_index in range(args.source_subjects):
538 | cur_data_list = data_dict[one_index]
539 | cur_label_list = label_dict[one_index]
540 | for one in range(args.batch_size):
541 | label_data_dict[cur_label_list[one]].add(cur_data_list[one])
542 | label_data_dict_list.append(label_data_dict)
543 | label_data_dict = defaultdict(set)
544 | corres_batch_data = []
545 | for i in range(len(label_data_dict_list)):
546 | for one in source_label:
547 | label_cur = one[0].item()
548 | corres_batch_data.append(random.choice(list(label_data_dict_list[i][label_cur])))
549 | corres_batch_data = torch.stack(corres_batch_data)
550 | if cuda:
551 | source_data = source_data.cuda()
552 | source_label = source_label.cuda()
553 | subject_id = subject_id.cuda()
554 | corres_batch_data =corres_batch_data.cuda()
555 | data_set_all+=len(source_label)
556 | optimizer_PreTraining.zero_grad()
557 | rec_loss, sim_loss = preTrainModel(source_data, corres_batch_data, subject_id, m, mark=j)
558 | loss_pretrain = rec_loss + args.beta*sim_loss
559 | loss_pretrain.backward()
560 | optimizer_PreTraining.step()
561 | print("data set amount: "+str(data_set_all))
562 | writer.add_scalars('subject: '+str(one_subject+1)+' '+'train DMMR/loss',
563 | {'loss_pretrain':loss_pretrain.data,'rec_loss':rec_loss.data,'sim_loss':sim_loss.data}, epoch + 1)
564 | end_time_pretrain = time.time()
565 | pretrain_epoch_time = end_time_pretrain - start_time_pretrain
566 | print("The time required for one pre-training epoch is:", pretrain_epoch_time, "second")
567 | print("rec_loss: "+str(rec_loss))
568 |
569 | # The fine-tuning phase
570 | source_iters2 = []
571 | for i in range(len(source_loader)):
572 | source_iters2.append(iter(source_loader[i]))
573 | fineTuneModel = DMMRFineTuningModel(cuda, preTrainModel, number_of_source=len(source_loader),
574 | number_of_category=args.cls_classes, batch_size=args.batch_size,
575 | time_steps=args.time_steps)
576 |
577 | optimizer_FineTuning = torch.optim.Adam(fineTuneModel.parameters(), **optimizer_config)
578 | if cuda:
579 | fineTuneModel = fineTuneModel.cuda()
580 | for epoch in range(args.epoch_fineTuning):
581 | print("epoch: " + str(epoch))
582 | start_time = time.time()
583 | fineTuneModel.train()
584 | count = 0
585 | data_set_all = 0
586 | for i in range(1, iteration + 1):
587 | p = float(i + epoch * iteration) / args.epoch_preTraining / iteration
588 | m = 2. / (1. + np.exp(-10 * p)) - 1
589 | batch_dict = defaultdict(list)
590 | data_dict = defaultdict(list)
591 | label_dict = defaultdict(list)
592 | label_data_dict = defaultdict(set)
593 |
594 | for j in range(len(source_iters2)):
595 | try:
596 | batch_dict[j] = next(source_iters2[j])
597 | except:
598 | source_iters2[j] = iter(source_loader[j])
599 | batch_dict[j] = next(source_iters2[j])
600 | index = 0
601 | for o in batch_dict[j][1]:
602 | cur_label = o[0].item()
603 | data_dict[j].append(batch_dict[j][0][index])
604 | label_dict[j].append(cur_label)
605 | index += 1
606 | for j in range(len(source_iters)):
607 | source_data, source_label = batch_dict[j]
608 | if cuda:
609 | source_data = source_data.cuda()
610 | source_label = source_label.cuda()
611 | data_set_all += len(source_label)
612 | optimizer_FineTuning.zero_grad()
613 | x_pred, x_logits, cls_loss = fineTuneModel(source_data, source_label)
614 | cls_loss.backward()
615 | optimizer_FineTuning.step()
616 | _, pred = torch.max(x_pred, dim=1)
617 | count += pred.eq(source_label.squeeze().data.view_as(pred)).sum()
618 | end_time = time.time()
619 | epoch_time = end_time - start_time
620 | print("The time required for one fine-tuning epoch is:", epoch_time, "second")
621 | print("data set amount: " + str(data_set_all))
622 | acc = float(count) / data_set_all
623 | writer.add_scalars('subject: ' + str(one_subject + 1) + ' ' + 'train DMMR/loss',
624 | {'cls_loss': cls_loss.data}, epoch + 1)
625 | writer.add_scalar('subject: ' + str(one_subject + 1) + ' ' + 'train DMMR/train accuracy', acc, epoch + 1)
626 | print("acc: " + str(acc))
627 |
628 | testModel = DMMRTestModel(fineTuneModel)
629 | acc_DMMR = testDMMR(data_loader_dict["test_loader"], testModel, cuda, args.batch_size)
630 | print("acc_DMMR: " + str(acc_DMMR))
631 | writer.add_scalars('subject: ' + str(one_subject + 1) + ' ' + 'test DMMR/test acc',
632 | {'test acc': acc_DMMR}, epoch + 1)
633 | if acc_DMMR > acc_final:
634 | acc_final = acc_DMMR
635 | best_pretrain_model = copy.deepcopy(preTrainModel.state_dict())
636 | best_tune_model = copy.deepcopy(fineTuneModel.state_dict())
637 | best_test_model = copy.deepcopy(testModel.state_dict())
638 | modelDir = "model/" + args.way + "/" + args.index + "/"
639 | try:
640 | os.makedirs(modelDir)
641 | except:
642 | pass
643 | torch.save(best_pretrain_model, modelDir + str(one_subject) + '_pretrain_model.pth')
644 | torch.save(best_tune_model, modelDir + str(one_subject) + '_tune_model.pth')
645 | torch.save(best_test_model, modelDir + str(one_subject) + '_test_model.pth')
646 | return acc_final
647 |
648 | ############## Other noise injection methods ##############
649 | def trainDMMR_Noise_MaskChannels(data_loader_dict, optimizer_config, cuda, args, iteration, writer, one_subject):
650 | source_loader = data_loader_dict['source_loader']
651 | # The pre-training phase
652 | preTrainModel = PreTrainingWithMaskChannels(cuda, number_of_source=len(source_loader), number_of_category=args.cls_classes, batch_size=args.batch_size, time_steps=args.time_steps)
653 | if cuda:
654 | preTrainModel = preTrainModel.cuda()
655 | source_iters = []
656 | for i in range(len(source_loader)):
657 | source_iters.append(iter(source_loader[i]))
658 | optimizer_PreTraining = torch.optim.Adam(preTrainModel.parameters(), **optimizer_config)
659 |
660 | acc_final = 0
661 | for epoch in range(args.epoch_preTraining):
662 | print("epoch: "+str(epoch))
663 | start_time_pretrain = time.time()
664 | preTrainModel.train()
665 | count = 0
666 | data_set_all = 0
667 | for i in range(1, iteration + 1):
668 | p = float(i + epoch * iteration) / args.epoch_preTraining / iteration
669 | m = 2. / (1. + np.exp(-10 * p)) - 1
670 | batch_dict = defaultdict(list)
671 | data_dict = defaultdict(list)
672 | label_dict = defaultdict(list)
673 | label_data_dict = defaultdict(set)
674 |
675 | for j in range(len(source_iters)):
676 | try:
677 | batch_dict[j] = next(source_iters[j])
678 | except:
679 | source_iters[j] = iter(source_loader[j])
680 | batch_dict[j]= next(source_iters[j])
681 | index = 0
682 | for o in batch_dict[j][1]:
683 | cur_label = o[0].item()
684 | data_dict[j].append(batch_dict[j][0][index])
685 | label_dict[j].append(cur_label)
686 | index+=1
687 |
688 | for j in range(len(source_iters)):
689 | subject_id = torch.ones(args.batch_size)
690 | subject_id = subject_id * j
691 | subject_id = subject_id.long()
692 |
693 | source_data, source_label = batch_dict[j]
694 | # prepare corresponding new batch, the new batch has same label with current batch
695 | label_data_dict_list = []
696 | for one_index in range(args.source_subjects):
697 | cur_data_list = data_dict[one_index]
698 | cur_label_list = label_dict[one_index]
699 | for one in range(args.batch_size):
700 | label_data_dict[cur_label_list[one]].add(cur_data_list[one])
701 | label_data_dict_list.append(label_data_dict)
702 | label_data_dict = defaultdict(set)
703 | corres_batch_data = []
704 | for i in range(len(label_data_dict_list)):
705 | for one in source_label:
706 | label_cur = one[0].item()
707 | corres_batch_data.append(random.choice(list(label_data_dict_list[i][label_cur])))
708 | corres_batch_data = torch.stack(corres_batch_data)
709 | if cuda:
710 | source_data = source_data.cuda()
711 | source_label = source_label.cuda()
712 | subject_id = subject_id.cuda()
713 | corres_batch_data =corres_batch_data.cuda()
714 | data_set_all+=len(source_label)
715 | optimizer_PreTraining.zero_grad()
716 | rec_loss, sim_loss = preTrainModel(source_data, corres_batch_data, subject_id, args, m, mark=j)
717 | loss_pretrain = rec_loss + args.beta*sim_loss
718 | loss_pretrain.backward()
719 | optimizer_PreTraining.step()
720 | print("data set amount: "+str(data_set_all))
721 | writer.add_scalars('subject: '+str(one_subject+1)+' '+'train DMMR/loss',
722 | {'loss_pretrain':loss_pretrain.data,'rec_loss':rec_loss.data,'sim_loss':sim_loss.data}, epoch + 1)
723 | end_time_pretrain = time.time()
724 | pretrain_epoch_time = end_time_pretrain - start_time_pretrain
725 | print("The time required for one pre-training epoch is:", pretrain_epoch_time, "second")
726 | print("rec_loss: "+str(rec_loss))
727 |
728 | # The fine-tuning phase
729 | source_iters2 = []
730 | for i in range(len(source_loader)):
731 | source_iters2.append(iter(source_loader[i]))
732 | fineTuneModel = DMMRFineTuningModel(cuda, preTrainModel, number_of_source=len(source_loader),
733 | number_of_category=args.cls_classes, batch_size=args.batch_size,
734 | time_steps=args.time_steps)
735 |
736 | optimizer_FineTuning = torch.optim.Adam(fineTuneModel.parameters(), **optimizer_config)
737 | if cuda:
738 | fineTuneModel = fineTuneModel.cuda()
739 | for epoch in range(args.epoch_fineTuning):
740 | print("epoch: " + str(epoch))
741 | start_time = time.time()
742 | fineTuneModel.train()
743 | count = 0
744 | data_set_all = 0
745 | for i in range(1, iteration + 1):
746 | p = float(i + epoch * iteration) / args.epoch_preTraining / iteration
747 | m = 2. / (1. + np.exp(-10 * p)) - 1
748 | batch_dict = defaultdict(list)
749 | data_dict = defaultdict(list)
750 | label_dict = defaultdict(list)
751 | label_data_dict = defaultdict(set)
752 |
753 | for j in range(len(source_iters2)):
754 | try:
755 | batch_dict[j] = next(source_iters2[j])
756 | except:
757 | source_iters2[j] = iter(source_loader[j])
758 | batch_dict[j] = next(source_iters2[j])
759 | index = 0
760 | for o in batch_dict[j][1]:
761 | cur_label = o[0].item()
762 | data_dict[j].append(batch_dict[j][0][index])
763 | label_dict[j].append(cur_label)
764 | index += 1
765 | for j in range(len(source_iters)):
766 | source_data, source_label = batch_dict[j]
767 | if cuda:
768 | source_data = source_data.cuda()
769 | source_label = source_label.cuda()
770 | data_set_all += len(source_label)
771 | optimizer_FineTuning.zero_grad()
772 | x_pred, x_logits, cls_loss = fineTuneModel(source_data, source_label)
773 | cls_loss.backward()
774 | optimizer_FineTuning.step()
775 | _, pred = torch.max(x_pred, dim=1)
776 | count += pred.eq(source_label.squeeze().data.view_as(pred)).sum()
777 | end_time = time.time()
778 | epoch_time = end_time - start_time
779 | print("The time required for one fine-tuning epoch is:", epoch_time, "second")
780 | print("data set amount: " + str(data_set_all))
781 | acc = float(count) / data_set_all
782 | writer.add_scalars('subject: ' + str(one_subject + 1) + ' ' + 'train DMMR/loss',
783 | {'cls_loss': cls_loss.data}, epoch + 1)
784 | writer.add_scalar('subject: ' + str(one_subject + 1) + ' ' + 'train DMMR/train accuracy', acc, epoch + 1)
785 | print("acc: " + str(acc))
786 |
787 | testModel = DMMRTestModel(fineTuneModel)
788 | acc_DMMR = testDMMR(data_loader_dict["test_loader"], testModel, cuda, args.batch_size)
789 | print("acc_DMMR: " + str(acc_DMMR))
790 | writer.add_scalars('subject: ' + str(one_subject + 1) + ' ' + 'test DMMR/test acc',
791 | {'test acc': acc_DMMR}, epoch + 1)
792 | if acc_DMMR > acc_final:
793 | acc_final = acc_DMMR
794 | best_pretrain_model = copy.deepcopy(preTrainModel.state_dict())
795 | best_tune_model = copy.deepcopy(fineTuneModel.state_dict())
796 | best_test_model = copy.deepcopy(testModel.state_dict())
797 | modelDir = "model/" + args.way + "/" + args.index + "/"
798 | try:
799 | os.makedirs(modelDir)
800 | except:
801 | pass
802 | torch.save(best_pretrain_model, modelDir + str(one_subject) + '_pretrain_model.pth')
803 | torch.save(best_tune_model, modelDir + str(one_subject) + '_tune_model.pth')
804 | torch.save(best_test_model, modelDir + str(one_subject) + '_test_model.pth')
805 | return acc_final
806 |
807 | def trainDMMR_Noise_MaskTimeSteps(data_loader_dict, optimizer_config, cuda, args, iteration, writer, one_subject):
808 | source_loader = data_loader_dict['source_loader']
809 | # The pre-training phase
810 | preTrainModel = PreTrainingWithMaskTimeSteps(cuda, number_of_source=len(source_loader), number_of_category=args.cls_classes, batch_size=args.batch_size, time_steps=args.time_steps)
811 | if cuda:
812 | preTrainModel = preTrainModel.cuda()
813 | source_iters = []
814 | for i in range(len(source_loader)):
815 | source_iters.append(iter(source_loader[i]))
816 | optimizer_PreTraining = torch.optim.Adam(preTrainModel.parameters(), **optimizer_config)
817 |
818 | acc_final = 0
819 | for epoch in range(args.epoch_preTraining):
820 | print("epoch: "+str(epoch))
821 | start_time_pretrain = time.time()
822 | preTrainModel.train()
823 | count = 0
824 | data_set_all = 0
825 | for i in range(1, iteration + 1):
826 | p = float(i + epoch * iteration) / args.epoch_preTraining / iteration
827 | m = 2. / (1. + np.exp(-10 * p)) - 1
828 | batch_dict = defaultdict(list)
829 | data_dict = defaultdict(list)
830 | label_dict = defaultdict(list)
831 | label_data_dict = defaultdict(set)
832 |
833 | for j in range(len(source_iters)):
834 | try:
835 | batch_dict[j] = next(source_iters[j])
836 | except:
837 | source_iters[j] = iter(source_loader[j])
838 | batch_dict[j]= next(source_iters[j])
839 | index = 0
840 | for o in batch_dict[j][1]:
841 | cur_label = o[0].item()
842 | data_dict[j].append(batch_dict[j][0][index])
843 | label_dict[j].append(cur_label)
844 | index+=1
845 |
846 | for j in range(len(source_iters)):
847 | subject_id = torch.ones(args.batch_size)
848 | subject_id = subject_id * j
849 | subject_id = subject_id.long()
850 |
851 | source_data, source_label = batch_dict[j]
852 | # prepare corresponding new batch, the new batch has same label with current batch
853 | label_data_dict_list = []
854 | for one_index in range(args.source_subjects):
855 | cur_data_list = data_dict[one_index]
856 | cur_label_list = label_dict[one_index]
857 | for one in range(args.batch_size):
858 | label_data_dict[cur_label_list[one]].add(cur_data_list[one])
859 | label_data_dict_list.append(label_data_dict)
860 | label_data_dict = defaultdict(set)
861 | corres_batch_data = []
862 | for i in range(len(label_data_dict_list)):
863 | for one in source_label:
864 | label_cur = one[0].item()
865 | corres_batch_data.append(random.choice(list(label_data_dict_list[i][label_cur])))
866 | corres_batch_data = torch.stack(corres_batch_data)
867 | if cuda:
868 | source_data = source_data.cuda()
869 | source_label = source_label.cuda()
870 | subject_id = subject_id.cuda()
871 | corres_batch_data =corres_batch_data.cuda()
872 | data_set_all+=len(source_label)
873 | optimizer_PreTraining.zero_grad()
874 | rec_loss, sim_loss = preTrainModel(source_data, corres_batch_data, subject_id, args, m, mark=j)
875 | loss_pretrain = rec_loss + args.beta*sim_loss
876 | loss_pretrain.backward()
877 | optimizer_PreTraining.step()
878 | print("data set amount: "+str(data_set_all))
879 | writer.add_scalars('subject: '+str(one_subject+1)+' '+'train DMMR/loss',
880 | {'loss_pretrain':loss_pretrain.data,'rec_loss':rec_loss.data,'sim_loss':sim_loss.data}, epoch + 1)
881 | end_time_pretrain = time.time()
882 | pretrain_epoch_time = end_time_pretrain - start_time_pretrain
883 | print("The time required for one pre-training epoch is:", pretrain_epoch_time, "second")
884 | print("rec_loss: "+str(rec_loss))
885 |
886 | # The fine-tuning phase
887 | source_iters2 = []
888 | for i in range(len(source_loader)):
889 | source_iters2.append(iter(source_loader[i]))
890 | fineTuneModel = DMMRFineTuningModel(cuda, preTrainModel, number_of_source=len(source_loader),
891 | number_of_category=args.cls_classes, batch_size=args.batch_size,
892 | time_steps=args.time_steps)
893 |
894 | optimizer_FineTuning = torch.optim.Adam(fineTuneModel.parameters(), **optimizer_config)
895 | if cuda:
896 | fineTuneModel = fineTuneModel.cuda()
897 | for epoch in range(args.epoch_fineTuning):
898 | print("epoch: " + str(epoch))
899 | start_time = time.time()
900 | fineTuneModel.train()
901 | count = 0
902 | data_set_all = 0
903 | for i in range(1, iteration + 1):
904 | p = float(i + epoch * iteration) / args.epoch_preTraining / iteration
905 | m = 2. / (1. + np.exp(-10 * p)) - 1
906 | batch_dict = defaultdict(list)
907 | data_dict = defaultdict(list)
908 | label_dict = defaultdict(list)
909 | label_data_dict = defaultdict(set)
910 |
911 | for j in range(len(source_iters2)):
912 | try:
913 | batch_dict[j] = next(source_iters2[j])
914 | except:
915 | source_iters2[j] = iter(source_loader[j])
916 | batch_dict[j] = next(source_iters2[j])
917 | index = 0
918 | for o in batch_dict[j][1]:
919 | cur_label = o[0].item()
920 | data_dict[j].append(batch_dict[j][0][index])
921 | label_dict[j].append(cur_label)
922 | index += 1
923 | for j in range(len(source_iters)):
924 | source_data, source_label = batch_dict[j]
925 | if cuda:
926 | source_data = source_data.cuda()
927 | source_label = source_label.cuda()
928 | data_set_all += len(source_label)
929 | optimizer_FineTuning.zero_grad()
930 | x_pred, x_logits, cls_loss = fineTuneModel(source_data, source_label)
931 | cls_loss.backward()
932 | optimizer_FineTuning.step()
933 | _, pred = torch.max(x_pred, dim=1)
934 | count += pred.eq(source_label.squeeze().data.view_as(pred)).sum()
935 | end_time = time.time()
936 | epoch_time = end_time - start_time
937 | print("The time required for one fine-tuning epoch is:", epoch_time, "second")
938 | print("data set amount: " + str(data_set_all))
939 | acc = float(count) / data_set_all
940 | writer.add_scalars('subject: ' + str(one_subject + 1) + ' ' + 'train DMMR/loss',
941 | {'cls_loss': cls_loss.data}, epoch + 1)
942 | writer.add_scalar('subject: ' + str(one_subject + 1) + ' ' + 'train DMMR/train accuracy', acc, epoch + 1)
943 | print("acc: " + str(acc))
944 |
945 | testModel = DMMRTestModel(fineTuneModel)
946 | acc_DMMR = testDMMR(data_loader_dict["test_loader"], testModel, cuda, args.batch_size)
947 | print("acc_DMMR: " + str(acc_DMMR))
948 | writer.add_scalars('subject: ' + str(one_subject + 1) + ' ' + 'test DMMR/test acc',
949 | {'test acc': acc_DMMR}, epoch + 1)
950 | if acc_DMMR > acc_final:
951 | acc_final = acc_DMMR
952 | best_pretrain_model = copy.deepcopy(preTrainModel.state_dict())
953 | best_tune_model = copy.deepcopy(fineTuneModel.state_dict())
954 | best_test_model = copy.deepcopy(testModel.state_dict())
955 | modelDir = "model/" + args.way + "/" + args.index + "/"
956 | try:
957 | os.makedirs(modelDir)
958 | except:
959 | pass
960 | torch.save(best_pretrain_model, modelDir + str(one_subject) + '_pretrain_model.pth')
961 | torch.save(best_tune_model, modelDir + str(one_subject) + '_tune_model.pth')
962 | torch.save(best_test_model, modelDir + str(one_subject) + '_test_model.pth')
963 | return acc_final
964 |
965 | def trainDMMR_Noise_ChannelsShuffling(data_loader_dict, optimizer_config, cuda, args, iteration, writer, one_subject):
966 | source_loader = data_loader_dict['source_loader']
967 | # The pre-training phase
968 | preTrainModel = PreTrainingWithChannelsShuffling(cuda, number_of_source=len(source_loader), number_of_category=args.cls_classes, batch_size=args.batch_size, time_steps=args.time_steps)
969 | if cuda:
970 | preTrainModel = preTrainModel.cuda()
971 | source_iters = []
972 | for i in range(len(source_loader)):
973 | source_iters.append(iter(source_loader[i]))
974 | optimizer_PreTraining = torch.optim.Adam(preTrainModel.parameters(), **optimizer_config)
975 |
976 | acc_final = 0
977 | for epoch in range(args.epoch_preTraining):
978 | print("epoch: "+str(epoch))
979 | start_time_pretrain = time.time()
980 | preTrainModel.train()
981 | count = 0
982 | data_set_all = 0
983 | for i in range(1, iteration + 1):
984 | p = float(i + epoch * iteration) / args.epoch_preTraining / iteration
985 | m = 2. / (1. + np.exp(-10 * p)) - 1
986 | batch_dict = defaultdict(list)
987 | data_dict = defaultdict(list)
988 | label_dict = defaultdict(list)
989 | label_data_dict = defaultdict(set)
990 |
991 | for j in range(len(source_iters)):
992 | try:
993 | batch_dict[j] = next(source_iters[j])
994 | except:
995 | source_iters[j] = iter(source_loader[j])
996 | batch_dict[j]= next(source_iters[j])
997 | index = 0
998 | for o in batch_dict[j][1]:
999 | cur_label = o[0].item()
1000 | data_dict[j].append(batch_dict[j][0][index])
1001 | label_dict[j].append(cur_label)
1002 | index+=1
1003 |
1004 | for j in range(len(source_iters)):
1005 | subject_id = torch.ones(args.batch_size)
1006 | subject_id = subject_id * j
1007 | subject_id = subject_id.long()
1008 |
1009 | source_data, source_label = batch_dict[j]
1010 | # prepare corresponding new batch, the new batch has same label with current batch
1011 | label_data_dict_list = []
1012 | for one_index in range(args.source_subjects):
1013 | cur_data_list = data_dict[one_index]
1014 | cur_label_list = label_dict[one_index]
1015 | for one in range(args.batch_size):
1016 | label_data_dict[cur_label_list[one]].add(cur_data_list[one])
1017 | label_data_dict_list.append(label_data_dict)
1018 | label_data_dict = defaultdict(set)
1019 | corres_batch_data = []
1020 | for i in range(len(label_data_dict_list)):
1021 | for one in source_label:
1022 | label_cur = one[0].item()
1023 | corres_batch_data.append(random.choice(list(label_data_dict_list[i][label_cur])))
1024 | corres_batch_data = torch.stack(corres_batch_data)
1025 | if cuda:
1026 | source_data = source_data.cuda()
1027 | source_label = source_label.cuda()
1028 | subject_id = subject_id.cuda()
1029 | corres_batch_data =corres_batch_data.cuda()
1030 | data_set_all+=len(source_label)
1031 | optimizer_PreTraining.zero_grad()
1032 | rec_loss, sim_loss = preTrainModel(source_data, corres_batch_data, subject_id, args, m, mark=j)
1033 | loss_pretrain = rec_loss + args.beta*sim_loss
1034 | loss_pretrain.backward()
1035 | optimizer_PreTraining.step()
1036 | print("data set amount: "+str(data_set_all))
1037 | writer.add_scalars('subject: '+str(one_subject+1)+' '+'train DMMR/loss',
1038 | {'loss_pretrain':loss_pretrain.data,'rec_loss':rec_loss.data,'sim_loss':sim_loss.data}, epoch + 1)
1039 | end_time_pretrain = time.time()
1040 | pretrain_epoch_time = end_time_pretrain - start_time_pretrain
1041 | print("The time required for one pre-training epoch is:", pretrain_epoch_time, "second")
1042 | print("rec_loss: "+str(rec_loss))
1043 |
1044 | # The fine-tuning phase
1045 | source_iters2 = []
1046 | for i in range(len(source_loader)):
1047 | source_iters2.append(iter(source_loader[i]))
1048 | fineTuneModel = DMMRFineTuningModel(cuda, preTrainModel, number_of_source=len(source_loader),
1049 | number_of_category=args.cls_classes, batch_size=args.batch_size,
1050 | time_steps=args.time_steps)
1051 |
1052 | optimizer_FineTuning = torch.optim.Adam(fineTuneModel.parameters(), **optimizer_config)
1053 | if cuda:
1054 | fineTuneModel = fineTuneModel.cuda()
1055 | for epoch in range(args.epoch_fineTuning):
1056 | print("epoch: " + str(epoch))
1057 | start_time = time.time()
1058 | fineTuneModel.train()
1059 | count = 0
1060 | data_set_all = 0
1061 | for i in range(1, iteration + 1):
1062 | p = float(i + epoch * iteration) / args.epoch_preTraining / iteration
1063 | m = 2. / (1. + np.exp(-10 * p)) - 1
1064 | batch_dict = defaultdict(list)
1065 | data_dict = defaultdict(list)
1066 | label_dict = defaultdict(list)
1067 | label_data_dict = defaultdict(set)
1068 |
1069 | for j in range(len(source_iters2)):
1070 | try:
1071 | batch_dict[j] = next(source_iters2[j])
1072 | except:
1073 | source_iters2[j] = iter(source_loader[j])
1074 | batch_dict[j] = next(source_iters2[j])
1075 | index = 0
1076 | for o in batch_dict[j][1]:
1077 | cur_label = o[0].item()
1078 | data_dict[j].append(batch_dict[j][0][index])
1079 | label_dict[j].append(cur_label)
1080 | index += 1
1081 | for j in range(len(source_iters)):
1082 | source_data, source_label = batch_dict[j]
1083 | if cuda:
1084 | source_data = source_data.cuda()
1085 | source_label = source_label.cuda()
1086 | data_set_all += len(source_label)
1087 | optimizer_FineTuning.zero_grad()
1088 | x_pred, x_logits, cls_loss = fineTuneModel(source_data, source_label)
1089 | cls_loss.backward()
1090 | optimizer_FineTuning.step()
1091 | _, pred = torch.max(x_pred, dim=1)
1092 | count += pred.eq(source_label.squeeze().data.view_as(pred)).sum()
1093 | end_time = time.time()
1094 | epoch_time = end_time - start_time
1095 | print("The time required for one fine-tuning epoch is:", epoch_time, "second")
1096 | print("data set amount: " + str(data_set_all))
1097 | acc = float(count) / data_set_all
1098 | writer.add_scalars('subject: ' + str(one_subject + 1) + ' ' + 'train DMMR/loss',
1099 | {'cls_loss': cls_loss.data}, epoch + 1)
1100 | writer.add_scalar('subject: ' + str(one_subject + 1) + ' ' + 'train DMMR/train accuracy', acc, epoch + 1)
1101 | print("acc: " + str(acc))
1102 |
1103 | testModel = DMMRTestModel(fineTuneModel)
1104 | acc_DMMR = testDMMR(data_loader_dict["test_loader"], testModel, cuda, args.batch_size)
1105 | print("acc_DMMR: " + str(acc_DMMR))
1106 | writer.add_scalars('subject: ' + str(one_subject + 1) + ' ' + 'test DMMR/test acc',
1107 | {'test acc': acc_DMMR}, epoch + 1)
1108 | if acc_DMMR > acc_final:
1109 | acc_final = acc_DMMR
1110 | best_pretrain_model = copy.deepcopy(preTrainModel.state_dict())
1111 | best_tune_model = copy.deepcopy(fineTuneModel.state_dict())
1112 | best_test_model = copy.deepcopy(testModel.state_dict())
1113 | modelDir = "model/" + args.way + "/" + args.index + "/"
1114 | try:
1115 | os.makedirs(modelDir)
1116 | except:
1117 | pass
1118 | torch.save(best_pretrain_model, modelDir + str(one_subject) + '_pretrain_model.pth')
1119 | torch.save(best_tune_model, modelDir + str(one_subject) + '_tune_model.pth')
1120 | torch.save(best_test_model, modelDir + str(one_subject) + '_test_model.pth')
1121 | return acc_final
1122 |
1123 | def trainDMMR_Noise_Dropout(data_loader_dict, optimizer_config, cuda, args, iteration, writer, one_subject):
1124 | source_loader = data_loader_dict['source_loader']
1125 | # The pre-training phase
1126 | preTrainModel = PreTrainingWithDropout(cuda, number_of_source=len(source_loader), number_of_category=args.cls_classes, batch_size=args.batch_size, time_steps=args.time_steps, dropout_rate=0.2)
1127 | if cuda:
1128 | preTrainModel = preTrainModel.cuda()
1129 | source_iters = []
1130 | for i in range(len(source_loader)):
1131 | source_iters.append(iter(source_loader[i]))
1132 | optimizer_PreTraining = torch.optim.Adam(preTrainModel.parameters(), **optimizer_config)
1133 |
1134 | acc_final = 0
1135 | for epoch in range(args.epoch_preTraining):
1136 | print("epoch: "+str(epoch))
1137 | start_time_pretrain = time.time()
1138 | preTrainModel.train()
1139 | count = 0
1140 | data_set_all = 0
1141 | for i in range(1, iteration + 1):
1142 | p = float(i + epoch * iteration) / args.epoch_preTraining / iteration
1143 | m = 2. / (1. + np.exp(-10 * p)) - 1
1144 | batch_dict = defaultdict(list)
1145 | data_dict = defaultdict(list)
1146 | label_dict = defaultdict(list)
1147 | label_data_dict = defaultdict(set)
1148 |
1149 | for j in range(len(source_iters)):
1150 | try:
1151 | batch_dict[j] = next(source_iters[j])
1152 | except:
1153 | source_iters[j] = iter(source_loader[j])
1154 | batch_dict[j]= next(source_iters[j])
1155 | index = 0
1156 | for o in batch_dict[j][1]:
1157 | cur_label = o[0].item()
1158 | data_dict[j].append(batch_dict[j][0][index])
1159 | label_dict[j].append(cur_label)
1160 | index+=1
1161 |
1162 | for j in range(len(source_iters)):
1163 | subject_id = torch.ones(args.batch_size)
1164 | subject_id = subject_id * j
1165 | subject_id = subject_id.long()
1166 |
1167 | source_data, source_label = batch_dict[j]
1168 | # prepare corresponding new batch, the new batch has same label with current batch
1169 | label_data_dict_list = []
1170 | for one_index in range(args.source_subjects):
1171 | cur_data_list = data_dict[one_index]
1172 | cur_label_list = label_dict[one_index]
1173 | for one in range(args.batch_size):
1174 | label_data_dict[cur_label_list[one]].add(cur_data_list[one])
1175 | label_data_dict_list.append(label_data_dict)
1176 | label_data_dict = defaultdict(set)
1177 | corres_batch_data = []
1178 | for i in range(len(label_data_dict_list)):
1179 | for one in source_label:
1180 | label_cur = one[0].item()
1181 | corres_batch_data.append(random.choice(list(label_data_dict_list[i][label_cur])))
1182 | corres_batch_data = torch.stack(corres_batch_data)
1183 | if cuda:
1184 | source_data = source_data.cuda()
1185 | source_label = source_label.cuda()
1186 | subject_id = subject_id.cuda()
1187 | corres_batch_data =corres_batch_data.cuda()
1188 | data_set_all+=len(source_label)
1189 | optimizer_PreTraining.zero_grad()
1190 | rec_loss, sim_loss = preTrainModel(source_data, corres_batch_data, subject_id, args, m, mark=j)
1191 | loss_pretrain = rec_loss + args.beta*sim_loss
1192 | loss_pretrain.backward()
1193 | optimizer_PreTraining.step()
1194 | print("data set amount: "+str(data_set_all))
1195 | writer.add_scalars('subject: '+str(one_subject+1)+' '+'train DMMR/loss',
1196 | {'loss_pretrain':loss_pretrain.data,'rec_loss':rec_loss.data,'sim_loss':sim_loss.data}, epoch + 1)
1197 | end_time_pretrain = time.time()
1198 | pretrain_epoch_time = end_time_pretrain - start_time_pretrain
1199 | print("The time required for one pre-training epoch is:", pretrain_epoch_time, "second")
1200 | print("rec_loss: "+str(rec_loss))
1201 |
1202 | # The fine-tuning phase
1203 | source_iters2 = []
1204 | for i in range(len(source_loader)):
1205 | source_iters2.append(iter(source_loader[i]))
1206 | fineTuneModel = DMMRFineTuningModel(cuda, preTrainModel, number_of_source=len(source_loader),
1207 | number_of_category=args.cls_classes, batch_size=args.batch_size,
1208 | time_steps=args.time_steps)
1209 |
1210 | optimizer_FineTuning = torch.optim.Adam(fineTuneModel.parameters(), **optimizer_config)
1211 | if cuda:
1212 | fineTuneModel = fineTuneModel.cuda()
1213 | for epoch in range(args.epoch_fineTuning):
1214 | print("epoch: " + str(epoch))
1215 | start_time = time.time()
1216 | fineTuneModel.train()
1217 | count = 0
1218 | data_set_all = 0
1219 | for i in range(1, iteration + 1):
1220 | p = float(i + epoch * iteration) / args.epoch_preTraining / iteration
1221 | m = 2. / (1. + np.exp(-10 * p)) - 1
1222 | batch_dict = defaultdict(list)
1223 | data_dict = defaultdict(list)
1224 | label_dict = defaultdict(list)
1225 | label_data_dict = defaultdict(set)
1226 |
1227 | for j in range(len(source_iters2)):
1228 | try:
1229 | batch_dict[j] = next(source_iters2[j])
1230 | except:
1231 | source_iters2[j] = iter(source_loader[j])
1232 | batch_dict[j] = next(source_iters2[j])
1233 | index = 0
1234 | for o in batch_dict[j][1]:
1235 | cur_label = o[0].item()
1236 | data_dict[j].append(batch_dict[j][0][index])
1237 | label_dict[j].append(cur_label)
1238 | index += 1
1239 | for j in range(len(source_iters)):
1240 | source_data, source_label = batch_dict[j]
1241 | if cuda:
1242 | source_data = source_data.cuda()
1243 | source_label = source_label.cuda()
1244 | data_set_all += len(source_label)
1245 | optimizer_FineTuning.zero_grad()
1246 | x_pred, x_logits, cls_loss = fineTuneModel(source_data, source_label)
1247 | cls_loss.backward()
1248 | optimizer_FineTuning.step()
1249 | _, pred = torch.max(x_pred, dim=1)
1250 | count += pred.eq(source_label.squeeze().data.view_as(pred)).sum()
1251 | end_time = time.time()
1252 | epoch_time = end_time - start_time
1253 | print("The time required for one fine-tuning epoch is:", epoch_time, "second")
1254 | print("data set amount: " + str(data_set_all))
1255 | acc = float(count) / data_set_all
1256 | writer.add_scalars('subject: ' + str(one_subject + 1) + ' ' + 'train DMMR/loss',
1257 | {'cls_loss': cls_loss.data}, epoch + 1)
1258 | writer.add_scalar('subject: ' + str(one_subject + 1) + ' ' + 'train DMMR/train accuracy', acc, epoch + 1)
1259 | print("acc: " + str(acc))
1260 |
1261 | testModel = DMMRTestModel(fineTuneModel)
1262 | acc_DMMR = testDMMR(data_loader_dict["test_loader"], testModel, cuda, args.batch_size)
1263 | print("acc_DMMR: " + str(acc_DMMR))
1264 | writer.add_scalars('subject: ' + str(one_subject + 1) + ' ' + 'test DMMR/test acc',
1265 | {'test acc': acc_DMMR}, epoch + 1)
1266 | if acc_DMMR > acc_final:
1267 | acc_final = acc_DMMR
1268 | best_pretrain_model = copy.deepcopy(preTrainModel.state_dict())
1269 | best_tune_model = copy.deepcopy(fineTuneModel.state_dict())
1270 | best_test_model = copy.deepcopy(testModel.state_dict())
1271 | modelDir = "model/" + args.way + "/" + args.index + "/"
1272 | try:
1273 | os.makedirs(modelDir)
1274 | except:
1275 | pass
1276 | torch.save(best_pretrain_model, modelDir + str(one_subject) + '_pretrain_model.pth')
1277 | torch.save(best_tune_model, modelDir + str(one_subject) + '_tune_model.pth')
1278 | torch.save(best_test_model, modelDir + str(one_subject) + '_test_model.pth')
1279 | return acc_final
1280 |
1281 | ############## T-SNE plots ##############
1282 | class FeatureVisualize(object):
1283 | '''
1284 | Visualize features by TSNE
1285 | '''
1286 |
1287 | def __init__(self, features, labels):
1288 | '''
1289 | features: (m,n)
1290 | labels: (m,)
1291 | '''
1292 | self.features = features
1293 | self.labels = labels
1294 |
1295 | def plot_tsne(self, save_filename, save_eps=False):
1296 | ''' Plot TSNE figure. Set save_eps=True if you want to save a .eps file.
1297 | '''
1298 | tsne = TSNE(n_components=2, init='pca', random_state=0)
1299 | features = tsne.fit_transform(self.features)
1300 | x_min, x_max = np.min(features, 0), np.max(features, 0)
1301 | data = (features - x_min) / (x_max - x_min)
1302 | del features
1303 | for i in range(data.shape[0]):
1304 | colors = plt.cm.tab20.colors
1305 | plt.scatter(data[i, 0], data[i, 1], color=colors[self.labels[i]])
1306 | plt.colorbar()
1307 | plt.xticks([])
1308 | plt.yticks([])
1309 | plt.title('T-SNE')
1310 | if save_eps:
1311 | plt.savefig('tsne.eps', dpi=600, format='eps')
1312 | plt.savefig(save_filename, dpi=600)
1313 | plt.show()
1314 | def TSNEForDMMR(data_loader_dict, cuda, args):
1315 | source_loader = data_loader_dict['source_loader']
1316 | # The pre-training phase
1317 | target_loader = data_loader_dict["test_loader"]
1318 | preTrainModel = DMMRPreTrainingModel(cuda,
1319 | number_of_source=len(source_loader),
1320 | number_of_category=args.cls_classes,
1321 | batch_size=args.batch_size,
1322 | time_steps=args.time_steps)
1323 | #load the pretrained model
1324 | preTrainModel.load_state_dict(torch.load("T-SNE/model/1_pretrain_model.pth", map_location='cpu'))
1325 | preTrainModel.eval()
1326 | pretrainReturnFeature = ModelReturnFeatures(preTrainModel, time_steps=args.time_steps)
1327 | fineTuneModel = DMMRFineTuningModel(cuda, preTrainModel, number_of_source=len(source_loader),
1328 | number_of_category=args.cls_classes,
1329 | batch_size=args.batch_size,
1330 | time_steps=args.time_steps)
1331 | # load the fine-tuned model
1332 | fineTuneModel.load_state_dict(torch.load("T-SNE/model/1_tune_model.pth", map_location='cpu'))
1333 | fineTuneModel.eval()
1334 | fineTuneModelReturnFeauters = ModelReturnFeatures(fineTuneModel, time_steps=args.time_steps)
1335 | fineTuneModelReturnFeauters.eval()
1336 |
1337 | source_iters = []
1338 | for i in range(len(source_loader)):
1339 | source_iters.append(iter(source_loader[i]))
1340 |
1341 | origin_features_list = []
1342 | origin_subject_id_list = []
1343 | label_list = []
1344 | pretrain_shared_features_list = []
1345 | shared_features_list = []
1346 | for i in range(1, 2):
1347 | for j in range(len(source_iters)):
1348 | try:
1349 | source_data, source_label = next(source_iters[j])
1350 | except:
1351 | source_iters[j] = iter(source_loader[j])
1352 | source_data, source_label = next(source_iters[j])
1353 | subject_id = torch.ones(args.batch_size)
1354 | subject_id = subject_id * j
1355 | subject_id = subject_id.long()
1356 |
1357 | _, pretrain_shared_feature = pretrainReturnFeature(source_data)
1358 | _, shared_feature = fineTuneModelReturnFeauters(source_data)
1359 |
1360 | num_samples = 50
1361 | # 50 samples are taken from each individual subject data
1362 | source_data_narray = source_data.numpy()
1363 | label_data_narray = source_label.squeeze().numpy()
1364 | # Reshape for sampling
1365 | source_data_narray = source_data_narray.reshape(512, 30 * 310)
1366 | # Randomly select 50 samples from it to obtain a tensor of size (50, 310).
1367 | random_indices = np.random.choice(source_data_narray.shape[0], num_samples, replace=False)
1368 | source_data_narray_50 = source_data_narray[random_indices]
1369 | subject_narray = np.full((num_samples,), j)
1370 | label_data_narray_50 = label_data_narray[random_indices]
1371 | #origin feature
1372 | origin_features_list.append(source_data_narray_50)
1373 | origin_subject_id_list.append(subject_narray)
1374 | label_list.append(label_data_narray_50)
1375 |
1376 | # pretrained feature
1377 | pretrain_shared_feature_narray = pretrain_shared_feature.detach().numpy()
1378 | pretrain_shared_feature_narray_50 = pretrain_shared_feature_narray[random_indices]
1379 | pretrain_shared_features_list.append(pretrain_shared_feature_narray_50)
1380 | #fine-tuned feature
1381 | shared_feature_narray = shared_feature.detach().numpy()
1382 | shared_feature_narray_50 = shared_feature_narray[random_indices]
1383 | shared_features_list.append(shared_feature_narray_50)
1384 |
1385 | #generate target data
1386 | target_data, target_label = next(iter(target_loader))
1387 | _, target_pretrain_shared_feature = pretrainReturnFeature(target_data)
1388 | _, target_shared_feature = fineTuneModelReturnFeauters(target_data)
1389 | target_data_narray = target_data.numpy()
1390 | target_label = target_label.squeeze().numpy()
1391 | target_data_narray = target_data_narray.reshape(512, 30 * 310)
1392 | random_indices_target = np.random.choice(target_data_narray.shape[0], num_samples, replace=False)
1393 | target_data_narray_50 = target_data_narray[random_indices_target]
1394 | target_subject_id = np.full((num_samples,), 14)
1395 | target_label_narray_50 = target_label[random_indices]
1396 |
1397 |
1398 | #add target subject data
1399 | origin_features_list.append(target_data_narray_50)
1400 | origin_subject_id_list.append(target_subject_id)
1401 | label_list.append(target_label_narray_50)
1402 |
1403 | target_pretrain_shared_feature_narray = target_pretrain_shared_feature.detach().numpy()
1404 | target_pretrain_shared_feature_narray_50 = target_pretrain_shared_feature_narray[random_indices]
1405 | pretrain_shared_features_list.append(target_pretrain_shared_feature_narray_50)
1406 |
1407 | target_shared_feature_narray = target_shared_feature.detach().numpy()
1408 | target_shared_feature_narray_50 = target_shared_feature_narray[random_indices]
1409 | shared_features_list.append(target_shared_feature_narray_50)
1410 |
1411 |
1412 | #concat for later norm
1413 | origin_stacked_feature = np.concatenate(origin_features_list, axis=0)
1414 | stacked_subject_id = np.concatenate(origin_subject_id_list, axis=0)
1415 | stacked_label = np.concatenate(label_list, axis=0)
1416 |
1417 | # T-SNE
1418 | #origin data
1419 | vis_pretrain_shared = FeatureVisualize(origin_stacked_feature, stacked_subject_id)
1420 | vis_pretrain_shared.plot_tsne('T-SNE/plot/origin_subject.jpg',save_eps=False)
1421 | vis_pretrain_shared = FeatureVisualize(origin_stacked_feature, stacked_label)
1422 | vis_pretrain_shared.plot_tsne("T-SNE/plot/origin_label.jpg",save_eps=False)
1423 |
1424 | # pretrained feature
1425 | pretrain_shared_stacked_feature = np.concatenate(pretrain_shared_features_list, axis=0)
1426 | vis_pretrain_shared = FeatureVisualize(pretrain_shared_stacked_feature, stacked_subject_id)
1427 | vis_pretrain_shared.plot_tsne('T-SNE/plot/pretrain_subject.jpg',save_eps=False)
1428 | vis_pretrain_shared = FeatureVisualize(pretrain_shared_stacked_feature, stacked_label)
1429 | vis_pretrain_shared.plot_tsne("T-SNE/plot/pretrain_label.jpg",save_eps=False)
1430 | # fine tuned data
1431 | shared_stacked_feature = np.concatenate(shared_features_list, axis=0)
1432 | vis_shared = FeatureVisualize(shared_stacked_feature, stacked_subject_id)
1433 | vis_shared.plot_tsne("T-SNE/plot/tune_subject.jpg",save_eps=False)
1434 | vis_shared_label = FeatureVisualize(shared_stacked_feature, stacked_label)
1435 | vis_shared_label.plot_tsne("T-SNE/plot/tune_label.jpg",save_eps=False)
1436 | return 0
--------------------------------------------------------------------------------