├── .DS_Store
├── .gitignore
├── .idea
├── .gitignore
├── AdaTime.iml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── LICENSE
├── README.md
├── algorithms
├── __pycache__
│ ├── algorithms.cpython-310.pyc
│ ├── algorithms.cpython-38.pyc
│ └── algorithms.cpython-39.pyc
└── algorithms.py
├── configs
├── __pycache__
│ ├── data_model_configs.cpython-310.pyc
│ ├── data_model_configs.cpython-38.pyc
│ ├── data_model_configs.cpython-39.pyc
│ ├── hparams.cpython-310.pyc
│ ├── hparams.cpython-38.pyc
│ ├── hparams.cpython-39.pyc
│ ├── sweep_params.cpython-310.pyc
│ ├── sweep_params.cpython-38.pyc
│ └── sweep_params.cpython-39.pyc
├── data_model_configs.py
├── hparams.py
└── sweep_params.py
├── dataloader
├── __pycache__
│ ├── dataloader.cpython-310.pyc
│ ├── dataloader.cpython-38.pyc
│ └── dataloader.cpython-39.pyc
└── dataloader.py
├── main.py
├── main_sweep.py
├── misc
├── adatime.PNG
└── results.PNG
├── models
├── __pycache__
│ ├── loss.cpython-310.pyc
│ ├── loss.cpython-38.pyc
│ ├── loss.cpython-39.pyc
│ ├── models.cpython-310.pyc
│ ├── models.cpython-38.pyc
│ ├── models.cpython-39.pyc
│ ├── resnet18.cpython-310.pyc
│ ├── resnet18.cpython-38.pyc
│ └── resnet18.cpython-39.pyc
├── loss.py
├── models.py
└── resnet18.py
├── trainers
├── abstract_trainer.py
├── sweep.py
└── train.py
└── utils.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/.DS_Store
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | wandb/
3 | experiments_logs/
4 | data/
5 | *.pyc
6 | main.py
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/AdaTime.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Emadeldeen Eldele
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [TKDD 2023] AdaTime: A Benchmarking Suite for Domain Adaptation on Time Series Data [[Paper](https://arxiv.org/abs/2203.08321)] [[Cite](#citation)]
2 | #### *by: Mohamed Ragab\*, Emadeldeen Eldele\*, Wee Ling Tan, Chuan-Sheng Foo, Zhenghua Chen☨ , Min Wu, Chee Kwoh, Xiaoli Li* * Equal contribution ☨ Corresponding author
3 |
4 | ## Published in the [ACM Transactions on Knowledge Discovery from Data (TKDD)](https://dl.acm.org/doi/10.1145/3587937).
5 | **AdaTime** is a PyTorch suite to systematically and fairly evaluate different domain adaptation methods on time series data.
6 |
7 |
8 |
9 |
10 |
11 | ## Requirmenets:
12 | - Python3
13 | - Pytorch==1.7
14 | - Numpy==1.20.1
15 | - scikit-learn==0.24.1
16 | - Pandas==1.2.4
17 | - skorch==0.10.0 (For DEV risk calculations)
18 | - openpyxl==3.0.7 (for classification reports)
19 | - Wandb=0.12.7 (for sweeps)
20 |
21 | ## Datasets
22 |
23 | ### Available Datasets
24 | We used four public datasets in this study. We also provide the **preprocessed** versions as follows:
25 | - [Sleep-EDF](https://researchdata.ntu.edu.sg/dataset.xhtml?persistentId=doi:10.21979/N9/UD1IM9)
26 | - [UCIHAR](https://researchdata.ntu.edu.sg/dataset.xhtml?persistentId=doi:10.21979/N9/0SYHTZ)
27 | - [HHAR](https://researchdata.ntu.edu.sg/dataset.xhtml?persistentId=doi:10.21979/N9/OWDFXO)
28 | - [WISDM](https://researchdata.ntu.edu.sg/dataset.xhtml?persistentId=doi:10.21979/N9/KJWE5B)
29 | - [FD](https://researchdata.ntu.edu.sg/dataset.xhtml?persistentId=doi:10.21979/N9/PU85XN)
30 |
31 | ### Adding New Dataset
32 |
33 | #### Structure of data
34 | To add new dataset (*e.g.,* NewData), it should be placed in a folder named: NewData in the datasets directory.
35 |
36 | Since "NewData" has several domains, each domain should be split into train/test splits with naming style as
37 | "train_*x*.pt" and "test_*x*.pt".
38 |
39 | The structure of data files should in dictionary form as follows:
40 | `train.pt = {"samples": data, "labels: labels}`, and similarly for `test.pt`.
41 |
42 | #### Configurations
43 | Next, you have to add a class with the name NewData in the `configs/data_model_configs.py` file.
44 | You can find similar classes for existing datasets as guidelines.
45 | Also, you have to specify the cross-domain scenarios in `self.scenarios` variable.
46 |
47 | Last, you have to add another class with the name NewData in the `configs/hparams.py` file to specify
48 | the training parameters.
49 |
50 |
51 | ## Domain Adaptation Algorithms
52 | ### Existing Algorithms
53 | - [Deep Coral](https://arxiv.org/abs/1607.01719)
54 | - [MMDA](https://arxiv.org/abs/1901.00282)
55 | - [DANN](https://arxiv.org/abs/1505.07818)
56 | - [CDAN](https://arxiv.org/abs/1705.10667)
57 | - [DIRT-T](https://arxiv.org/abs/1802.08735)
58 | - [DSAN](https://ieeexplore.ieee.org/document/9085896)
59 | - [HoMM](https://arxiv.org/pdf/1912.11976.pdf)
60 | - [DDC](https://arxiv.org/abs/1412.3474)
61 | - [CoDATS](https://arxiv.org/pdf/2005.10996.pdf)
62 | - [AdvSKM](https://www.ijcai.org/proceedings/2021/0378.pdf)
63 | - [SASA](https://ojs.aaai.org/index.php/AAAI/article/view/16846/16653)
64 | - [CoTMix](https://arxiv.org/pdf/2212.01555.pdf)
65 |
66 |
67 | ### Adding New Algorithm
68 | To add a new algorithm, place it in `algorithms/algorithms.py` file.
69 |
70 |
71 | ## Training procedure
72 |
73 | The experiments are organised in a hierarchical way such that:
74 | - Several experiments are collected under one directory assigned by `--experiment_description`.
75 | - Each experiment could have different trials, each is specified by `--run_description`.
76 | - For example, if we want to experiment different UDA methods with CNN backbone, we can assign
77 | `--experiment_description CNN_backnones --run_description DANN` and `--experiment_description CNN_backnones --run_description DDC` and so on.
78 |
79 | ### Training a model
80 |
81 | To train a model:
82 |
83 | ```
84 | python main.py --phase train \
85 | --experiment_description exp1 \
86 | --da_method DANN \
87 | --dataset HHAR \
88 | --backbone CNN \
89 | --num_runs 5 \
90 | ```
91 | To test a model:
92 |
93 | ```
94 | python main.py --phase test \
95 | --experiment_description exp1 \
96 | --da_method DANN \
97 | --dataset HHAR \
98 | --backbone CNN \
99 | --num_runs 5 \
100 | ```
101 | ### Launching a sweep
102 | Sweeps here are deployed on [Wandb](https://wandb.ai/), which makes it easier for visualization, following the training progress, organizing sweeps, and collecting results.
103 |
104 | ```
105 | python main_sweep.py --experiment_description exp1_sweep \
106 | --run_description sweep_over_lr \
107 | --da_method DANN \
108 | --dataset HHAR \
109 | --backbone CNN \
110 | --num_runs 5 \
111 | --sweep_project_wandb TEST
112 | --num_sweeps 50 \
113 | ```
114 | Upon the run, you will find the running progress in the specified project page in wandb.
115 |
116 | `Note:` If you got cuda out of memory error during testing, this is probably due to DEV risk calculations.
117 |
118 |
119 | ### Upper and Lower bounds
120 | - To obtain the source-only or the lower bound you can choose the da_method to be `NO_ADAPT`.
121 | - To obtain the the target-only or the upper bound you can choose the da_method `TARGET_ONLY`
122 |
123 | ## Results
124 | - Each run will have all the cross-domain scenarios results in the format `src_to_trg_run_x`, where `x`
125 | is the run_id (you can have multiple runs by assigning `--num_runs` arg).
126 | - Under each directory, you will find the classification report, a log file, checkpoint,
127 | and the different risks scores.
128 | - By the end of the all the runs, you will find the overall average and std results in the run directory.
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 | ## Citation
137 | If you found this work useful for you, please consider citing it.
138 | ```
139 | @article{adatime,
140 | author = {Ragab, Mohamed and Eldele, Emadeldeen and Tan, Wee Ling and Foo, Chuan-Sheng and Chen, Zhenghua and Wu, Min and Kwoh, Chee-Keong and Li, Xiaoli},
141 | title = {ADATIME: A Benchmarking Suite for Domain Adaptation on Time Series Data},
142 | year = {2023},
143 | publisher = {Association for Computing Machinery},
144 | address = {New York, NY, USA},
145 | issn = {1556-4681},
146 | url = {https://doi.org/10.1145/3587937},
147 | doi = {10.1145/3587937},
148 | journal = {ACM Trans. Knowl. Discov. Data},
149 | month = {mar}
150 | }
151 | ```
152 |
153 |
154 | ## Contact
155 | For any issues/questions regarding the paper or reproducing the results, please contact any of the following.
156 |
157 | Mohamed Ragab: *mohamedr002{at}e.ntu.edu.sg*
158 |
159 | Emadeldeen Eldele: *emad0002{at}e.ntu.edu.sg*
160 |
161 | School of Computer Science and Engineering (SCSE),
162 | Nanyang Technological University (NTU), Singapore.
163 |
--------------------------------------------------------------------------------
/algorithms/__pycache__/algorithms.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/algorithms/__pycache__/algorithms.cpython-310.pyc
--------------------------------------------------------------------------------
/algorithms/__pycache__/algorithms.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/algorithms/__pycache__/algorithms.cpython-38.pyc
--------------------------------------------------------------------------------
/algorithms/__pycache__/algorithms.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/algorithms/__pycache__/algorithms.cpython-39.pyc
--------------------------------------------------------------------------------
/algorithms/algorithms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import itertools
5 |
6 | from models.models import classifier, ReverseLayerF, Discriminator, RandomLayer, Discriminator_CDAN, \
7 | codats_classifier, AdvSKM_Disc, CNN_ATTN
8 | from models.loss import MMD_loss, CORAL, ConditionalEntropyLoss, VAT, LMMD_loss, HoMM_loss, NTXentLoss, SupConLoss
9 | from utils import EMA
10 | from torch.optim.lr_scheduler import StepLR
11 | from copy import deepcopy
12 | import torch.nn. functional as F
13 |
14 | def get_algorithm_class(algorithm_name):
15 | """Return the algorithm class with the given name."""
16 | if algorithm_name not in globals():
17 | raise NotImplementedError("Algorithm not found: {}".format(algorithm_name))
18 | return globals()[algorithm_name]
19 |
20 |
21 | class Algorithm(torch.nn.Module):
22 | """
23 | A subclass of Algorithm implements a domain adaptation algorithm.
24 | Subclasses should implement the update() method.
25 | """
26 |
27 | def __init__(self, configs, backbone):
28 | super(Algorithm, self).__init__()
29 | self.configs = configs
30 |
31 | self.cross_entropy = nn.CrossEntropyLoss()
32 | self.feature_extractor = backbone(configs)
33 | self.classifier = classifier(configs)
34 | self.network = nn.Sequential(self.feature_extractor, self.classifier)
35 |
36 |
37 | # update function is common to all algorithms
38 | def update(self, src_loader, trg_loader, avg_meter, logger):
39 | # defining best and last model
40 | best_src_risk = float('inf')
41 | best_model = None
42 |
43 | for epoch in range(1, self.hparams["num_epochs"] + 1):
44 |
45 | # training loop
46 | self.training_epoch(src_loader, trg_loader, avg_meter, epoch)
47 |
48 | # saving the best model based on src risk
49 | if (epoch + 1) % 10 == 0 and avg_meter['Src_cls_loss'].avg < best_src_risk:
50 | best_src_risk = avg_meter['Src_cls_loss'].avg
51 | best_model = deepcopy(self.network.state_dict())
52 |
53 |
54 | logger.debug(f'[Epoch : {epoch}/{self.hparams["num_epochs"]}]')
55 | for key, val in avg_meter.items():
56 | logger.debug(f'{key}\t: {val.avg:2.4f}')
57 | logger.debug(f'-------------------------------------')
58 |
59 | last_model = self.network.state_dict()
60 |
61 | return last_model, best_model
62 |
63 | # train loop vary from one method to another
64 | def training_epoch(self, *args, **kwargs):
65 | raise NotImplementedError
66 |
67 |
68 | class NO_ADAPT(Algorithm):
69 | """
70 | Lower bound: train on source and test on target.
71 | """
72 | def __init__(self, backbone, configs, hparams, device):
73 | super().__init__(configs, backbone)
74 |
75 | # optimizer and scheduler
76 | self.optimizer = torch.optim.Adam(
77 | self.network.parameters(),
78 | lr=hparams["learning_rate"],
79 | weight_decay=hparams["weight_decay"]
80 | )
81 | self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay'])
82 | # hparams
83 | self.hparams = hparams
84 | # device
85 | self.device = device
86 |
87 | def training_epoch(self,src_loader, trg_loader, avg_meter, epoch):
88 | for src_x, src_y in src_loader:
89 |
90 | src_x, src_y = src_x.to(self.device), src_y.to(self.device)
91 | src_feat = self.feature_extractor(src_x)
92 | src_pred = self.classifier(src_feat)
93 |
94 | src_cls_loss = self.cross_entropy(src_pred, src_y)
95 |
96 | loss = src_cls_loss
97 |
98 | self.optimizer.zero_grad()
99 | loss.backward()
100 | self.optimizer.step()
101 |
102 | losses = {'Src_cls_loss': src_cls_loss.item()}
103 |
104 | for key, val in losses.items():
105 | avg_meter[key].update(val, 32)
106 |
107 | self.lr_scheduler.step()
108 |
109 |
110 | class TARGET_ONLY(Algorithm):
111 | """
112 | Upper bound: train on target and test on target.
113 | """
114 |
115 | def __init__(self, backbone, configs, hparams, device):
116 | super().__init__(configs, backbone)
117 |
118 | # optimizer and scheduler
119 | self.optimizer = torch.optim.Adam(
120 | self.network.parameters(),
121 | lr=hparams["learning_rate"],
122 | weight_decay=hparams["weight_decay"]
123 | )
124 | self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay'])
125 | # hparams
126 | self.hparams = hparams
127 | # device
128 | self.device = device
129 |
130 | def training_epoch(self, src_loader, trg_loader, avg_meter, epoch):
131 |
132 | for trg_x, trg_y in trg_loader:
133 |
134 | trg_x, trg_y = trg_x.to(self.device), trg_y.to(self.device)
135 |
136 | trg_feat = self.feature_extractor(trg_x)
137 | trg_pred = self.classifier(trg_feat)
138 |
139 | trg_cls_loss = self.cross_entropy(trg_pred, trg_y)
140 |
141 | loss = trg_cls_loss
142 |
143 | self.optimizer.zero_grad()
144 | loss.backward()
145 | self.optimizer.step()
146 |
147 | losses = {'Trg_cls_loss': trg_cls_loss.item()}
148 |
149 | for key, val in losses.items():
150 | avg_meter[key].update(val, 32)
151 |
152 | self.lr_scheduler.step()
153 |
154 |
155 | class Deep_Coral(Algorithm):
156 | """
157 | Deep Coral: https://arxiv.org/abs/1607.01719
158 | """
159 | def __init__(self, backbone, configs, hparams, device):
160 | super().__init__(configs, backbone)
161 |
162 | # optimizer and scheduler
163 | self.optimizer = torch.optim.Adam(
164 | self.network.parameters(),
165 | lr=hparams["learning_rate"],
166 | weight_decay=hparams["weight_decay"]
167 | )
168 | self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay'])
169 | # hparams
170 | self.hparams = hparams
171 | # device
172 | self.device = device
173 |
174 | # correlation alignment loss
175 | self.coral = CORAL()
176 |
177 |
178 | def training_epoch(self,src_loader, trg_loader, avg_meter, epoch):
179 |
180 | # Construct Joint Loaders
181 | # add if statement
182 |
183 | if len(src_loader) > len(trg_loader):
184 | joint_loader =enumerate(zip(src_loader, itertools.cycle(trg_loader)))
185 | else:
186 | joint_loader =enumerate(zip(itertools.cycle(src_loader), trg_loader))
187 |
188 | for step, ((src_x, src_y), (trg_x, _)) in joint_loader:
189 | src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(self.device)
190 |
191 | src_feat = self.feature_extractor(src_x)
192 | src_pred = self.classifier(src_feat)
193 |
194 | src_cls_loss = self.cross_entropy(src_pred, src_y)
195 |
196 | trg_feat = self.feature_extractor(trg_x)
197 |
198 | coral_loss = self.coral(src_feat, trg_feat)
199 |
200 | loss = self.hparams["coral_wt"] * coral_loss + \
201 | self.hparams["src_cls_loss_wt"] * src_cls_loss
202 |
203 | self.optimizer.zero_grad()
204 | loss.backward()
205 | self.optimizer.step()
206 |
207 | losses = {'Total_loss': loss.item(), 'Src_cls_loss': src_cls_loss.item(),
208 | 'coral_loss': coral_loss.item()}
209 |
210 | for key, val in losses.items():
211 | avg_meter[key].update(val, 32)
212 |
213 | self.lr_scheduler.step()
214 |
215 | class MMDA(Algorithm):
216 | """
217 | MMDA: https://arxiv.org/abs/1901.00282
218 | """
219 |
220 | def __init__(self, backbone, configs, hparams, device):
221 | super().__init__(configs, backbone)
222 |
223 | # optimizer and scheduler
224 | self.optimizer = torch.optim.Adam(
225 | self.network.parameters(),
226 | lr=hparams["learning_rate"],
227 | weight_decay=hparams["weight_decay"]
228 | )
229 | self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay'])
230 | # hparams
231 | self.hparams = hparams
232 | # device
233 | self.device = device
234 |
235 | # Aligment losses
236 | self.mmd = MMD_loss()
237 | self.coral = CORAL()
238 | self.cond_ent = ConditionalEntropyLoss()
239 |
240 |
241 | def training_epoch(self,src_loader, trg_loader, avg_meter, epoch):
242 |
243 | # Construct Joint Loaders
244 | joint_loader =enumerate(zip(src_loader, itertools.cycle(trg_loader)))
245 |
246 | for step, ((src_x, src_y), (trg_x, _)) in joint_loader:
247 | src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(self.device)
248 |
249 | src_feat = self.feature_extractor(src_x)
250 | src_pred = self.classifier(src_feat)
251 |
252 | src_cls_loss = self.cross_entropy(src_pred, src_y)
253 |
254 | trg_feat = self.feature_extractor(trg_x)
255 | src_feat = self.feature_extractor(src_x)
256 | src_pred = self.classifier(src_feat)
257 |
258 | src_cls_loss = self.cross_entropy(src_pred, src_y)
259 |
260 | trg_feat = self.feature_extractor(trg_x)
261 |
262 | coral_loss = self.coral(src_feat, trg_feat)
263 | mmd_loss = self.mmd(src_feat, trg_feat)
264 | cond_ent_loss = self.cond_ent(trg_feat)
265 |
266 | loss = self.hparams["coral_wt"] * coral_loss + \
267 | self.hparams["mmd_wt"] * mmd_loss + \
268 | self.hparams["cond_ent_wt"] * cond_ent_loss + \
269 | self.hparams["src_cls_loss_wt"] * src_cls_loss
270 |
271 | self.optimizer.zero_grad()
272 | loss.backward()
273 | self.optimizer.step()
274 |
275 | losses = {'Total_loss': loss.item(), 'Coral_loss': coral_loss.item(), 'MMD_loss': mmd_loss.item(),
276 | 'cond_ent_wt': cond_ent_loss.item(), 'Src_cls_loss': src_cls_loss.item()}
277 |
278 | for key, val in losses.items():
279 | avg_meter[key].update(val, 32)
280 |
281 | self.lr_scheduler.step()
282 |
283 |
284 | class DANN(Algorithm):
285 | """
286 | DANN: https://arxiv.org/abs/1505.07818
287 | """
288 |
289 | def __init__(self, backbone, configs, hparams, device):
290 | super().__init__(configs, backbone)
291 |
292 |
293 | # optimizer and scheduler
294 | self.optimizer = torch.optim.Adam(
295 | self.network.parameters(),
296 | lr=hparams["learning_rate"],
297 | weight_decay=hparams["weight_decay"]
298 | )
299 | self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay'])
300 | # hparams
301 | self.hparams = hparams
302 | # device
303 | self.device = device
304 |
305 | # Domain Discriminator
306 | self.domain_classifier = Discriminator(configs)
307 | self.optimizer_disc = torch.optim.Adam(
308 | self.domain_classifier.parameters(),
309 | lr=hparams["learning_rate"],
310 | weight_decay=hparams["weight_decay"], betas=(0.5, 0.99)
311 | )
312 |
313 | def training_epoch(self,src_loader, trg_loader, avg_meter, epoch):
314 | # Combine dataloaders
315 | # Method 1 (min len of both domains)
316 | # joint_loader = enumerate(zip(src_loader, trg_loader))
317 |
318 | # Method 2 (max len of both domains)
319 | # joint_loader =enumerate(zip(src_loader, itertools.cycle(trg_loader)))
320 | joint_loader =enumerate(zip(src_loader, itertools.cycle(trg_loader)))
321 | num_batches = max(len(src_loader), len(trg_loader))
322 |
323 | for step, ((src_x, src_y), (trg_x, _)) in joint_loader:
324 |
325 | src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(self.device)
326 |
327 | p = float(step + epoch * num_batches) / self.hparams["num_epochs"] + 1 / num_batches
328 | alpha = 2. / (1. + np.exp(-10 * p)) - 1
329 |
330 | # zero grad
331 | self.optimizer.zero_grad()
332 | self.optimizer_disc.zero_grad()
333 |
334 | domain_label_src = torch.ones(len(src_x)).to(self.device)
335 | domain_label_trg = torch.zeros(len(trg_x)).to(self.device)
336 |
337 | src_feat = self.feature_extractor(src_x)
338 | src_pred = self.classifier(src_feat)
339 |
340 | trg_feat = self.feature_extractor(trg_x)
341 |
342 | # Task classification Loss
343 | src_cls_loss = self.cross_entropy(src_pred.squeeze(), src_y)
344 |
345 | # Domain classification loss
346 | # source
347 | src_feat_reversed = ReverseLayerF.apply(src_feat, alpha)
348 | src_domain_pred = self.domain_classifier(src_feat_reversed)
349 | src_domain_loss = self.cross_entropy(src_domain_pred, domain_label_src.long())
350 |
351 | # target
352 | trg_feat_reversed = ReverseLayerF.apply(trg_feat, alpha)
353 | trg_domain_pred = self.domain_classifier(trg_feat_reversed)
354 | trg_domain_loss = self.cross_entropy(trg_domain_pred, domain_label_trg.long())
355 |
356 | # Total domain loss
357 | domain_loss = src_domain_loss + trg_domain_loss
358 |
359 | loss = self.hparams["src_cls_loss_wt"] * src_cls_loss + \
360 | self.hparams["domain_loss_wt"] * domain_loss
361 |
362 | loss.backward()
363 | self.optimizer.step()
364 | self.optimizer_disc.step()
365 |
366 | losses = {'Total_loss': loss.item(), 'Domain_loss': domain_loss.item(), 'Src_cls_loss': src_cls_loss.item()}
367 |
368 | for key, val in losses.items():
369 | avg_meter[key].update(val, 32)
370 |
371 | self.lr_scheduler.step()
372 |
373 | class CDAN(Algorithm):
374 | """
375 | CDAN: https://arxiv.org/abs/1705.10667
376 | """
377 |
378 | def __init__(self, backbone, configs, hparams, device):
379 | super().__init__(configs, backbone)
380 |
381 |
382 | # optimizer and scheduler
383 | self.optimizer = torch.optim.Adam(
384 | self.network.parameters(),
385 | lr=hparams["learning_rate"],
386 | weight_decay=hparams["weight_decay"]
387 | )
388 | self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay'])
389 | # hparams
390 | self.hparams = hparams
391 | # device
392 | self.device = device
393 |
394 | # Aligment Losses
395 | self.criterion_cond = ConditionalEntropyLoss().to(device)
396 |
397 | self.domain_classifier = Discriminator_CDAN(configs)
398 | self.random_layer = RandomLayer([configs.features_len * configs.final_out_channels, configs.num_classes],
399 | configs.features_len * configs.final_out_channels)
400 | self.optimizer_disc = torch.optim.Adam(
401 | self.domain_classifier.parameters(),
402 | lr=hparams["learning_rate"],
403 | weight_decay=hparams["weight_decay"])
404 |
405 | def training_epoch(self,src_loader, trg_loader, avg_meter, epoch):
406 |
407 | # Construct Joint Loaders
408 | joint_loader =enumerate(zip(src_loader, itertools.cycle(trg_loader)))
409 |
410 | for step, ((src_x, src_y), (trg_x, _)) in joint_loader:
411 | src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(self.device)
412 | # prepare true domain labels
413 | domain_label_src = torch.ones(len(src_x)).to(self.device)
414 | domain_label_trg = torch.zeros(len(trg_x)).to(self.device)
415 | domain_label_concat = torch.cat((domain_label_src, domain_label_trg), 0).long()
416 |
417 | # source features and predictions
418 | src_feat = self.feature_extractor(src_x)
419 | src_pred = self.classifier(src_feat)
420 |
421 | # target features and predictions
422 | trg_feat = self.feature_extractor(trg_x)
423 | trg_pred = self.classifier(trg_feat)
424 |
425 | # concatenate features and predictions
426 | feat_concat = torch.cat((src_feat, trg_feat), dim=0)
427 | pred_concat = torch.cat((src_pred, trg_pred), dim=0)
428 |
429 | # Domain classification loss
430 | feat_x_pred = torch.bmm(pred_concat.unsqueeze(2), feat_concat.unsqueeze(1)).detach()
431 | disc_prediction = self.domain_classifier(feat_x_pred.view(-1, pred_concat.size(1) * feat_concat.size(1)))
432 | disc_loss = self.cross_entropy(disc_prediction, domain_label_concat)
433 |
434 | # update Domain classification
435 | self.optimizer_disc.zero_grad()
436 | disc_loss.backward()
437 | self.optimizer_disc.step()
438 |
439 | # prepare fake domain labels for training the feature extractor
440 | domain_label_src = torch.zeros(len(src_x)).long().to(self.device)
441 | domain_label_trg = torch.ones(len(trg_x)).long().to(self.device)
442 | domain_label_concat = torch.cat((domain_label_src, domain_label_trg), 0)
443 |
444 | # Repeat predictions after updating discriminator
445 | feat_x_pred = torch.bmm(pred_concat.unsqueeze(2), feat_concat.unsqueeze(1))
446 | disc_prediction = self.domain_classifier(feat_x_pred.view(-1, pred_concat.size(1) * feat_concat.size(1)))
447 | # loss of domain discriminator according to fake labels
448 |
449 | domain_loss = self.cross_entropy(disc_prediction, domain_label_concat)
450 |
451 | # Task classification Loss
452 | src_cls_loss = self.cross_entropy(src_pred.squeeze(), src_y)
453 |
454 | # conditional entropy loss.
455 | loss_trg_cent = self.criterion_cond(trg_pred)
456 |
457 | # total loss
458 | loss = self.hparams["src_cls_loss_wt"] * src_cls_loss + self.hparams["domain_loss_wt"] * domain_loss + \
459 | self.hparams["cond_ent_wt"] * loss_trg_cent
460 |
461 | # update feature extractor
462 | self.optimizer.zero_grad()
463 | loss.backward()
464 | self.optimizer.step()
465 |
466 | losses = {'Total_loss': loss.item(), 'Domain_loss': domain_loss.item(), 'Src_cls_loss': src_cls_loss.item(),
467 | 'cond_ent_loss': loss_trg_cent.item()}
468 |
469 | for key, val in losses.items():
470 | avg_meter[key].update(val, 32)
471 | self.lr_scheduler.step()
472 |
473 | class DIRT(Algorithm):
474 | """
475 | DIRT-T: https://arxiv.org/abs/1802.08735
476 | """
477 |
478 | def __init__(self, backbone, configs, hparams, device):
479 | super().__init__(configs, backbone)
480 |
481 | # optimizer and scheduler
482 | self.optimizer = torch.optim.Adam(
483 | self.network.parameters(),
484 | lr=hparams["learning_rate"],
485 | weight_decay=hparams["weight_decay"]
486 | )
487 | self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay'])
488 | # hparams
489 | self.hparams = hparams
490 | # device
491 | self.device = device
492 |
493 |
494 | # Aligment losses
495 | self.criterion_cond = ConditionalEntropyLoss().to(device)
496 | self.vat_loss = VAT(self.network, device).to(device)
497 | self.ema = EMA(0.998)
498 | self.ema.register(self.network)
499 |
500 | # Discriminator
501 | self.domain_classifier = Discriminator(configs)
502 | self.optimizer_disc = torch.optim.Adam(
503 | self.domain_classifier.parameters(),
504 | lr=hparams["learning_rate"],
505 | weight_decay=hparams["weight_decay"]
506 | )
507 |
508 | def training_epoch(self,src_loader, trg_loader, avg_meter, epoch):
509 |
510 | # Construct Joint Loaders
511 | joint_loader =enumerate(zip(src_loader, itertools.cycle(trg_loader)))
512 |
513 | for step, ((src_x, src_y), (trg_x, _)) in joint_loader:
514 | src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(self.device)
515 | # prepare true domain labels
516 | domain_label_src = torch.ones(len(src_x)).to(self.device)
517 | domain_label_trg = torch.zeros(len(trg_x)).to(self.device)
518 | domain_label_concat = torch.cat((domain_label_src, domain_label_trg), 0).long()
519 |
520 | src_feat = self.feature_extractor(src_x)
521 | src_pred = self.classifier(src_feat)
522 |
523 | # target features and predictions
524 | trg_feat = self.feature_extractor(trg_x)
525 | trg_pred = self.classifier(trg_feat)
526 |
527 | # concatenate features and predictions
528 | feat_concat = torch.cat((src_feat, trg_feat), dim=0)
529 |
530 | # Domain classification loss
531 | disc_prediction = self.domain_classifier(feat_concat.detach())
532 | disc_loss = self.cross_entropy(disc_prediction, domain_label_concat)
533 |
534 | # update Domain classification
535 | self.optimizer_disc.zero_grad()
536 | disc_loss.backward()
537 | self.optimizer_disc.step()
538 |
539 | # prepare fake domain labels for training the feature extractor
540 | domain_label_src = torch.zeros(len(src_x)).long().to(self.device)
541 | domain_label_trg = torch.ones(len(trg_x)).long().to(self.device)
542 | domain_label_concat = torch.cat((domain_label_src, domain_label_trg), 0)
543 |
544 | # Repeat predictions after updating discriminator
545 | disc_prediction = self.domain_classifier(feat_concat)
546 |
547 | # loss of domain discriminator according to fake labels
548 | domain_loss = self.cross_entropy(disc_prediction, domain_label_concat)
549 |
550 | # Task classification Loss
551 | src_cls_loss = self.cross_entropy(src_pred.squeeze(), src_y)
552 |
553 | # conditional entropy loss.
554 | loss_trg_cent = self.criterion_cond(trg_pred)
555 |
556 | # Virual advariarial training loss
557 | loss_src_vat = self.vat_loss(src_x, src_pred)
558 | loss_trg_vat = self.vat_loss(trg_x, trg_pred)
559 | total_vat = loss_src_vat + loss_trg_vat
560 | # total loss
561 | loss = self.hparams["src_cls_loss_wt"] * src_cls_loss + self.hparams["domain_loss_wt"] * domain_loss + \
562 | self.hparams["cond_ent_wt"] * loss_trg_cent + self.hparams["vat_loss_wt"] * total_vat
563 |
564 | # update exponential moving average
565 | self.ema(self.network)
566 |
567 | # update feature extractor
568 | self.optimizer.zero_grad()
569 | loss.backward()
570 | self.optimizer.step()
571 |
572 | losses = {'Total_loss': loss.item(), 'Domain_loss': domain_loss.item(), 'Src_cls_loss': src_cls_loss.item(),
573 | 'cond_ent_loss': loss_trg_cent.item()}
574 |
575 | for key, val in losses.items():
576 | avg_meter[key].update(val, 32)
577 |
578 | self.lr_scheduler.step()
579 |
580 | class DSAN(Algorithm):
581 | """
582 | DSAN: https://ieeexplore.ieee.org/document/9085896
583 | """
584 |
585 | def __init__(self, backbone, configs, hparams, device):
586 | super().__init__(configs, backbone)
587 |
588 | # optimizer and scheduler
589 | self.optimizer = torch.optim.Adam(
590 | self.network.parameters(),
591 | lr=hparams["learning_rate"],
592 | weight_decay=hparams["weight_decay"]
593 | )
594 | self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay'])
595 | # hparams
596 | self.hparams = hparams
597 | # device
598 | self.device = device
599 |
600 | # Alignment losses
601 | self.loss_LMMD = LMMD_loss(device=device, class_num=configs.num_classes).to(device)
602 |
603 | def training_epoch(self,src_loader, trg_loader, avg_meter, epoch):
604 |
605 | # Construct Joint Loaders
606 | joint_loader =enumerate(zip(src_loader, itertools.cycle(trg_loader)))
607 |
608 | for step, ((src_x, src_y), (trg_x, _)) in joint_loader:
609 | src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(self.device) # extract source features
610 | src_feat = self.feature_extractor(src_x)
611 | src_pred = self.classifier(src_feat)
612 |
613 | # extract target features
614 | trg_feat = self.feature_extractor(trg_x)
615 | trg_pred = self.classifier(trg_feat)
616 |
617 | # calculate lmmd loss
618 | domain_loss = self.loss_LMMD.get_loss(src_feat, trg_feat, src_y, torch.nn.functional.softmax(trg_pred, dim=1))
619 |
620 | # calculate source classification loss
621 | src_cls_loss = self.cross_entropy(src_pred, src_y)
622 |
623 | # calculate the total loss
624 | loss = self.hparams["domain_loss_wt"] * domain_loss + \
625 | self.hparams["src_cls_loss_wt"] * src_cls_loss
626 |
627 | self.optimizer.zero_grad()
628 | loss.backward()
629 | self.optimizer.step()
630 |
631 | losses = {'Total_loss': loss.item(), 'LMMD_loss': domain_loss.item(), 'Src_cls_loss': src_cls_loss.item()}
632 |
633 | for key, val in losses.items():
634 | avg_meter[key].update(val, 32)
635 |
636 | self.lr_scheduler.step()
637 |
638 | class HoMM(Algorithm):
639 | """
640 | HoMM: https://arxiv.org/pdf/1912.11976.pdf
641 | """
642 |
643 | def __init__(self, backbone, configs, hparams, device):
644 | super().__init__(configs, backbone)
645 |
646 | # optimizer and scheduler
647 | self.optimizer = torch.optim.Adam(
648 | self.network.parameters(),
649 | lr=hparams["learning_rate"],
650 | weight_decay=hparams["weight_decay"]
651 | )
652 | self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay'])
653 | # hparams
654 | self.hparams = hparams
655 | # device
656 | self.device = device
657 |
658 | # aligment losses
659 | self.coral = CORAL()
660 | self.HoMM_loss = HoMM_loss()
661 |
662 | def training_epoch(self,src_loader, trg_loader, avg_meter, epoch):
663 |
664 | # Construct Joint Loaders
665 | joint_loader =enumerate(zip(src_loader, itertools.cycle(trg_loader)))
666 |
667 | for step, ((src_x, src_y), (trg_x, _)) in joint_loader:
668 | src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(self.device) # extract source features
669 |
670 | src_feat = self.feature_extractor(src_x)
671 | src_pred = self.classifier(src_feat)
672 |
673 | # extract target features
674 | trg_feat = self.feature_extractor(trg_x)
675 | trg_pred = self.classifier(trg_feat)
676 |
677 | # calculate source classification loss
678 | src_cls_loss = self.cross_entropy(src_pred, src_y)
679 |
680 | # calculate lmmd loss
681 | domain_loss = self.HoMM_loss(src_feat, trg_feat)
682 |
683 | # calculate the total loss
684 | loss = self.hparams["domain_loss_wt"] * domain_loss + \
685 | self.hparams["src_cls_loss_wt"] * src_cls_loss
686 |
687 | self.optimizer.zero_grad()
688 | loss.backward()
689 | self.optimizer.step()
690 |
691 | losses = {'Total_loss': loss.item(), 'HoMM_loss': domain_loss.item(), 'Src_cls_loss': src_cls_loss.item()}
692 |
693 | for key, val in losses.items():
694 | avg_meter[key].update(val, 32)
695 |
696 | self.lr_scheduler.step()
697 |
698 |
699 | class DDC(Algorithm):
700 | """
701 | DDC: https://arxiv.org/abs/1412.3474
702 | """
703 |
704 | def __init__(self, backbone, configs, hparams, device):
705 | super().__init__(configs, backbone)
706 |
707 | # optimizer and scheduler
708 | self.optimizer = torch.optim.Adam(
709 | self.network.parameters(),
710 | lr=hparams["learning_rate"],
711 | weight_decay=hparams["weight_decay"]
712 | )
713 | self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay'])
714 | # hparams
715 | self.hparams = hparams
716 | # device
717 | self.device = device
718 |
719 | # Aligment losses
720 | self.mmd_loss = MMD_loss()
721 |
722 | def training_epoch(self, src_loader, trg_loader, avg_meter, epoch):
723 |
724 | # Construct Joint Loaders
725 | joint_loader =enumerate(zip(src_loader, itertools.cycle(trg_loader)))
726 |
727 | for step, ((src_x, src_y), (trg_x, _)) in joint_loader:
728 | src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(self.device) # extract source features
729 | # extract source features
730 | src_feat = self.feature_extractor(src_x)
731 | src_pred = self.classifier(src_feat)
732 |
733 | # extract target features
734 | trg_feat = self.feature_extractor(trg_x)
735 |
736 | # calculate source classification loss
737 | src_cls_loss = self.cross_entropy(src_pred, src_y)
738 |
739 | # calculate mmd loss
740 | domain_loss = self.mmd_loss(src_feat, trg_feat)
741 |
742 | # calculate the total loss
743 | loss = self.hparams["domain_loss_wt"] * domain_loss + \
744 | self.hparams["src_cls_loss_wt"] * src_cls_loss
745 |
746 | self.optimizer.zero_grad()
747 | loss.backward()
748 | self.optimizer.step()
749 |
750 | losses = {'Total_loss': loss.item(), 'MMD_loss': domain_loss.item(), 'Src_cls_loss': src_cls_loss.item()}
751 |
752 | for key, val in losses.items():
753 | avg_meter[key].update(val, 32)
754 |
755 | self.lr_scheduler.step()
756 |
757 | class CoDATS(Algorithm):
758 | """
759 | CoDATS: https://arxiv.org/pdf/2005.10996.pdf
760 | """
761 |
762 | def __init__(self, backbone, configs, hparams, device):
763 | super().__init__(configs, backbone)
764 |
765 | # we replace the original classifier with codats the classifier
766 | # remember to use same name of self.classifier, as we use it for the model evaluation
767 | self.classifier = codats_classifier(configs)
768 | self.network = nn.Sequential(self.feature_extractor, self.classifier)
769 |
770 | # optimizer and scheduler
771 | self.optimizer = torch.optim.Adam(
772 | self.network.parameters(),
773 | lr=hparams["learning_rate"],
774 | weight_decay=hparams["weight_decay"]
775 | )
776 | self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay'])
777 | # hparams
778 | self.hparams = hparams
779 | # device
780 | self.device = device
781 |
782 |
783 | # Domain classifier
784 | self.domain_classifier = Discriminator(configs)
785 |
786 | self.optimizer_disc = torch.optim.Adam(
787 | self.domain_classifier.parameters(),
788 | lr=hparams["learning_rate"],
789 | weight_decay=hparams["weight_decay"], betas=(0.5, 0.99)
790 | )
791 |
792 | def training_epoch(self,src_loader, trg_loader, avg_meter, epoch):
793 |
794 | # Construct Joint Loaders
795 | joint_loader =enumerate(zip(src_loader, itertools.cycle(trg_loader)))
796 | num_batches = max(len(src_loader), len(trg_loader))
797 | for step, ((src_x, src_y), (trg_x, _)) in joint_loader:
798 | src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(self.device) # extract source features
799 |
800 | p = float(step + epoch * num_batches) / self.hparams["num_epochs"] + 1 / num_batches
801 | alpha = 2. / (1. + np.exp(-10 * p)) - 1
802 |
803 | # zero grad
804 | self.optimizer.zero_grad()
805 | self.optimizer_disc.zero_grad()
806 |
807 | domain_label_src = torch.ones(len(src_x)).to(self.device)
808 | domain_label_trg = torch.zeros(len(trg_x)).to(self.device)
809 |
810 | src_feat = self.feature_extractor(src_x)
811 | src_pred = self.classifier(src_feat)
812 |
813 | trg_feat = self.feature_extractor(trg_x)
814 |
815 | # Task classification Loss
816 | src_cls_loss = self.cross_entropy(src_pred.squeeze(), src_y)
817 |
818 | # Domain classification loss
819 | # source
820 | src_feat_reversed = ReverseLayerF.apply(src_feat, alpha)
821 | src_domain_pred = self.domain_classifier(src_feat_reversed)
822 | src_domain_loss = self.cross_entropy(src_domain_pred, domain_label_src.long())
823 |
824 | # target
825 | trg_feat_reversed = ReverseLayerF.apply(trg_feat, alpha)
826 | trg_domain_pred = self.domain_classifier(trg_feat_reversed)
827 | trg_domain_loss = self.cross_entropy(trg_domain_pred, domain_label_trg.long())
828 |
829 | # Total domain loss
830 | domain_loss = src_domain_loss + trg_domain_loss
831 |
832 | loss = self.hparams["src_cls_loss_wt"] * src_cls_loss + \
833 | self.hparams["domain_loss_wt"] * domain_loss
834 |
835 | loss.backward()
836 | self.optimizer.step()
837 | self.optimizer_disc.step()
838 |
839 | losses = {'Total_loss': loss.item(), 'Domain_loss': domain_loss.item(), 'Src_cls_loss': src_cls_loss.item()}
840 | for key, val in losses.items():
841 | avg_meter[key].update(val, 32)
842 |
843 | self.lr_scheduler.step()
844 |
845 | class AdvSKM(Algorithm):
846 | """
847 | AdvSKM: https://www.ijcai.org/proceedings/2021/0378.pdf
848 | """
849 |
850 | def __init__(self, backbone, configs, hparams, device):
851 | super().__init__(configs, backbone)
852 |
853 | # optimizer and scheduler
854 | self.optimizer = torch.optim.Adam(
855 | self.network.parameters(),
856 | lr=hparams["learning_rate"],
857 | weight_decay=hparams["weight_decay"]
858 | )
859 | self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay'])
860 | # hparams
861 | self.hparams = hparams
862 | # device
863 | self.device = device
864 |
865 | # Aligment losses
866 | self.mmd_loss = MMD_loss()
867 | self.AdvSKM_embedder = AdvSKM_Disc(configs).to(device)
868 | self.optimizer_disc = torch.optim.Adam(
869 | self.AdvSKM_embedder.parameters(),
870 | lr=hparams["learning_rate"],
871 | weight_decay=hparams["weight_decay"]
872 | )
873 |
874 | def training_epoch(self,src_loader, trg_loader, avg_meter, epoch):
875 |
876 | # Construct Joint Loaders
877 | joint_loader =enumerate(zip(src_loader, itertools.cycle(trg_loader)))
878 | for step, ((src_x, src_y), (trg_x, _)) in joint_loader:
879 | src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(self.device) # extract source features
880 |
881 | src_feat = self.feature_extractor(src_x)
882 | src_pred = self.classifier(src_feat)
883 |
884 | # extract target features
885 | trg_feat = self.feature_extractor(trg_x)
886 |
887 | source_embedding_disc = self.AdvSKM_embedder(src_feat.detach())
888 | target_embedding_disc = self.AdvSKM_embedder(trg_feat.detach())
889 | mmd_loss = - self.mmd_loss(source_embedding_disc, target_embedding_disc)
890 | mmd_loss.requires_grad = True
891 |
892 | # update discriminator
893 | self.optimizer_disc.zero_grad()
894 | mmd_loss.backward()
895 | self.optimizer_disc.step()
896 |
897 | # calculate source classification loss
898 | src_cls_loss = self.cross_entropy(src_pred, src_y)
899 |
900 | # domain loss.
901 | source_embedding_disc = self.AdvSKM_embedder(src_feat)
902 | target_embedding_disc = self.AdvSKM_embedder(trg_feat)
903 |
904 | mmd_loss_adv = self.mmd_loss(source_embedding_disc, target_embedding_disc)
905 | mmd_loss_adv.requires_grad = True
906 |
907 | # calculate the total loss
908 | loss = self.hparams["domain_loss_wt"] * mmd_loss_adv + \
909 | self.hparams["src_cls_loss_wt"] * src_cls_loss
910 |
911 | # update optimizer
912 | self.optimizer.zero_grad()
913 | loss.backward()
914 | self.optimizer.step()
915 |
916 | losses = {'Total_loss': loss.item(), 'MMD_loss': mmd_loss_adv.item(), 'Src_cls_loss': src_cls_loss.item()}
917 | for key, val in losses.items():
918 | avg_meter[key].update(val, 32)
919 |
920 | self.lr_scheduler.step()
921 |
922 | class SASA(Algorithm):
923 |
924 | def __init__(self, backbone, configs, hparams, device):
925 | super().__init__(configs, backbone)
926 |
927 | # feature_length for classifier
928 | configs.features_len = 1
929 | self.classifier = classifier(configs)
930 | # feature length for feature extractor
931 | configs.features_len = 1
932 | self.feature_extractor = CNN_ATTN(configs)
933 | self.network = nn.Sequential(self.feature_extractor, self.classifier)
934 |
935 | # optimizer and scheduler
936 | self.optimizer = torch.optim.Adam(
937 | self.network.parameters(),
938 | lr=hparams["learning_rate"],
939 | weight_decay=hparams["weight_decay"]
940 | )
941 | self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay'])
942 | # hparams
943 | self.hparams = hparams
944 | # device
945 | self.device = device
946 |
947 |
948 | def training_epoch(self,src_loader, trg_loader, avg_meter, epoch):
949 |
950 | # Construct Joint Loaders
951 | joint_loader =enumerate(zip(src_loader, itertools.cycle(trg_loader)))
952 | for step, ((src_x, src_y), (trg_x, _)) in joint_loader:
953 | src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(self.device) # extract source features
954 |
955 | # Extract features
956 | src_feature = self.feature_extractor(src_x)
957 | tgt_feature = self.feature_extractor(trg_x)
958 |
959 | # source classification loss
960 | y_pred = self.classifier(src_feature)
961 | src_cls_loss = self.cross_entropy(y_pred, src_y)
962 |
963 | # MMD loss
964 | domain_loss_intra = self.mmd_loss(src_struct=src_feature,
965 | tgt_struct=tgt_feature, weight=self.hparams['domain_loss_wt'])
966 |
967 | # total loss
968 | total_loss = self.hparams['src_cls_loss_wt'] * src_cls_loss + domain_loss_intra
969 |
970 | # remove old gradients
971 | self.optimizer.zero_grad()
972 | # calculate gradients
973 | total_loss.backward()
974 | # update the weights
975 | self.optimizer.step()
976 |
977 | losses = {'Total_loss': total_loss.item(), 'MMD_loss': domain_loss_intra.item(),
978 | 'Src_cls_loss': src_cls_loss.item()}
979 | for key, val in losses.items():
980 | avg_meter[key].update(val, 32)
981 |
982 | self.lr_scheduler.step()
983 | def mmd_loss(self, src_struct, tgt_struct, weight):
984 | delta = torch.mean(src_struct - tgt_struct, dim=-2)
985 | loss_value = torch.norm(delta, 2) * weight
986 | return loss_value
987 |
988 |
989 | class CoTMix(Algorithm):
990 | def __init__(self, backbone, configs, hparams, device):
991 | super().__init__(configs, backbone)
992 |
993 | # optimizer and scheduler
994 | self.optimizer = torch.optim.Adam(
995 | self.network.parameters(),
996 | lr=hparams["learning_rate"],
997 | weight_decay=hparams["weight_decay"]
998 | )
999 | self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay'])
1000 | # hparams
1001 | self.hparams = hparams
1002 | # device
1003 | self.device = device
1004 |
1005 | # Aligment losses
1006 | self.contrastive_loss = NTXentLoss(device, hparams["batch_size"], 0.2, True)
1007 | self.entropy_loss = ConditionalEntropyLoss()
1008 | self.sup_contrastive_loss = SupConLoss(device)
1009 |
1010 | def training_epoch(self,src_loader, trg_loader, avg_meter, epoch):
1011 |
1012 | # Construct Joint Loaders
1013 | joint_loader =enumerate(zip(src_loader, itertools.cycle(trg_loader)))
1014 | for step, ((src_x, src_y), (trg_x, _)) in joint_loader:
1015 | src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(self.device) # extract source features
1016 |
1017 | # ====== Temporal Mixup =====================
1018 | src_dominant, trg_dominant = self.temporal_mixup(src_x, trg_x)
1019 |
1020 | # ====== Source =====================
1021 | self.optimizer.zero_grad()
1022 |
1023 | # Src original features
1024 | src_orig_feat = self.feature_extractor(src_x)
1025 | src_orig_logits = self.classifier(src_orig_feat)
1026 |
1027 | # Target original features
1028 | trg_orig_feat = self.feature_extractor(trg_x)
1029 | trg_orig_logits = self.classifier(trg_orig_feat)
1030 |
1031 | # ----------- The two main losses
1032 | # Cross-Entropy loss
1033 | src_cls_loss = self.cross_entropy(src_orig_logits, src_y)
1034 | loss = src_cls_loss * round(self.hparams["src_cls_weight"], 2)
1035 |
1036 | # Target Entropy loss
1037 | trg_entropy_loss = self.entropy_loss(trg_orig_logits)
1038 | loss += trg_entropy_loss * round(self.hparams["trg_entropy_weight"], 2)
1039 |
1040 | # ----------- Auxiliary losses
1041 | # Extract source-dominant mixup features.
1042 | src_dominant_feat = self.feature_extractor(src_dominant)
1043 | src_dominant_logits = self.classifier(src_dominant_feat)
1044 |
1045 | # supervised contrastive loss on source domain side
1046 | src_concat = torch.cat([src_orig_logits.unsqueeze(1), src_dominant_logits.unsqueeze(1)], dim=1)
1047 | src_supcon_loss = self.sup_contrastive_loss(src_concat, src_y)
1048 | loss += src_supcon_loss * round(self.hparams["src_supCon_weight"], 2)
1049 |
1050 | # Extract target-dominant mixup features.
1051 | trg_dominant_feat = self.feature_extractor(trg_dominant)
1052 | trg_dominant_logits = self.classifier(trg_dominant_feat)
1053 |
1054 | # Unsupervised contrastive loss on target domain side
1055 | trg_con_loss = self.contrastive_loss(trg_orig_logits, trg_dominant_logits)
1056 | loss += trg_con_loss * round(self.hparams["trg_cont_weight"], 2)
1057 |
1058 | loss.backward()
1059 | self.optimizer.step()
1060 |
1061 | losses = {'Total_loss': loss.item(),
1062 | 'src_cls_loss': src_cls_loss.item(),
1063 | 'trg_entropy_loss': trg_entropy_loss.item(),
1064 | 'src_supcon_loss': src_supcon_loss.item(),
1065 | 'trg_con_loss': trg_con_loss.item()
1066 | }
1067 | for key, val in losses.items():
1068 | avg_meter[key].update(val, 32)
1069 |
1070 | self.lr_scheduler.step()
1071 |
1072 | def temporal_mixup(self,src_x, trg_x):
1073 |
1074 | mix_ratio = round(self.hparams["mix_ratio"], 2)
1075 | temporal_shift = self.hparams["temporal_shift"]
1076 | h = temporal_shift // 2 # half
1077 |
1078 | src_dominant = mix_ratio * src_x + (1 - mix_ratio) * \
1079 | torch.mean(torch.stack([torch.roll(trg_x, -i, 2) for i in range(-h, h)], 2), 2)
1080 |
1081 | trg_dominant = mix_ratio * trg_x + (1 - mix_ratio) * \
1082 | torch.mean(torch.stack([torch.roll(src_x, -i, 2) for i in range(-h, h)], 2), 2)
1083 |
1084 | return src_dominant, trg_dominant
1085 |
1086 |
1087 |
1088 | # Untied Approaches: (MCD)
1089 | class MCD(Algorithm):
1090 | """
1091 | Maximum Classifier Discrepancy for Unsupervised Domain Adaptation
1092 | MCD: https://arxiv.org/pdf/1712.02560.pdf
1093 | """
1094 |
1095 | def __init__(self, backbone, configs, hparams, device):
1096 | super().__init__(configs, backbone)
1097 |
1098 | self.feature_extractor = backbone(configs)
1099 | self.classifier = classifier(configs)
1100 | self.classifier2 = classifier(configs)
1101 |
1102 | self.network = nn.Sequential(self.feature_extractor, self.classifier)
1103 |
1104 |
1105 | # optimizer and scheduler
1106 | self.optimizer_fe = torch.optim.Adam(
1107 | self.feature_extractor.parameters(),
1108 | lr=hparams["learning_rate"],
1109 | weight_decay=hparams["weight_decay"]
1110 | )
1111 | # optimizer and scheduler
1112 | self.optimizer_c1 = torch.optim.Adam(
1113 | self.classifier.parameters(),
1114 | lr=hparams["learning_rate"],
1115 | weight_decay=hparams["weight_decay"]
1116 | )
1117 | # optimizer and scheduler
1118 | self.optimizer_c2 = torch.optim.Adam(
1119 | self.classifier2.parameters(),
1120 | lr=hparams["learning_rate"],
1121 | weight_decay=hparams["weight_decay"]
1122 | )
1123 |
1124 | self.lr_scheduler_fe = StepLR(self.optimizer_fe, step_size=hparams['step_size'], gamma=hparams['lr_decay'])
1125 | self.lr_scheduler_c1 = StepLR(self.optimizer_c1, step_size=hparams['step_size'], gamma=hparams['lr_decay'])
1126 | self.lr_scheduler_c2 = StepLR(self.optimizer_c2, step_size=hparams['step_size'], gamma=hparams['lr_decay'])
1127 |
1128 | # hparams
1129 | self.hparams = hparams
1130 | # device
1131 | self.device = device
1132 |
1133 | # Aligment losses
1134 | self.mmd_loss = MMD_loss()
1135 |
1136 | def update(self, src_loader, trg_loader, avg_meter, logger):
1137 | # defining best and last model
1138 | best_src_risk = float('inf')
1139 | best_model = None
1140 |
1141 | for epoch in range(1, self.hparams["num_epochs"] + 1):
1142 |
1143 | # source pretraining loop
1144 | self.pretrain_epoch(src_loader, avg_meter)
1145 |
1146 | # training loop
1147 | self.training_epoch(src_loader, trg_loader, avg_meter, epoch)
1148 |
1149 | # saving the best model based on src risk
1150 | if (epoch + 1) % 10 == 0 and avg_meter['Src_cls_loss'].avg < best_src_risk:
1151 | best_src_risk = avg_meter['Src_cls_loss'].avg
1152 | best_model = deepcopy(self.network.state_dict())
1153 |
1154 |
1155 | logger.debug(f'[Epoch : {epoch}/{self.hparams["num_epochs"]}]')
1156 | for key, val in avg_meter.items():
1157 | logger.debug(f'{key}\t: {val.avg:2.4f}')
1158 | logger.debug(f'-------------------------------------')
1159 |
1160 | last_model = self.network.state_dict()
1161 |
1162 | return last_model, best_model
1163 |
1164 | def pretrain_epoch(self, src_loader,avg_meter):
1165 | for src_x, src_y in src_loader:
1166 | src_x, src_y = src_x.to(self.device), src_y.to(self.device)
1167 |
1168 | src_feat = self.feature_extractor(src_x)
1169 | src_pred1 = self.classifier(src_feat)
1170 | src_pred2 = self.classifier2(src_feat)
1171 |
1172 | src_cls_loss1 = self.cross_entropy(src_pred1, src_y)
1173 | src_cls_loss2 = self.cross_entropy(src_pred2, src_y)
1174 |
1175 | loss = src_cls_loss1 + src_cls_loss2
1176 |
1177 | self.optimizer_c1.zero_grad()
1178 | self.optimizer_c2.zero_grad()
1179 | self.optimizer_fe.zero_grad()
1180 |
1181 | loss.backward()
1182 |
1183 | self.optimizer_c1.step()
1184 | self.optimizer_c2.step()
1185 | self.optimizer_fe.step()
1186 |
1187 |
1188 | losses = {'Src_cls_loss': loss.item()}
1189 |
1190 | for key, val in losses.items():
1191 | avg_meter[key].update(val, 32)
1192 |
1193 | def training_epoch(self, src_loader, trg_loader, avg_meter, epoch):
1194 |
1195 | # Construct Joint Loaders
1196 | joint_loader =enumerate(zip(src_loader, itertools.cycle(trg_loader)))
1197 |
1198 | for step, ((src_x, src_y), (trg_x, _)) in joint_loader:
1199 | src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(self.device) # extract source features
1200 |
1201 |
1202 | # extract source features
1203 | src_feat = self.feature_extractor(src_x)
1204 | src_pred1 = self.classifier(src_feat)
1205 | src_pred2 = self.classifier2(src_feat)
1206 |
1207 | # source losses
1208 | src_cls_loss1 = self.cross_entropy(src_pred1, src_y)
1209 | src_cls_loss2 = self.cross_entropy(src_pred2, src_y)
1210 | loss_s = src_cls_loss1 + src_cls_loss2
1211 |
1212 |
1213 | # Freeze the feature extractor
1214 | for k, v in self.feature_extractor.named_parameters():
1215 | v.requires_grad = False
1216 | # update C1 and C2 to maximize their difference on target sample
1217 | trg_feat = self.feature_extractor(trg_x)
1218 | trg_pred1 = self.classifier(trg_feat.detach())
1219 | trg_pred2 = self.classifier2(trg_feat.detach())
1220 |
1221 |
1222 | loss_dis = self.discrepancy(trg_pred1, trg_pred2)
1223 |
1224 | loss = loss_s - loss_dis
1225 |
1226 | loss.backward()
1227 | self.optimizer_c1.step()
1228 | self.optimizer_c2.step()
1229 |
1230 | self.optimizer_c1.zero_grad()
1231 | self.optimizer_c2.zero_grad()
1232 | self.optimizer_fe.zero_grad()
1233 |
1234 | # Freeze the classifiers
1235 | for k, v in self.classifier.named_parameters():
1236 | v.requires_grad = False
1237 | for k, v in self.classifier2.named_parameters():
1238 | v.requires_grad = False
1239 | # Freeze the feature extractor
1240 | for k, v in self.feature_extractor.named_parameters():
1241 | v.requires_grad = True
1242 | # update feature extractor to minimize the discrepaqncy on target samples
1243 | trg_feat = self.feature_extractor(trg_x)
1244 | trg_pred1 = self.classifier(trg_feat)
1245 | trg_pred2 = self.classifier2(trg_feat)
1246 |
1247 |
1248 | loss_dis_t = self.discrepancy(trg_pred1, trg_pred2)
1249 | domain_loss = self.hparams["domain_loss_wt"] * loss_dis_t
1250 |
1251 | domain_loss.backward()
1252 | self.optimizer_fe.step()
1253 |
1254 | self.optimizer_fe.zero_grad()
1255 | self.optimizer_c1.zero_grad()
1256 | self.optimizer_c2.zero_grad()
1257 |
1258 |
1259 | losses = {'Total_loss': loss.item(), 'MMD_loss': domain_loss.item()}
1260 |
1261 | for key, val in losses.items():
1262 | avg_meter[key].update(val, 32)
1263 |
1264 | self.lr_scheduler_fe.step()
1265 | self.lr_scheduler_c1.step()
1266 | self.lr_scheduler_c2.step()
1267 |
1268 | def discrepancy(self, out1, out2):
1269 |
1270 | return torch.mean(torch.abs(F.softmax(out1) - F.softmax(out2)))
1271 |
--------------------------------------------------------------------------------
/configs/__pycache__/data_model_configs.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/configs/__pycache__/data_model_configs.cpython-310.pyc
--------------------------------------------------------------------------------
/configs/__pycache__/data_model_configs.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/configs/__pycache__/data_model_configs.cpython-38.pyc
--------------------------------------------------------------------------------
/configs/__pycache__/data_model_configs.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/configs/__pycache__/data_model_configs.cpython-39.pyc
--------------------------------------------------------------------------------
/configs/__pycache__/hparams.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/configs/__pycache__/hparams.cpython-310.pyc
--------------------------------------------------------------------------------
/configs/__pycache__/hparams.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/configs/__pycache__/hparams.cpython-38.pyc
--------------------------------------------------------------------------------
/configs/__pycache__/hparams.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/configs/__pycache__/hparams.cpython-39.pyc
--------------------------------------------------------------------------------
/configs/__pycache__/sweep_params.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/configs/__pycache__/sweep_params.cpython-310.pyc
--------------------------------------------------------------------------------
/configs/__pycache__/sweep_params.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/configs/__pycache__/sweep_params.cpython-38.pyc
--------------------------------------------------------------------------------
/configs/__pycache__/sweep_params.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/configs/__pycache__/sweep_params.cpython-39.pyc
--------------------------------------------------------------------------------
/configs/data_model_configs.py:
--------------------------------------------------------------------------------
1 | def get_dataset_class(dataset_name):
2 | """Return the algorithm class with the given name."""
3 | if dataset_name not in globals():
4 | raise NotImplementedError("Dataset not found: {}".format(dataset_name))
5 | return globals()[dataset_name]
6 |
7 | class HAR():
8 | def __init__(self):
9 | super(HAR, self)
10 | self.scenarios = [("2", "11"), ("6", "23"), ("7", "13"), ("9", "18"), ("12", "16"), ("18", "27"), ("20", "5"), ("24", "8"), ("28", "27"), ("30", "20")]
11 | self.class_names = ['walk', 'upstairs', 'downstairs', 'sit', 'stand', 'lie']
12 | self.sequence_len = 128
13 | self.shuffle = True
14 | self.drop_last = True
15 | self.normalize = True
16 |
17 | # model configs
18 | self.input_channels = 9
19 | self.kernel_size = 5
20 | self.stride = 1
21 | self.dropout = 0.5
22 | self.num_classes = 6
23 |
24 | # CNN and RESNET features
25 | self.mid_channels = 64
26 | self.final_out_channels = 128
27 | self.features_len = 1
28 |
29 | # TCN features
30 | self.tcn_layers = [75, 150]
31 | self.tcn_final_out_channles = self.tcn_layers[-1]
32 | self.tcn_kernel_size = 17
33 | self.tcn_dropout = 0.0
34 |
35 | # lstm features
36 | self.lstm_hid = 128
37 | self.lstm_n_layers = 1
38 | self.lstm_bid = False
39 |
40 | # discriminator
41 | self.disc_hid_dim = 64
42 | self.hidden_dim = 500
43 | self.DSKN_disc_hid = 128
44 |
45 |
46 | class EEG():
47 | def __init__(self):
48 | super(EEG, self).__init__()
49 | # data parameters
50 | self.num_classes = 5
51 | self.class_names = ['W', 'N1', 'N2', 'N3', 'REM']
52 | self.sequence_len = 3000
53 | self.scenarios = [("0", "11"), ("7", "18"), ("9", "14"), ("12", "5"), ("16", "1"),
54 | ("3", "19"), ("18", "12"), ("13", "17"), ("5", "15"), ("6", "2")]
55 | self.shuffle = True
56 | self.drop_last = True
57 | self.normalize = True
58 |
59 | # model configs
60 | self.input_channels = 1
61 | self.kernel_size = 25
62 | self.stride = 6
63 | self.dropout = 0.2
64 |
65 | # features
66 | self.mid_channels = 32
67 | self.final_out_channels = 128
68 | self.features_len = 1
69 |
70 | # TCN features
71 | self.tcn_layers = [32,64]
72 | self.tcn_final_out_channles = self.tcn_layers[-1]
73 | self.tcn_kernel_size = 15# 25
74 | self.tcn_dropout = 0.0
75 |
76 | # lstm features
77 | self.lstm_hid = 128
78 | self.lstm_n_layers = 1
79 | self.lstm_bid = False
80 |
81 | # discriminator
82 | self.DSKN_disc_hid = 128
83 | self.hidden_dim = 500
84 | self.disc_hid_dim = 100
85 |
86 |
87 | class WISDM(object):
88 | def __init__(self):
89 | super(WISDM, self).__init__()
90 | self.class_names = ['walk', 'jog', 'sit', 'stand', 'upstairs', 'downstairs']
91 | self.sequence_len = 128
92 | self.scenarios = [("7", "18"), ("20", "30"), ("35", "31"), ("17", "23"), ("6", "19"),
93 | ("2", "11"), ("33", "12"), ("5", "26"), ("28", "4"), ("23", "32")]
94 | self.num_classes = 6
95 | self.shuffle = True
96 | self.drop_last = False
97 | self.normalize = True
98 |
99 | # model configs
100 | self.input_channels = 3
101 | self.kernel_size = 5
102 | self.stride = 1
103 | self.dropout = 0.5
104 | self.num_classes = 6
105 |
106 | # features
107 | self.mid_channels = 64
108 | self.final_out_channels = 128
109 | self.features_len = 1
110 |
111 | # TCN features
112 | self.tcn_layers = [75,150,300]
113 | self.tcn_final_out_channles = self.tcn_layers[-1]
114 | self.tcn_kernel_size = 17
115 | self.tcn_dropout = 0.0
116 |
117 | # lstm features
118 | self.lstm_hid = 128
119 | self.lstm_n_layers = 1
120 | self.lstm_bid = False
121 |
122 | # discriminator
123 | self.disc_hid_dim = 64
124 | self.DSKN_disc_hid = 128
125 | self.hidden_dim = 500
126 |
127 |
128 | class HHAR(object): ## HHAR dataset, SAMSUNG device.
129 | def __init__(self):
130 | super(HHAR, self).__init__()
131 | self.sequence_len = 128
132 | self.scenarios = [("0", "6"), ("1", "6"), ("2", "7"), ("3", "8"), ("4", "5"),
133 | ("5", "0"), ("6", "1"), ("7", "4"), ("8", "3"), ("0", "2")]
134 | self.class_names = ['bike', 'sit', 'stand', 'walk', 'stairs_up', 'stairs_down']
135 | self.num_classes = 6
136 | self.shuffle = True
137 | self.drop_last = True
138 | self.normalize = True
139 |
140 | # model configs
141 | self.input_channels = 3
142 | self.kernel_size = 5
143 | self.stride = 1
144 | self.dropout = 0.5
145 |
146 | # features
147 | self.mid_channels = 64
148 | self.final_out_channels = 128
149 | self.features_len = 1
150 |
151 | # TCN features
152 | self.tcn_layers = [75,150]
153 | self.tcn_final_out_channles = self.tcn_layers[-1]
154 | self.tcn_kernel_size = 17
155 | self.tcn_dropout = 0.0
156 |
157 | # lstm features
158 | self.lstm_hid = 128
159 | self.lstm_n_layers = 1
160 | self.lstm_bid = False
161 |
162 | # discriminator
163 | self.disc_hid_dim = 64
164 | self.DSKN_disc_hid = 128
165 | self.hidden_dim = 500
166 |
167 |
168 |
169 | class FD(object):
170 | def __init__(self):
171 | super(FD, self).__init__()
172 | self.sequence_len = 5120
173 | self.scenarios = [("0", "1"), ("0", "3"), ("1", "0"), ("1", "2"),("1", "3"),
174 | ("2", "1"),("2", "3"), ("3", "0"), ("3", "1"), ("3", "2")]
175 | self.class_names = ['Healthy', 'D1', 'D2']
176 | self.num_classes = 3
177 | self.shuffle = True
178 | self.drop_last = True
179 | self.normalize = True
180 |
181 | # Model configs
182 | self.input_channels = 1
183 | self.kernel_size = 32
184 | self.stride = 6
185 | self.dropout = 0.5
186 |
187 | self.mid_channels = 64
188 | self.final_out_channels = 128
189 | self.features_len = 1
190 |
191 | # TCN features
192 | self.tcn_layers = [75, 150]
193 | self.tcn_final_out_channles = self.tcn_layers[-1]
194 | self.tcn_kernel_size = 17
195 | self.tcn_dropout = 0.0
196 |
197 | # lstm features
198 | self.lstm_hid = 128
199 | self.lstm_n_layers = 1
200 | self.lstm_bid = False
201 |
202 | # discriminator
203 | self.disc_hid_dim = 64
204 | self.DSKN_disc_hid = 128
205 | self.hidden_dim = 500
206 |
--------------------------------------------------------------------------------
/configs/hparams.py:
--------------------------------------------------------------------------------
1 | ## The cuurent hyper-parameters values are not necessarily the best ones for a specific risk.
2 | def get_hparams_class(dataset_name):
3 | """Return the algorithm class with the given name."""
4 | if dataset_name not in globals():
5 | raise NotImplementedError("Dataset not found: {}".format(dataset_name))
6 | return globals()[dataset_name]
7 |
8 |
9 | class HAR():
10 | def __init__(self):
11 | super(HAR, self).__init__()
12 | self.train_params = {
13 | 'num_epochs': 40,
14 | 'batch_size': 32,
15 | 'weight_decay': 1e-4,
16 | 'step_size': 50,
17 | 'lr_decay': 0.5
18 |
19 | }
20 | self.alg_hparams = {
21 | 'NO_ADAPT': {'learning_rate': 1e-3, 'src_cls_loss_wt': 1},
22 | 'TARGET_ONLY': {'learning_rate': 1e-3, 'trg_cls_loss_wt': 1},
23 | "SASA": {
24 | "domain_loss_wt": 7.3937939938562,
25 | "learning_rate": 0.005,
26 | "src_cls_loss_wt": 4.185814373345016,
27 | "weight_decay": 0.0001
28 | },
29 | "DDC": {
30 | "learning_rate": 0.001,
31 | "mmd_wt": 3.7991920933520342,
32 | "src_cls_loss_wt": 6.286301875125623,
33 | "domain_loss_wt": 6.36,
34 | "weight_decay": 0.0001
35 | },
36 | "CoDATS": {
37 | "domain_loss_wt": 3.2750474868706925,
38 | "learning_rate": 0.001,
39 | "src_cls_loss_wt": 6.335109786953256,
40 | "weight_decay": 0.0001
41 | },
42 | "DANN": {
43 | "domain_loss_wt": 2.943729820531079,
44 | "learning_rate": 0.001,
45 | "src_cls_loss_wt": 5.1390077646202,
46 | "weight_decay": 0.0001
47 | },
48 | "DIRT": {
49 | "cond_ent_wt": 1.20721518968644,
50 | "domain_loss_wt": 1.9012145515129044,
51 | "learning_rate": 0.005,
52 | "src_cls_loss_wt": 9.67861021290254,
53 | "vat_loss_wt": 7.7102843136045855,
54 | "weight_decay": 0.0001
55 | },
56 | "DSAN": {
57 | "learning_rate": 0.001,
58 | "mmd_wt": 2.0872340713147786,
59 | "src_cls_loss_wt": 1.8744909939900247,
60 | "domain_loss_wt": 1.59,
61 | "weight_decay": 0.0001
62 | },
63 | "MMDA": {
64 | "cond_ent_wt": 1.383002023133561,
65 | "coral_wt": 8.36810764913737,
66 | "learning_rate": 0.001,
67 | "mmd_wt": 3.964042918489996,
68 | "src_cls_loss_wt": 6.794522068759213,
69 | "weight_decay": 0.0001
70 | },
71 | "Deep_Coral": {
72 | "coral_wt": 4.23035475456397,
73 | "learning_rate": 0.0005,
74 | "src_cls_loss_wt": 0.1013209750429822,
75 | "weight_decay": 0.0001
76 | },
77 | "CDAN": {
78 | "cond_ent_wt": 1.2920143348777362,
79 | "domain_loss_wt": 9.545761950873414,
80 | "learning_rate": 0.001,
81 | "src_cls_loss_wt": 9.430292987535724,
82 | "weight_decay": 0.0001
83 | },
84 | "AdvSKM": {
85 | "domain_loss_wt": 1.338788378230754,
86 | "learning_rate": 0.0005,
87 | "src_cls_loss_wt": 2.468525942065072,
88 | "weight_decay": 0.0001
89 | },
90 | "HoMM": {
91 | "hommd_wt": 2.8305712579412683,
92 | "learning_rate": 0.0005,
93 | "src_cls_loss_wt": 0.1282520874653523,
94 | "domain_loss_wt": 9.13,
95 | "weight_decay": 0.0001
96 | },
97 | 'CoTMix': {'learning_rate': 0.001, 'mix_ratio': 0.9, 'temporal_shift': 14,
98 | 'src_cls_weight': 0.78, 'src_supCon_weight': 0.1, 'trg_cont_weight': 0.1,
99 | 'trg_entropy_weight': 0.05},
100 | 'MCD': {'learning_rate': 1e-2, 'src_cls_loss_wt': 9.74, 'domain_loss_wt': 5.43},
101 |
102 | }
103 |
104 |
105 | class EEG():
106 | def __init__(self):
107 | super(EEG, self).__init__()
108 | self.train_params = {
109 | 'num_epochs': 40,
110 | 'batch_size': 128,
111 | 'weight_decay': 1e-4,
112 |
113 | }
114 | self.alg_hparams = {
115 | 'NO_ADAPT': {'learning_rate': 1e-3, 'src_cls_loss_wt': 1},
116 | 'TARGET_ONLY': {'learning_rate': 1e-3, 'trg_cls_loss_wt': 1},
117 | "SASA": {
118 | "domain_loss_wt": 5.8045319155819515,
119 | "learning_rate": 0.005,
120 | "src_cls_loss_wt": 4.438490884851632,
121 | "weight_decay": 0.0001
122 | },
123 | "CoDATS": {
124 | "domain_loss_wt": 0.3551260369189456,
125 | "learning_rate": 0.005,
126 | "src_cls_loss_wt": 1.2534327517723889,
127 | "weight_decay": 0.0001
128 | },
129 | "AdvSKM": {
130 | "domain_loss_wt": 5.600818539370264,
131 | "learning_rate": 0.0005,
132 | "src_cls_loss_wt": 4.231231335081738,
133 | "weight_decay": 0.0001
134 | },
135 | "Deep_Coral": {
136 | "coral_wt": 9.50224286095279,
137 | "learning_rate": 0.0005,
138 | "src_cls_loss_wt": 0.8149666724969482,
139 | "weight_decay": 0.0001
140 | },
141 | "DANN": {
142 | "domain_loss_wt": 0.27634197975549135,
143 | "learning_rate": 0.0005,
144 | "src_cls_loss_wt": 8.441929209893459,
145 | "weight_decay": 0.0001
146 | },
147 | "DDC": {
148 | "learning_rate": 0.0005,
149 | "mmd_wt": 5.900770246907044,
150 | "src_cls_loss_wt": 1.979307877348751,
151 | "domain_loss_wt": 8.923,
152 | "weight_decay": 0.0001
153 | },
154 | "DIRT": {
155 | "cond_ent_wt": 1.7021814402136783,
156 | "domain_loss_wt": 1.6488583075821344,
157 | "learning_rate": 0.01,
158 | "src_cls_loss_wt": 6.427127521674593,
159 | "vat_loss_wt": 5.078600240648073,
160 | "weight_decay": 0.0001
161 | },
162 | "MMDA": {
163 | "cond_ent_wt": 9.177841626283191,
164 | "coral_wt": 2.768290045896212,
165 | "learning_rate": 0.0005,
166 | "mmd_wt": 2.25231504738171,
167 | "src_cls_loss_wt": 8.64418208100774,
168 | "weight_decay": 0.0001
169 | },
170 | "DSAN": {
171 | "learning_rate": 0.001,
172 | "mmd_wt": 5.01196798268099,
173 | "src_cls_loss_wt": 7.774381653453339,
174 | "domain_loss_wt": 6.708,
175 | "weight_decay": 0.0001
176 | },
177 | "HoMM": {
178 | "hommd_wt": 3.843851397373747,
179 | "learning_rate": 0.001,
180 | "src_cls_loss_wt": 1.8311375304849091,
181 | "domain_loss_wt": 1.102,
182 | "weight_decay": 0.0001
183 | },
184 | "CDAN": {
185 | "cond_ent_wt": 0.7559091229767906,
186 | "domain_loss_wt": 0.17693531166083065,
187 | "learning_rate": 0.0005,
188 | "src_cls_loss_wt": 7.764624556216286,
189 | "weight_decay": 0.0001
190 | },
191 | 'CoTMix': {'learning_rate': 0.001, 'mix_ratio': 0.79, 'temporal_shift': 300,
192 | 'src_cls_weight': 0.96, 'src_supCon_weight': 0.1, 'trg_cont_weight': 0.1,
193 | 'trg_entropy_weight': 0.05}
194 |
195 | }
196 |
197 |
198 | class WISDM():
199 | def __init__(self):
200 | super().__init__()
201 | self.train_params = {
202 | 'num_epochs': 40,
203 | 'batch_size': 32,
204 | 'weight_decay': 1e-4,
205 |
206 | }
207 | self.alg_hparams = {
208 | 'NO_ADAPT': {'learning_rate': 1e-3, 'src_cls_loss_wt': 1},
209 | 'TARGET_ONLY': {'learning_rate': 1e-3, 'trg_cls_loss_wt': 1},
210 | "SASA": {
211 | "domain_loss_wt": 1.2632988839197083,
212 | "learning_rate": 0.005,
213 | "src_cls_loss_wt": 9.898676755625807,
214 | "weight_decay": 0.0001
215 | },
216 | "CDAN": {
217 | "cond_ent_wt": 0.837129024245748,
218 | "domain_loss_wt": 5.9197207530729266,
219 | "learning_rate": 0.001,
220 | "src_cls_loss_wt": 6.983963629299826,
221 | "weight_decay": 0.0001
222 | },
223 | "HoMM": {
224 | "hommd_wt": 6.799448304230478,
225 | "learning_rate": 0.005,
226 | "src_cls_loss_wt": 0.2563533185103576,
227 | "domain_loss_wt": 4.239,
228 | "weight_decay": 0.0001
229 | },
230 | "DANN": {
231 | "domain_loss_wt": 2.6051391453662873,
232 | "learning_rate": 0.005,
233 | "src_cls_loss_wt": 5.272383517138417,
234 | "weight_decay": 0.0001
235 | },
236 | "DIRT": {
237 | "cond_ent_wt": 1.6935884891647972,
238 | "domain_loss_wt": 7.774841143071709,
239 | "learning_rate": 0.005,
240 | "src_cls_loss_wt": 9.62463958771893,
241 | "vat_loss_wt": 4.644539486962429,
242 | "weight_decay": 0.0001
243 | },
244 | "AdvSKM": {
245 | "domain_loss_wt": 0.17573022784621156,
246 | "learning_rate": 0.001,
247 | "src_cls_loss_wt": 7.656694101023234,
248 | "weight_decay": 0.0001
249 | },
250 | "MMDA": {
251 | "cond_ent_wt": 7.555540424691775,
252 | "coral_wt": 5.254400971297628,
253 | "learning_rate": 0.005,
254 | "mmd_wt": 2.295549751091742,
255 | "src_cls_loss_wt": 6.653513071102565,
256 | "weight_decay": 0.0001
257 | },
258 | "Deep_Coral": {
259 | "coral_wt": 6.4881104202861755,
260 | "learning_rate": 0.001,
261 | "src_cls_loss_wt": 6.66305608395703,
262 | "weight_decay": 0.0001
263 | },
264 | "CoDATS": {
265 | "domain_loss_wt": 4.574872968982744,
266 | "learning_rate": 0.001,
267 | "src_cls_loss_wt": 5.860885469514424,
268 | "weight_decay": 0.0001
269 | },
270 | "DSAN": {
271 | "learning_rate": 0.005,
272 | "mmd_wt": 1.5468030830413808,
273 | "src_cls_loss_wt": 1.2981011362021273,
274 | "domain_loss_wt": 0.1,
275 | "weight_decay": 0.0001
276 | },
277 | "DDC": {
278 | "learning_rate": 0.001,
279 | "mmd_wt": 1.9901164953952095,
280 | "src_cls_loss_wt": 4.881899626451807,
281 | "domain_loss_wt": 7.595,
282 | "weight_decay": 0.0001
283 | },
284 | "CoTMix": {
285 | 'learning_rate': 0.001,
286 | 'mix_ratio': 0.72,
287 | 'temporal_shift': 14,
288 | 'src_cls_weight': 0.98,
289 | 'src_supCon_weight': 0.1,
290 | 'trg_cont_weight': 0.1,
291 | 'trg_entropy_weight': 0.05}
292 | }
293 |
294 |
295 | class HHAR():
296 | def __init__(self):
297 | super().__init__()
298 | self.train_params = {
299 | 'num_epochs': 40,
300 | 'batch_size': 32,
301 | 'weight_decay': 1e-4,
302 | }
303 | self.alg_hparams = {
304 | 'NO_ADAPT': {'learning_rate': 1e-3, 'src_cls_loss_wt': 1},
305 | 'TARGET_ONLY': {'learning_rate': 1e-3, 'trg_cls_loss_wt': 1},
306 |
307 | "SASA": {
308 | "domain_loss_wt": 5.760124609738364,
309 | "learning_rate": 0.001,
310 | "src_cls_loss_wt": 4.130742585941761,
311 | "weight_decay": 0.0001
312 | },
313 | "DSAN": {
314 | "learning_rate": 0.0005,
315 | "mmd_wt": 0.5993593617252002,
316 | "src_cls_loss_wt": 0.386167577207679,
317 | "domain_loss_wt": 0.16,
318 | "weight_decay": 0.0001
319 | },
320 | "CoDATS": {
321 | "domain_loss_wt": 9.314114040099962,
322 | "learning_rate": 0.0005,
323 | "src_cls_loss_wt": 7.700018679383289,
324 | "weight_decay": 0.0001
325 | },
326 | "HoMM": {
327 | "hommd_wt": 7.172430927893522,
328 | "learning_rate": 0.0005,
329 | "src_cls_loss_wt": 0.20121211752349172,
330 | "domain_loss_wt": 0.9824,
331 | "weight_decay": 0.0001
332 | },
333 | "DIRT": {
334 | "cond_ent_wt": 1.329734510542011,
335 | "domain_loss_wt": 6.632293308809388,
336 | "learning_rate": 0.001,
337 | "src_cls_loss_wt": 7.729881324550688,
338 | "vat_loss_wt": 6.912258476982827,
339 | "weight_decay": 0.0001
340 | },
341 | "AdvSKM": {
342 | "domain_loss_wt": 1.8649335076712072,
343 | "learning_rate": 0.001,
344 | "src_cls_loss_wt": 3.961611563054495,
345 | "weight_decay": 0.0001
346 | },
347 | "DDC": {
348 | "learning_rate": 0.0005,
349 | "mmd_wt": 8.355791702302787,
350 | "src_cls_loss_wt": 1.2079058664226126,
351 | "domain_loss_wt": 0.2048,
352 | "weight_decay": 0.0001
353 | },
354 | "CDAN": {
355 | "cond_ent_wt": 0.1841898900507932,
356 | "domain_loss_wt": 1.9307294194382076,
357 | "learning_rate": 0.0005,
358 | "src_cls_loss_wt": 4.15410157776963,
359 | "weight_decay": 0.0001
360 | },
361 | "DANN": {
362 | "domain_loss_wt": 1.0296390274908802,
363 | "learning_rate": 0.0005,
364 | "src_cls_loss_wt": 2.038458138479581,
365 | "weight_decay": 0.0001
366 | },
367 | "Deep_Coral": {
368 | "coral_wt": 5.9357031653707475,
369 | "learning_rate": 0.0005,
370 | "src_cls_loss_wt": 0.43859323168654,
371 | "weight_decay": 0.0001
372 | },
373 | "MMDA": {
374 | "cond_ent_wt": 6.707871745810609,
375 | "coral_wt": 5.903714930042433,
376 | "learning_rate": 0.005,
377 | "mmd_wt": 6.480169289397163,
378 | "src_cls_loss_wt": 0.18878476669902317,
379 | "weight_decay": 0.0001
380 | },
381 | 'CoTMix': {'learning_rate': 0.001, 'mix_ratio': 0.52, 'temporal_shift': 14,
382 | 'src_cls_weight': 0.8, 'src_supCon_weight': 0.1, 'trg_cont_weight': 0.1,
383 | 'trg_entropy_weight': 0.05}
384 |
385 | }
386 |
387 |
388 | class FD():
389 | def __init__(self):
390 | super().__init__()
391 | self.train_params = {
392 | 'num_epochs': 40,
393 | 'batch_size': 32,
394 | 'weight_decay': 1e-4,
395 | }
396 | self.alg_hparams = {
397 | 'NO_ADAPT': {'learning_rate': 1e-3, 'src_cls_loss_wt': 1},
398 | 'TARGET_ONLY': {'learning_rate': 1e-3, 'trg_cls_loss_wt': 1},
399 | "SASA": {
400 | "domain_loss_wt": 0.7821851095870519,
401 | "learning_rate": 0.005,
402 | "src_cls_loss_wt": 7.680225091930735,
403 | "weight_decay": 0.0001
404 | },
405 | "MMDA": {
406 | "cond_ent_wt": 8.12868726468387,
407 | "coral_wt": 7.2734249221691005,
408 | "learning_rate": 0.0005,
409 | "mmd_wt": 4.967077206689191,
410 | "src_cls_loss_wt": 0.30259189730747005,
411 | "weight_decay": 0.0001
412 | },
413 | "AdvSKM": {
414 | "domain_loss_wt": 9.377024659182622,
415 | "learning_rate": 0.001,
416 | "src_cls_loss_wt": 0.7569318345582794,
417 | "weight_decay": 0.0001
418 | },
419 | "HoMM": {
420 | "hommd_wt": 6.719563315664067,
421 | "learning_rate": 0.001,
422 | "src_cls_loss_wt": 1.5584167741262964,
423 | "domain_loss_wt": 0.9824,
424 | "weight_decay": 0.0001
425 | },
426 | "Deep_Coral": {
427 | "coral_wt": 7.493856538302936,
428 | "learning_rate": 0.001,
429 | "src_cls_loss_wt": 1.452466194151791,
430 | "weight_decay": 0.0001
431 | },
432 | "DIRT": {
433 | "cond_ent_wt": 4.753485587751647,
434 | "domain_loss_wt": 7.427507171955081,
435 | "learning_rate": 0.001,
436 | "src_cls_loss_wt": 9.818770948448943,
437 | "vat_loss_wt": 9.609164719194178,
438 | "weight_decay": 0.0001
439 | },
440 | "DSAN": {
441 | "learning_rate": 0.005,
442 | "mmd_wt": 7.278792967879357,
443 | "src_cls_loss_wt": 2.5146121077752395,
444 | "domain_loss_wt": 0.16,
445 | "weight_decay": 0.0001
446 | },
447 | "CDAN": {
448 | "cond_ent_wt": 0.553637609557987,
449 | "domain_loss_wt": 6.759045461432962,
450 | "learning_rate": 0.001,
451 | "src_cls_loss_wt": 6.854042579661701,
452 | "weight_decay": 0.0001
453 | },
454 | "DDC": {
455 | "learning_rate": 0.005,
456 | "mmd_wt": 6.701050990813831,
457 | "src_cls_loss_wt": 1.1626428404763771,
458 | "domain_loss_wt": 0.2048,
459 | "weight_decay": 0.0001
460 | },
461 | "CoDATS": {
462 | "domain_loss_wt": 0.6990097136753354,
463 | "learning_rate": 0.005,
464 | "src_cls_loss_wt": 9.57338373194037,
465 | "weight_decay": 0.0001
466 | },
467 | "DANN": {
468 | "domain_loss_wt": 5.221878412210977,
469 | "learning_rate": 0.001,
470 | "src_cls_loss_wt": 4.233865748743297,
471 | "weight_decay": 0.0001
472 | },
473 | 'CoTMix': {'learning_rate': 0.001, 'mix_ratio': 0.52, 'temporal_shift': 14,
474 | 'src_cls_weight': 0.8, 'src_supCon_weight': 0.1, 'trg_cont_weight': 0.1,
475 | 'trg_entropy_weight': 0.05}
476 | }
477 |
--------------------------------------------------------------------------------
/configs/sweep_params.py:
--------------------------------------------------------------------------------
1 | sweep_train_hparams = {
2 | 'num_epochs': {'values': [3, 4, 5, 6]},
3 | 'batch_size': {'values': [32, 64]},
4 | 'learning_rate':{'values': [1e-2, 5e-3, 1e-3, 5e-4]},
5 | 'disc_lr': {'values': [1e-2, 5e-3, 1e-3, 5e-4]},
6 | 'weight_decay': {'values': [1e-4, 1e-5, 1e-6]},
7 | 'step_size': {'values': [5, 10, 30]},
8 | 'gamma': {'values': [5, 10, 15, 20, 25]},
9 | 'optimizer': {'values': ['adam']},
10 | }
11 | sweep_alg_hparams = {
12 | 'DANN': {
13 | 'learning_rate': {'values': [1e-2, 5e-3, 1e-3, 5e-4]},
14 | 'src_cls_loss_wt': {'distribution': 'uniform', 'min': 1e-1, 'max': 10},
15 | 'domain_loss_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10},
16 | },
17 |
18 | 'AdvSKM': {
19 | 'learning_rate': {'values': [1e-2, 5e-3, 1e-3, 5e-4]},
20 | 'src_cls_loss_wt': {'distribution': 'uniform', 'min': 1e-1, 'max': 10},
21 | 'domain_loss_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10},
22 | },
23 |
24 | 'CoDATS': {
25 | 'learning_rate': {'values': [1e-2, 5e-3, 1e-3, 5e-4]},
26 | 'src_cls_loss_wt': {'distribution': 'uniform', 'min': 1e-1, 'max': 10},
27 | 'domain_loss_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10},
28 | },
29 |
30 | 'CDAN': {
31 | 'learning_rate': {'values': [1e-2, 5e-3, 1e-3, 5e-4]},
32 | 'src_cls_loss_wt': {'distribution': 'uniform', 'min': 1e-1, 'max': 10},
33 | 'domain_loss_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10},
34 | 'cond_ent_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10},
35 | },
36 |
37 | 'Deep_Coral': {
38 | 'learning_rate': {'values': [1e-2, 5e-3, 1e-3, 5e-4]},
39 | 'src_cls_loss_wt': {'distribution': 'uniform', 'min': 1e-1, 'max': 10},
40 | 'coral_wt': {'distribution': 'uniform', 'min': 1e-1, 'max': 10},
41 | },
42 |
43 | 'DIRT': {
44 | 'learning_rate': {'values': [1e-2, 5e-3, 1e-3, 5e-4]},
45 | 'src_cls_loss_wt': {'distribution': 'uniform', 'min': 1e-1, 'max': 10},
46 | 'domain_loss_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10},
47 | 'cond_ent_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10},
48 | 'vat_loss_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10},
49 | },
50 |
51 | 'HoMM': {
52 | 'learning_rate': {'values': [1e-2, 5e-3, 1e-3, 5e-4]},
53 | 'src_cls_loss_wt': {'distribution': 'uniform', 'min': 1e-1, 'max': 10},
54 | 'hommd_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10},
55 | },
56 |
57 | 'MMDA': {
58 | 'learning_rate': {'values': [1e-2, 5e-3, 1e-3, 5e-4]},
59 | 'src_cls_loss_wt': {'distribution': 'uniform', 'min': 1e-1, 'max': 10},
60 | 'coral_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10},
61 | 'cond_ent_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10},
62 | 'mmd_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10},
63 | },
64 |
65 | 'DSAN': {
66 | 'learning_rate': {'values': [1e-2, 5e-3, 1e-3, 5e-4]},
67 | 'src_cls_loss_wt': {'distribution': 'uniform', 'min': 1e-1, 'max': 10},
68 | 'mmd_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10},
69 | },
70 |
71 | 'DDC': {
72 | 'learning_rate': {'values': [1e-2, 5e-3, 1e-3, 5e-4]},
73 | 'src_cls_loss_wt': {'distribution': 'uniform', 'min': 1e-1, 'max': 10},
74 | 'mmd_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10},
75 | },
76 |
77 | 'SASA': {
78 | 'learning_rate': {'values': [1e-2, 5e-3, 1e-3, 5e-4]},
79 | 'src_cls_loss_wt': {'distribution': 'uniform', 'min': 1e-1, 'max': 10},
80 | 'domain_loss_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10},
81 | },
82 |
83 | 'CoTMix': {
84 | 'learning_rate': {'values': [1e-2, 5e-3, 1e-3, 5e-4]},
85 | 'temporal_shift': {'values': [5, 10, 15, 20, 30, 50]},
86 | 'src_cls_weight': {'distribution': 'uniform', 'min': 1e-1, 'max': 1},
87 | 'mix_ratio': {'distribution': 'uniform', 'min': 0.5, 'max': 0.99},
88 | 'src_supCon_weight': {'distribution': 'uniform', 'min': 1e-3, 'max': 1},
89 | 'trg_cont_weight': {'distribution': 'uniform', 'min': 1e-3, 'max': 1},
90 | 'trg_entropy_weight': {'distribution': 'uniform', 'min': 1e-3, 'max': 1},
91 | },
92 | }
93 |
94 |
--------------------------------------------------------------------------------
/dataloader/__pycache__/dataloader.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/dataloader/__pycache__/dataloader.cpython-310.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/dataloader.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/dataloader/__pycache__/dataloader.cpython-38.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/dataloader.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/dataloader/__pycache__/dataloader.cpython-39.pyc
--------------------------------------------------------------------------------
/dataloader/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader
3 | from torch.utils.data import Dataset
4 | from torchvision import transforms
5 |
6 | from sklearn.model_selection import train_test_split
7 |
8 | import os, sys
9 | import numpy as np
10 | import random
11 |
12 |
13 | class Load_Dataset(Dataset):
14 | def __init__(self, dataset, dataset_configs):
15 | super().__init__()
16 | self.num_channels = dataset_configs.input_channels
17 |
18 | # Load samples
19 | x_data = dataset["samples"]
20 |
21 | # Load labels
22 | y_data = dataset.get("labels")
23 | if y_data is not None and isinstance(y_data, np.ndarray):
24 | y_data = torch.from_numpy(y_data)
25 |
26 | # Convert to torch tensor
27 | if isinstance(x_data, np.ndarray):
28 | x_data = torch.from_numpy(x_data)
29 |
30 | # Check samples dimensions.
31 | # The dimension of the data is expected to be (N, C, L)
32 | # where N is the #samples, C: #channels, and L is the sequence length
33 | if len(x_data.shape) == 2:
34 | x_data = x_data.unsqueeze(1)
35 | elif len(x_data.shape) == 3 and x_data.shape[1] != self.num_channels:
36 | x_data = x_data.transpose(1, 2)
37 |
38 | # Normalize data
39 | if dataset_configs.normalize:
40 | data_mean = torch.mean(x_data, dim=(0, 2))
41 | data_std = torch.std(x_data, dim=(0, 2))
42 | self.transform = transforms.Normalize(mean=data_mean, std=data_std)
43 | else:
44 | self.transform = None
45 | self.x_data = x_data.float()
46 | self.y_data = y_data.long() if y_data is not None else None
47 | self.len = x_data.shape[0]
48 |
49 |
50 | def __getitem__(self, index):
51 | x = self.x_data[index]
52 | if self.transform:
53 | x = self.transform(self.x_data[index].reshape(self.num_channels, -1, 1)).reshape(self.x_data[index].shape)
54 | y = self.y_data[index] if self.y_data is not None else None
55 | return x, y
56 |
57 | def __len__(self):
58 | return self.len
59 |
60 |
61 | def data_generator(data_path, domain_id, dataset_configs, hparams, dtype):
62 | # loading dataset file from path
63 | dataset_file = torch.load(os.path.join(data_path, f"{dtype}_{domain_id}.pt"))
64 |
65 | # Loading datasets
66 | dataset = Load_Dataset(dataset_file, dataset_configs)
67 |
68 | if dtype == "test": # you don't need to shuffle or drop last batch while testing
69 | shuffle = False
70 | drop_last = False
71 | else:
72 | shuffle = dataset_configs.shuffle
73 | drop_last = dataset_configs.drop_last
74 |
75 | # Dataloaders
76 | data_loader = torch.utils.data.DataLoader(dataset=dataset,
77 | batch_size=hparams["batch_size"],
78 | shuffle=shuffle,
79 | drop_last=drop_last,
80 | num_workers=0)
81 |
82 | return data_loader
83 |
84 |
85 |
86 | def data_generator_old(data_path, domain_id, dataset_configs, hparams):
87 | # loading path
88 | train_dataset = torch.load(os.path.join(data_path, "train_" + domain_id + ".pt"))
89 | test_dataset = torch.load(os.path.join(data_path, "test_" + domain_id + ".pt"))
90 |
91 | # Loading datasets
92 | train_dataset = Load_Dataset(train_dataset, dataset_configs)
93 | test_dataset = Load_Dataset(test_dataset, dataset_configs)
94 |
95 | # Dataloaders
96 | batch_size = hparams["batch_size"]
97 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size,
98 | shuffle=True, drop_last=True, num_workers=0)
99 |
100 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size,
101 | shuffle=False, drop_last=dataset_configs.drop_last, num_workers=0)
102 | return train_loader, test_loader
103 |
104 |
105 |
106 | def few_shot_data_generator(data_loader, dataset_configs, num_samples=5):
107 | x_data = data_loader.dataset.x_data
108 | y_data = data_loader.dataset.y_data
109 |
110 | NUM_SAMPLES_PER_CLASS = num_samples
111 | NUM_CLASSES = len(torch.unique(y_data))
112 |
113 | counts = [y_data.eq(i).sum().item() for i in range(NUM_CLASSES)]
114 | samples_count_dict = {i: min(counts[i], NUM_SAMPLES_PER_CLASS) for i in range(NUM_CLASSES)}
115 |
116 | samples_ids = {i: torch.where(y_data == i)[0] for i in range(NUM_CLASSES)}
117 | selected_ids = {i: torch.randperm(samples_ids[i].size(0))[:samples_count_dict[i]] for i in range(NUM_CLASSES)}
118 |
119 | selected_x = torch.cat([x_data[samples_ids[i][selected_ids[i]]] for i in range(NUM_CLASSES)], dim=0)
120 | selected_y = torch.cat([y_data[samples_ids[i][selected_ids[i]]] for i in range(NUM_CLASSES)], dim=0)
121 |
122 | few_shot_dataset = {"samples": selected_x, "labels": selected_y}
123 | few_shot_dataset = Load_Dataset(few_shot_dataset, dataset_configs)
124 |
125 | few_shot_loader = torch.utils.data.DataLoader(dataset=few_shot_dataset, batch_size=len(few_shot_dataset),
126 | shuffle=False, drop_last=False, num_workers=0)
127 |
128 | return few_shot_loader
129 |
130 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from trainers.train import Trainer
2 |
3 | import argparse
4 | parser = argparse.ArgumentParser()
5 |
6 | if __name__ == "__main__":
7 |
8 | # ======== Experiments Phase ================
9 | parser.add_argument('--phase', default='train', type=str, help='train, test')
10 |
11 | # ======== Experiments Name ================
12 | parser.add_argument('--save_dir', default='experiments_logs', type=str, help='Directory containing all experiments')
13 | parser.add_argument('--exp_name', default='EXP1', type=str, help='experiment name')
14 |
15 | # ========= Select the DA methods ============
16 | parser.add_argument('--da_method', default='MCD', type=str, help='NO_ADAPT, Deep_Coral, MMDA, DANN, CDAN, DIRT, DSAN, HoMM, CoDATS, AdvSKM, SASA, CoTMix, TARGET_ONLY')
17 |
18 | # ========= Select the DATASET ==============
19 | parser.add_argument('--data_path', default=r'../ADATIME_data', type=str, help='Path containing datase2t')
20 | parser.add_argument('--dataset', default='HAR', type=str, help='Dataset of choice: (WISDM - EEG - HAR - HHAR_SA)')
21 |
22 | # ========= Select the BACKBONE ==============
23 | parser.add_argument('--backbone', default='CNN', type=str, help='Backbone of choice: (CNN - RESNET18 - TCN)')
24 |
25 | # ========= Experiment settings ===============
26 | parser.add_argument('--num_runs', default=1, type=int, help='Number of consecutive run with different seeds')
27 | parser.add_argument('--device', default= "cuda", type=str, help='cpu or cuda')
28 |
29 | # arguments
30 | args = parser.parse_args()
31 |
32 | # create trainier object
33 | trainer = Trainer(args)
34 |
35 | # train and test
36 | if args.phase == 'train':
37 | trainer.fit()
38 | elif args.phase == 'test':
39 | trainer.test()
40 |
41 |
42 |
43 | #TODO:
44 | # 1- Change the naming of the functions ---> ( Done)
45 | # 2- Change the algorithms following DCORAL --> (Done)
46 | # 3- Keep one trainer for both train and test -->(Done)
47 | # 4- Create the new joint loader that consider the all possible batches --> Done
48 | # 5- Implement Lower/Upper Bound Approach --> Done
49 | # 6- Add the best hparams --> Done
50 | # 7- Add pretrain based methods (ADDA, MCD, MDD)
51 |
--------------------------------------------------------------------------------
/main_sweep.py:
--------------------------------------------------------------------------------
1 | from trainers.sweep import Trainer
2 | import argparse
3 | parser = argparse.ArgumentParser()
4 |
5 |
6 |
7 |
8 | if __name__ == "__main__":
9 | # ========= Select the DA methods ============
10 | parser.add_argument('--da_method', default='Deep_Coral', type=str,
11 | help='DANN, Deep_Coral, WDGRL, MMDA, VADA, DIRT, CDAN, ADDA, HoMM, CoDATS')
12 |
13 | # ========= Select the DATASET ==============
14 | parser.add_argument('--data_path', default=r'../ADATIME_data', type=str, help='Path containing datase2t')
15 | parser.add_argument('--dataset', default='HAR', type=str, help='Dataset of choice: (WISDM - EEG - HAR - HHAR_SA)')
16 |
17 | # ========= Select the BACKBONE ==============
18 | parser.add_argument('--backbone', default='CNN', type=str, help='Backbone of choice: (CNN - RESNET18 - TCN)')
19 |
20 | # ========= Experiment settings ===============
21 | parser.add_argument('--num_runs', default=1, type=int, help='Number of consecutive run with different seeds')
22 | parser.add_argument('--device', default="cuda", type=str, help='cpu or cuda')
23 | parser.add_argument('--exp_name', default='sweep_EXP1', type=str, help='experiment name')
24 |
25 | # ======== sweep settings =====================
26 | parser.add_argument('--num_sweeps', default=1, type=str, help='Number of sweep runs')
27 |
28 | # We run sweeps using wandb plateform, so next parameters are for wandb.
29 | parser.add_argument('--sweep_project_wandb', default='ADATIME_refactor', type=str, help='Project name in Wandb')
30 | parser.add_argument('--wandb_entity', type=str,
31 | help='Entity name in Wandb (can be left blank if there is a default entity)')
32 | parser.add_argument('--hp_search_strategy', default="random", type=str,
33 | help='The way of selecting hyper-parameters (random-grid-bayes). in wandb see:https://docs.wandb.ai/guides/sweeps/configuration')
34 | parser.add_argument('--metric_to_minimize', default="src_risk", type=str,
35 | help='select one of: (src_risk - trg_risk - few_shot_trg_risk - dev_risk)')
36 |
37 | # ======== Experiments Name ================
38 | parser.add_argument('--save_dir', default='experiments_logs/sweep_logs', type=str,
39 | help='Directory containing all experiments')
40 |
41 | args = parser.parse_args()
42 |
43 | trainer = Trainer(args)
44 |
45 | trainer.train()
46 |
--------------------------------------------------------------------------------
/misc/adatime.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/misc/adatime.PNG
--------------------------------------------------------------------------------
/misc/results.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/misc/results.PNG
--------------------------------------------------------------------------------
/models/__pycache__/loss.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/models/__pycache__/loss.cpython-310.pyc
--------------------------------------------------------------------------------
/models/__pycache__/loss.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/models/__pycache__/loss.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/loss.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/models/__pycache__/loss.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/models.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/models/__pycache__/models.cpython-310.pyc
--------------------------------------------------------------------------------
/models/__pycache__/models.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/models/__pycache__/models.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/models.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/models/__pycache__/models.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resnet18.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/models/__pycache__/resnet18.cpython-310.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resnet18.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/models/__pycache__/resnet18.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resnet18.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/models/__pycache__/resnet18.cpython-39.pyc
--------------------------------------------------------------------------------
/models/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 |
7 | class ConditionalEntropyLoss(torch.nn.Module):
8 | def __init__(self):
9 | super(ConditionalEntropyLoss, self).__init__()
10 |
11 | def forward(self, x):
12 | b = F.softmax(x, dim=1) * F.log_softmax(x, dim=1)
13 | b = b.sum(dim=1)
14 | return -1.0 * b.mean(dim=0)
15 |
16 |
17 | class VAT(nn.Module):
18 | def __init__(self, model, device):
19 | super(VAT, self).__init__()
20 | self.n_power = 1
21 | self.XI = 1e-6
22 | self.model = model
23 | self.epsilon = 3.5
24 | self.device = device
25 |
26 | def forward(self, X, logit):
27 | vat_loss = self.virtual_adversarial_loss(X, logit)
28 | return vat_loss
29 |
30 | def generate_virtual_adversarial_perturbation(self, x, logit):
31 | d = torch.randn_like(x, device=self.device)
32 |
33 | for _ in range(self.n_power):
34 | d = self.XI * self.get_normalized_vector(d).requires_grad_()
35 | logit_m = self.model(x + d)
36 | dist = self.kl_divergence_with_logit(logit, logit_m)
37 | grad = torch.autograd.grad(dist, [d])[0]
38 | d = grad.detach()
39 |
40 | return self.epsilon * self.get_normalized_vector(d)
41 |
42 | def kl_divergence_with_logit(self, q_logit, p_logit):
43 | q = F.softmax(q_logit, dim=1)
44 | qlogq = torch.mean(torch.sum(q * F.log_softmax(q_logit, dim=1), dim=1))
45 | qlogp = torch.mean(torch.sum(q * F.log_softmax(p_logit, dim=1), dim=1))
46 | return qlogq - qlogp
47 |
48 | def get_normalized_vector(self, d):
49 | return F.normalize(d.view(d.size(0), -1), p=2, dim=1).reshape(d.size())
50 |
51 | def virtual_adversarial_loss(self, x, logit):
52 | r_vadv = self.generate_virtual_adversarial_perturbation(x, logit)
53 | logit_p = logit.detach()
54 | logit_m = self.model(x + r_vadv)
55 | loss = self.kl_divergence_with_logit(logit_p, logit_m)
56 | return loss
57 |
58 |
59 | class MMD_loss(nn.Module):
60 | def __init__(self, kernel_type='rbf', kernel_mul=2.0, kernel_num=5):
61 | super(MMD_loss, self).__init__()
62 | self.kernel_num = kernel_num
63 | self.kernel_mul = kernel_mul
64 | self.fix_sigma = None
65 | self.kernel_type = kernel_type
66 |
67 | def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
68 | n_samples = int(source.size()[0]) + int(target.size()[0])
69 | total = torch.cat([source, target], dim=0)
70 | total0 = total.unsqueeze(0).expand(
71 | int(total.size(0)), int(total.size(0)), int(total.size(1)))
72 | total1 = total.unsqueeze(1).expand(
73 | int(total.size(0)), int(total.size(0)), int(total.size(1)))
74 | L2_distance = ((total0 - total1) ** 2).sum(2)
75 | if fix_sigma:
76 | bandwidth = fix_sigma
77 | else:
78 | bandwidth = torch.sum(L2_distance.data) / (n_samples ** 2 - n_samples)
79 | bandwidth /= kernel_mul ** (kernel_num // 2)
80 | bandwidth_list = [bandwidth * (kernel_mul ** i)
81 | for i in range(kernel_num)]
82 | kernel_val = [torch.exp(-L2_distance / bandwidth_temp)
83 | for bandwidth_temp in bandwidth_list]
84 | return sum(kernel_val)
85 |
86 | def linear_mmd2(self, f_of_X, f_of_Y):
87 | loss = 0.0
88 | delta = f_of_X.float().mean(0) - f_of_Y.float().mean(0)
89 | loss = delta.dot(delta.T)
90 | return loss
91 |
92 | def forward(self, source, target):
93 | if self.kernel_type == 'linear':
94 | return self.linear_mmd2(source, target)
95 | elif self.kernel_type == 'rbf':
96 | batch_size = int(source.size()[0])
97 | kernels = self.guassian_kernel(
98 | source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma)
99 | with torch.no_grad():
100 | XX = torch.mean(kernels[:batch_size, :batch_size])
101 | YY = torch.mean(kernels[batch_size:, batch_size:])
102 | XY = torch.mean(kernels[:batch_size, batch_size:])
103 | YX = torch.mean(kernels[batch_size:, :batch_size])
104 | loss = torch.mean(XX + YY - XY - YX)
105 | torch.cuda.empty_cache()
106 | return loss
107 |
108 |
109 | class CORAL(nn.Module):
110 | def __init__(self):
111 | super(CORAL, self).__init__()
112 |
113 | def forward(self, source, target):
114 | d = source.size(1)
115 |
116 | # source covariance
117 | xm = torch.mean(source, 0, keepdim=True) - source
118 | xc = xm.t() @ xm
119 |
120 | # target covariance
121 | xmt = torch.mean(target, 0, keepdim=True) - target
122 | xct = xmt.t() @ xmt
123 |
124 | # frobenius norm between source and target
125 | loss = torch.mean(torch.mul((xc - xct), (xc - xct)))
126 | loss = loss / (4 * d * d)
127 | return loss
128 |
129 |
130 | ### FOR DCAN #######################
131 | def EntropyLoss(input_):
132 | mask = input_.ge(0.0000001)
133 | mask_out = torch.masked_select(input_, mask)
134 | entropy = - (torch.sum(mask_out * torch.log(mask_out)))
135 | return entropy / float(input_.size(0))
136 |
137 |
138 | def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
139 | n_samples = int(source.size()[0]) + int(target.size()[0])
140 | total = torch.cat([source, target], dim=0)
141 | total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
142 | total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
143 | L2_distance = ((total0 - total1) ** 2).sum(2)
144 | if fix_sigma:
145 | bandwidth = fix_sigma
146 | else:
147 | bandwidth = torch.sum(L2_distance.data) / (n_samples ** 2 - n_samples)
148 | bandwidth /= kernel_mul ** (kernel_num // 2)
149 | bandwidth_list = [bandwidth * (kernel_mul ** i) for i in range(kernel_num)]
150 | kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
151 | return sum(kernel_val) # /len(kernel_val)
152 |
153 |
154 | def MMD(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
155 | batch_size = int(source.size()[0])
156 | kernels = guassian_kernel(source, target,
157 | kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
158 | loss = 0
159 | for i in range(batch_size):
160 | s1, s2 = i, (i + 1) % batch_size
161 | t1, t2 = s1 + batch_size, s2 + batch_size
162 | loss += kernels[s1, s2] + kernels[t1, t2]
163 | loss -= kernels[s1, t2] + kernels[s2, t1]
164 | return loss / float(batch_size)
165 |
166 |
167 | def MMD_reg(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
168 | batch_size_source = int(source.size()[0])
169 | batch_size_target = int(target.size()[0])
170 | kernels = guassian_kernel(source, target,
171 | kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
172 | loss = 0
173 | for i in range(batch_size_source):
174 | s1, s2 = i, (i + 1) % batch_size_source
175 | t1, t2 = s1 + batch_size_target, s2 + batch_size_target
176 | loss += kernels[s1, s2] + kernels[t1, t2]
177 | loss -= kernels[s1, t2] + kernels[s2, t1]
178 | return loss / float(batch_size_source + batch_size_target)
179 |
180 |
181 | ### FOR HoMM #######################
182 | class HoMM_loss(nn.Module):
183 | def __init__(self):
184 | super(HoMM_loss, self).__init__()
185 |
186 | def forward(self, xs, xt):
187 | xs = xs - torch.mean(xs, axis=0)
188 | xt = xt - torch.mean(xt, axis=0)
189 | xs = torch.unsqueeze(xs, axis=-1)
190 | xs = torch.unsqueeze(xs, axis=-1)
191 | xt = torch.unsqueeze(xt, axis=-1)
192 | xt = torch.unsqueeze(xt, axis=-1)
193 | xs_1 = xs.permute(0, 2, 1, 3)
194 | xs_2 = xs.permute(0, 2, 3, 1)
195 | xt_1 = xt.permute(0, 2, 1, 3)
196 | xt_2 = xt.permute(0, 2, 3, 1)
197 | HR_Xs = xs * xs_1 * xs_2 # dim: b*L*L*L
198 | HR_Xs = torch.mean(HR_Xs, axis=0) # dim: L*L*L
199 | HR_Xt = xt * xt_1 * xt_2
200 | HR_Xt = torch.mean(HR_Xt, axis=0)
201 | return torch.mean((HR_Xs - HR_Xt) ** 2)
202 |
203 |
204 | ### FOR DSAN #######################
205 | class LMMD_loss(nn.Module):
206 | def __init__(self, device, class_num=3, kernel_type='rbf', kernel_mul=2.0, kernel_num=5, fix_sigma=None):
207 | super(LMMD_loss, self).__init__()
208 | self.class_num = class_num
209 | self.kernel_num = kernel_num
210 | self.kernel_mul = kernel_mul
211 | self.fix_sigma = fix_sigma
212 | self.kernel_type = kernel_type
213 | self.device = device
214 |
215 | def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
216 | n_samples = int(source.size()[0]) + int(target.size()[0])
217 | total = torch.cat([source, target], dim=0)
218 | total0 = total.unsqueeze(0).expand(
219 | int(total.size(0)), int(total.size(0)), int(total.size(1)))
220 | total1 = total.unsqueeze(1).expand(
221 | int(total.size(0)), int(total.size(0)), int(total.size(1)))
222 | L2_distance = ((total0 - total1) ** 2).sum(2)
223 | if fix_sigma:
224 | bandwidth = fix_sigma
225 | else:
226 | bandwidth = torch.sum(L2_distance.data) / (n_samples ** 2 - n_samples)
227 | bandwidth /= kernel_mul ** (kernel_num // 2)
228 | bandwidth_list = [bandwidth * (kernel_mul ** i)
229 | for i in range(kernel_num)]
230 | kernel_val = [torch.exp(-L2_distance / bandwidth_temp)
231 | for bandwidth_temp in bandwidth_list]
232 | return sum(kernel_val)
233 |
234 | def get_loss(self, source, target, s_label, t_label):
235 | batch_size = source.size()[0]
236 | weight_ss, weight_tt, weight_st = self.cal_weight(
237 | s_label, t_label, batch_size=batch_size, class_num=self.class_num)
238 | weight_ss = torch.from_numpy(weight_ss).to(self.device)
239 | weight_tt = torch.from_numpy(weight_tt).to(self.device)
240 | weight_st = torch.from_numpy(weight_st).to(self.device)
241 |
242 | kernels = self.guassian_kernel(source, target,
243 | kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma)
244 | loss = torch.Tensor([0]).to(self.device)
245 | if torch.sum(torch.isnan(sum(kernels))):
246 | return loss
247 | SS = kernels[:batch_size, :batch_size]
248 | TT = kernels[batch_size:, batch_size:]
249 | ST = kernels[:batch_size, batch_size:]
250 |
251 | loss += torch.sum(weight_ss * SS + weight_tt * TT - 2 * weight_st * ST)
252 | return loss
253 |
254 | def convert_to_onehot(self, sca_label, class_num=31):
255 | return np.eye(class_num)[sca_label]
256 |
257 | def cal_weight(self, s_label, t_label, batch_size=32, class_num=4):
258 | batch_size = s_label.size()[0]
259 | s_sca_label = s_label.cpu().data.numpy()
260 | s_vec_label = self.convert_to_onehot(s_sca_label, class_num=self.class_num)
261 | s_sum = np.sum(s_vec_label, axis=0).reshape(1, class_num)
262 | s_sum[s_sum == 0] = 100
263 | s_vec_label = s_vec_label / s_sum
264 |
265 | t_sca_label = t_label.cpu().data.max(1)[1].numpy()
266 | t_vec_label = t_label.cpu().data.numpy()
267 | t_sum = np.sum(t_vec_label, axis=0).reshape(1, class_num)
268 | t_sum[t_sum == 0] = 100
269 | t_vec_label = t_vec_label / t_sum
270 |
271 | index = list(set(s_sca_label) & set(t_sca_label))
272 | mask_arr = np.zeros((batch_size, class_num))
273 | mask_arr[:, index] = 1
274 | t_vec_label = t_vec_label * mask_arr
275 | s_vec_label = s_vec_label * mask_arr
276 |
277 | weight_ss = np.matmul(s_vec_label, s_vec_label.T)
278 | weight_tt = np.matmul(t_vec_label, t_vec_label.T)
279 | weight_st = np.matmul(s_vec_label, t_vec_label.T)
280 |
281 | length = len(index)
282 | if length != 0:
283 | weight_ss = weight_ss / length
284 | weight_tt = weight_tt / length
285 | weight_st = weight_st / length
286 | else:
287 | weight_ss = np.array([0])
288 | weight_tt = np.array([0])
289 | weight_st = np.array([0])
290 | return weight_ss.astype('float32'), weight_tt.astype('float32'), weight_st.astype('float32')
291 |
292 |
293 |
294 | class NTXentLoss(torch.nn.Module):
295 |
296 | def __init__(self, device, batch_size, temperature, use_cosine_similarity):
297 | super(NTXentLoss, self).__init__()
298 | self.batch_size = batch_size
299 | self.temperature = temperature
300 | self.device = device
301 | self.softmax = torch.nn.Softmax(dim=-1)
302 | self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool)
303 | self.similarity_function = self._get_similarity_function(use_cosine_similarity)
304 | self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")
305 |
306 | def _get_similarity_function(self, use_cosine_similarity):
307 | if use_cosine_similarity:
308 | self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
309 | return self._cosine_simililarity
310 | else:
311 | return self._dot_simililarity
312 |
313 | def _get_correlated_mask(self):
314 | diag = np.eye(2 * self.batch_size)
315 | l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size)
316 | l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size)
317 | mask = torch.from_numpy((diag + l1 + l2))
318 | mask = (1 - mask).type(torch.bool)
319 | return mask.to(self.device)
320 |
321 | @staticmethod
322 | def _dot_simililarity(x, y):
323 | v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
324 | # x shape: (N, 1, C)
325 | # y shape: (1, C, 2N)
326 | # v shape: (N, 2N)
327 | return v
328 |
329 | def _cosine_simililarity(self, x, y):
330 | # x shape: (N, 1, C)
331 | # y shape: (1, 2N, C)
332 | # v shape: (N, 2N)
333 | v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
334 | return v
335 |
336 | def forward(self, zis, zjs):
337 | representations = torch.cat([zjs, zis], dim=0)
338 |
339 | similarity_matrix = self.similarity_function(representations, representations)
340 |
341 | # filter out the scores from the positive samples
342 | l_pos = torch.diag(similarity_matrix, self.batch_size)
343 | r_pos = torch.diag(similarity_matrix, -self.batch_size)
344 | positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1)
345 |
346 | negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1)
347 |
348 | logits = torch.cat((positives, negatives), dim=1)
349 | logits /= self.temperature
350 |
351 | labels = torch.zeros(2 * self.batch_size).to(self.device).long()
352 | loss = self.criterion(logits, labels)
353 |
354 | return loss / (2 * self.batch_size)
355 |
356 |
357 | class SupConLoss(torch.nn.Module):
358 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
359 | It also supports the unsupervised contrastive loss in SimCLR"""
360 |
361 | def __init__(self, device, temperature=0.2, contrast_mode='all'):
362 | super(SupConLoss, self).__init__()
363 | self.temperature = temperature
364 | self.contrast_mode = contrast_mode
365 | self.device = device
366 |
367 | def forward(self, features, labels=None, mask=None):
368 | """Compute loss for model. If both `labels` and `mask` are None,
369 | it degenerates to SimCLR unsupervised loss:
370 | https://arxiv.org/pdf/2002.05709.pdf
371 | Args:
372 | features: hidden vector of shape [bsz, n_views, ...].
373 | labels: ground truth of shape [bsz].
374 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
375 | has the same class as sample i. Can be asymmetric.
376 | Returns:
377 | A loss scalar.
378 | """
379 | device = self.device # 'cuda' #(torch.device('cuda')
380 | # if features.is_cuda
381 | # else torch.device('cpu'))
382 |
383 | if len(features.shape) < 3:
384 | raise ValueError('`features` needs to be [bsz, n_views, ...],'
385 | 'at least 3 dimensions are required')
386 | if len(features.shape) > 3:
387 | features = features.view(features.shape[0], features.shape[1], -1)
388 |
389 | batch_size = features.shape[0]
390 | if labels is not None and mask is not None:
391 | raise ValueError('Cannot define both `labels` and `mask`')
392 | elif labels is None and mask is None:
393 | mask = torch.eye(batch_size, dtype=torch.float32).to(device)
394 | elif labels is not None:
395 | labels = labels.contiguous().view(-1, 1)
396 | if labels.shape[0] != batch_size:
397 | raise ValueError('Num of labels does not match num of features')
398 | mask = torch.eq(labels, labels.T).float().to(device)
399 | else:
400 | mask = mask.float().to(device)
401 |
402 | contrast_count = features.shape[1]
403 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
404 | if self.contrast_mode == 'one':
405 | anchor_feature = features[:, 0]
406 | anchor_count = 1
407 | elif self.contrast_mode == 'all':
408 | anchor_feature = contrast_feature
409 | anchor_count = contrast_count
410 | else:
411 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
412 |
413 | # compute logits
414 | anchor_dot_contrast = torch.div(
415 | torch.matmul(anchor_feature, contrast_feature.T),
416 | self.temperature)
417 | # for numerical stability
418 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
419 | logits = anchor_dot_contrast - logits_max.detach()
420 |
421 | # tile mask
422 | mask = mask.repeat(anchor_count, contrast_count)
423 | # mask-out self-contrast cases
424 | logits_mask = torch.scatter(
425 | torch.ones_like(mask),
426 | 1,
427 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
428 | 0
429 | )
430 | mask = mask * logits_mask
431 |
432 | # compute log_prob
433 | exp_logits = torch.exp(logits) * logits_mask
434 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
435 |
436 | # compute mean of log-likelihood over positive
437 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
438 |
439 | # loss
440 | # loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
441 | loss = - self.temperature * mean_log_prob_pos
442 | loss = loss.view(anchor_count, batch_size).mean()
443 |
444 | return loss
445 |
--------------------------------------------------------------------------------
/models/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import math
4 | from torch.autograd import Function
5 | from torch.nn.utils import weight_norm
6 | import torch.nn.functional as F
7 | from .resnet18 import resnet18
8 |
9 |
10 | # from utils import weights_init
11 |
12 | def get_backbone_class(backbone_name):
13 | """Return the algorithm class with the given name."""
14 | if backbone_name not in globals():
15 | raise NotImplementedError("Algorithm not found: {}".format(backbone_name))
16 | return globals()[backbone_name]
17 |
18 |
19 | ##################################################
20 | ########## BACKBONE NETWORKS ###################
21 | ##################################################
22 |
23 | ########## CNN #############################
24 | class CNN(nn.Module):
25 | def __init__(self, configs):
26 | super(CNN, self).__init__()
27 |
28 | self.conv_block1 = nn.Sequential(
29 | nn.Conv1d(configs.input_channels, configs.mid_channels, kernel_size=configs.kernel_size,
30 | stride=configs.stride, bias=False, padding=(configs.kernel_size // 2)),
31 | nn.BatchNorm1d(configs.mid_channels),
32 | nn.ReLU(),
33 | nn.MaxPool1d(kernel_size=2, stride=2, padding=1),
34 | nn.Dropout(configs.dropout)
35 | )
36 |
37 | self.conv_block2 = nn.Sequential(
38 | nn.Conv1d(configs.mid_channels, configs.mid_channels * 2, kernel_size=8, stride=1, bias=False, padding=4),
39 | nn.BatchNorm1d(configs.mid_channels * 2),
40 | nn.ReLU(),
41 | nn.MaxPool1d(kernel_size=2, stride=2, padding=1)
42 | )
43 |
44 | self.conv_block3 = nn.Sequential(
45 | nn.Conv1d(configs.mid_channels * 2, configs.final_out_channels, kernel_size=8, stride=1, bias=False,
46 | padding=4),
47 | nn.BatchNorm1d(configs.final_out_channels),
48 | nn.ReLU(),
49 | nn.MaxPool1d(kernel_size=2, stride=2, padding=1),
50 | )
51 |
52 | self.adaptive_pool = nn.AdaptiveAvgPool1d(configs.features_len)
53 |
54 | def forward(self, x_in):
55 | x = self.conv_block1(x_in)
56 | x = self.conv_block2(x)
57 | x = self.conv_block3(x)
58 | x = self.adaptive_pool(x)
59 |
60 | x_flat = x.reshape(x.shape[0], -1)
61 | return x_flat
62 |
63 |
64 |
65 | class classifier(nn.Module):
66 | def __init__(self, configs):
67 | super(classifier, self).__init__()
68 | self.logits = nn.Linear(configs.features_len * configs.final_out_channels, configs.num_classes)
69 | self.configs = configs
70 |
71 | def forward(self, x):
72 |
73 | predictions = self.logits(x)
74 |
75 | return predictions
76 |
77 |
78 |
79 | ########## TCN #############################
80 | torch.backends.cudnn.benchmark = True # might be required to fasten TCN
81 |
82 |
83 | class Chomp1d(nn.Module):
84 | def __init__(self, chomp_size):
85 | super(Chomp1d, self).__init__()
86 | self.chomp_size = chomp_size
87 |
88 | def forward(self, x):
89 | return x[:, :, :-self.chomp_size].contiguous()
90 |
91 |
92 | class TCN(nn.Module):
93 | def __init__(self, configs):
94 | super(TCN, self).__init__()
95 |
96 | in_channels0 = configs.input_channels
97 | out_channels0 = configs.tcn_layers[1]
98 | kernel_size = configs.tcn_kernel_size
99 | stride = 1
100 | dilation0 = 1
101 | padding0 = (kernel_size - 1) * dilation0
102 |
103 | self.net0 = nn.Sequential(
104 | weight_norm(nn.Conv1d(in_channels0, out_channels0, kernel_size, stride=stride, padding=padding0,
105 | dilation=dilation0)),
106 | nn.ReLU(),
107 | weight_norm(nn.Conv1d(out_channels0, out_channels0, kernel_size, stride=stride, padding=padding0,
108 | dilation=dilation0)),
109 | nn.ReLU(),
110 | )
111 |
112 | self.downsample0 = nn.Conv1d(in_channels0, out_channels0, 1) if in_channels0 != out_channels0 else None
113 | self.relu = nn.ReLU()
114 |
115 | in_channels1 = configs.tcn_layers[0]
116 | out_channels1 = configs.tcn_layers[1]
117 | dilation1 = 2
118 | padding1 = (kernel_size - 1) * dilation1
119 | self.net1 = nn.Sequential(
120 | nn.Conv1d(in_channels0, out_channels1, kernel_size, stride=stride, padding=padding1, dilation=dilation1),
121 | nn.ReLU(),
122 | nn.Conv1d(out_channels1, out_channels1, kernel_size, stride=stride, padding=padding1, dilation=dilation1),
123 | nn.ReLU(),
124 | )
125 | self.downsample1 = nn.Conv1d(out_channels1, out_channels1, 1) if in_channels1 != out_channels1 else None
126 |
127 | self.conv_block1 = nn.Sequential(
128 | nn.Conv1d(in_channels0, out_channels0, kernel_size=kernel_size, stride=stride, bias=False, padding=padding0,
129 | dilation=dilation0),
130 | Chomp1d(padding0),
131 | nn.BatchNorm1d(out_channels0),
132 | nn.ReLU(),
133 |
134 | nn.Conv1d(out_channels0, out_channels0, kernel_size=kernel_size, stride=stride, bias=False,
135 | padding=padding0, dilation=dilation0),
136 | Chomp1d(padding0),
137 | nn.BatchNorm1d(out_channels0),
138 | nn.ReLU(),
139 | )
140 |
141 | self.conv_block2 = nn.Sequential(
142 | nn.Conv1d(out_channels0, out_channels1, kernel_size=kernel_size, stride=stride, bias=False,
143 | padding=padding1, dilation=dilation1),
144 | Chomp1d(padding1),
145 | nn.BatchNorm1d(out_channels1),
146 | nn.ReLU(),
147 |
148 | nn.Conv1d(out_channels1, out_channels1, kernel_size=kernel_size, stride=stride, bias=False,
149 | padding=padding1, dilation=dilation1),
150 | Chomp1d(padding1),
151 | nn.BatchNorm1d(out_channels1),
152 | nn.ReLU(),
153 | )
154 |
155 | def forward(self, inputs):
156 | """Inputs have to have dimension (N, C_in, L_in)"""
157 | x0 = self.conv_block1(inputs)
158 | res0 = inputs if self.downsample0 is None else self.downsample0(inputs)
159 | out_0 = self.relu(x0 + res0)
160 |
161 | x1 = self.conv_block2(out_0)
162 | res1 = out_0 if self.downsample1 is None else self.downsample1(out_0)
163 | out_1 = self.relu(x1 + res1)
164 |
165 | out = out_1[:, :, -1]
166 | return out
167 |
168 |
169 | ######## RESNET ##############################################
170 |
171 | class RESNET18(nn.Module):
172 | def __init__(self, configs):
173 | super(RESNET18, self).__init__()
174 | self.resnet = resnet18(configs)
175 | def forward(self, x_in):
176 | x = self.resnet(x_in)
177 | x_flat = x.reshape(x.shape[0], -1)
178 | return x_flat
179 |
180 | class BasicBlock(nn.Module):
181 | expansion = 1
182 |
183 | def __init__(self, inplanes, planes, stride=1, downsample=None):
184 | super(BasicBlock, self).__init__()
185 | self.conv1 = nn.Conv1d(inplanes, planes, kernel_size=1, stride=stride,
186 | bias=False)
187 | self.bn1 = nn.BatchNorm1d(planes)
188 |
189 | self.downsample = downsample
190 | self.stride = stride
191 |
192 | def forward(self, x):
193 | residual = x
194 |
195 | out = self.conv1(x)
196 | out = self.bn1(out)
197 | out = F.relu(out)
198 |
199 | if self.downsample is not None:
200 | residual = self.downsample(x)
201 |
202 | out += residual
203 | out = F.relu(out)
204 |
205 | return out
206 |
207 |
208 | ##################################################
209 | ########## OTHER NETWORKS ######################
210 | ##################################################
211 |
212 | class codats_classifier(nn.Module):
213 | def __init__(self, configs):
214 | super(codats_classifier, self).__init__()
215 | model_output_dim = configs.features_len
216 | self.hidden_dim = configs.hidden_dim
217 | self.logits = nn.Sequential(
218 | nn.Linear(model_output_dim * configs.final_out_channels, self.hidden_dim),
219 | nn.ReLU(),
220 | nn.Linear(self.hidden_dim, self.hidden_dim),
221 | nn.ReLU(),
222 | nn.Linear(self.hidden_dim, configs.num_classes))
223 |
224 | def forward(self, x_in):
225 | predictions = self.logits(x_in)
226 | return predictions
227 |
228 |
229 | class Discriminator(nn.Module):
230 | """Discriminator model for source domain."""
231 |
232 | def __init__(self, configs):
233 | """Init discriminator."""
234 | super(Discriminator, self).__init__()
235 |
236 | self.layer = nn.Sequential(
237 | nn.Linear(configs.features_len * configs.final_out_channels, configs.disc_hid_dim),
238 | nn.ReLU(),
239 | nn.Linear(configs.disc_hid_dim, configs.disc_hid_dim),
240 | nn.ReLU(),
241 | nn.Linear(configs.disc_hid_dim, 2)
242 | # nn.LogSoftmax(dim=1)
243 | )
244 |
245 | def forward(self, input):
246 | """Forward the discriminator."""
247 | out = self.layer(input)
248 | return out
249 |
250 |
251 | #### Codes required by DANN ##############
252 | class ReverseLayerF(Function):
253 | @staticmethod
254 | def forward(ctx, x, alpha):
255 | ctx.alpha = alpha
256 | return x.view_as(x)
257 |
258 | @staticmethod
259 | def backward(ctx, grad_output):
260 | output = grad_output.neg() * ctx.alpha
261 | return output, None
262 |
263 |
264 | #### Codes required by CDAN ##############
265 | class RandomLayer(nn.Module):
266 | def __init__(self, input_dim_list=[], output_dim=1024):
267 | super(RandomLayer, self).__init__()
268 | self.input_num = len(input_dim_list)
269 | self.output_dim = output_dim
270 | self.random_matrix = [torch.randn(input_dim_list[i], output_dim) for i in range(self.input_num)]
271 |
272 | def forward(self, input_list):
273 | return_list = [torch.mm(input_list[i], self.random_matrix[i]) for i in range(self.input_num)]
274 | return_tensor = return_list[0] / math.pow(float(self.output_dim), 1.0 / len(return_list))
275 | for single in return_list[1:]:
276 | return_tensor = torch.mul(return_tensor, single)
277 | return return_tensor
278 |
279 | def cuda(self):
280 | super(RandomLayer, self).cuda()
281 | self.random_matrix = [val.cuda() for val in self.random_matrix]
282 |
283 |
284 | class Discriminator_CDAN(nn.Module):
285 | """Discriminator model for CDAN ."""
286 |
287 | def __init__(self, configs):
288 | """Init discriminator."""
289 | super(Discriminator_CDAN, self).__init__()
290 |
291 | self.restored = False
292 |
293 | self.layer = nn.Sequential(
294 | nn.Linear(configs.features_len * configs.final_out_channels * configs.num_classes, configs.disc_hid_dim),
295 | nn.ReLU(),
296 | nn.Linear(configs.disc_hid_dim, configs.disc_hid_dim),
297 | nn.ReLU(),
298 | nn.Linear(configs.disc_hid_dim, 2)
299 | # nn.LogSoftmax(dim=1)
300 | )
301 |
302 | def forward(self, input):
303 | """Forward the discriminator."""
304 | out = self.layer(input)
305 | return out
306 |
307 |
308 | #### Codes required by AdvSKM ##############
309 | class Cosine_act(nn.Module):
310 | def __init__(self):
311 | super(Cosine_act, self).__init__()
312 |
313 | def forward(self, input):
314 | return torch.cos(input)
315 |
316 |
317 | cos_act = Cosine_act()
318 |
319 | class AdvSKM_Disc(nn.Module):
320 | """Discriminator model for source domain."""
321 |
322 | def __init__(self, configs):
323 | """Init discriminator."""
324 | super(AdvSKM_Disc, self).__init__()
325 |
326 | self.input_dim = configs.features_len * configs.final_out_channels
327 | self.hid_dim = configs.DSKN_disc_hid
328 | self.branch_1 = nn.Sequential(
329 | nn.Linear(self.input_dim, self.hid_dim),
330 | nn.Linear(self.hid_dim, self.hid_dim),
331 | nn.BatchNorm1d(self.hid_dim),
332 | cos_act,
333 | nn.Linear(self.hid_dim, self.hid_dim // 2),
334 | nn.Linear(self.hid_dim // 2, self.hid_dim // 2),
335 | nn.BatchNorm1d(self.hid_dim // 2),
336 | cos_act
337 | )
338 | self.branch_2 = nn.Sequential(
339 | nn.Linear(configs.features_len * configs.final_out_channels, configs.disc_hid_dim),
340 | nn.Linear(configs.disc_hid_dim, configs.disc_hid_dim),
341 | nn.BatchNorm1d(configs.disc_hid_dim),
342 | nn.ReLU(),
343 | nn.Linear(configs.disc_hid_dim, configs.disc_hid_dim // 2),
344 | nn.Linear(configs.disc_hid_dim // 2, configs.disc_hid_dim // 2),
345 | nn.BatchNorm1d(configs.disc_hid_dim // 2),
346 | nn.ReLU())
347 |
348 | def forward(self, input):
349 | """Forward the discriminator."""
350 | out_cos = self.branch_1(input)
351 | out_rel = self.branch_2(input)
352 | total_out = torch.cat((out_cos, out_rel), dim=1)
353 | return total_out
354 |
355 | # SASA model
356 | class CNN_ATTN(nn.Module):
357 | def __init__(self, configs):
358 | super(CNN_ATTN, self).__init__()
359 |
360 | self.conv_block1 = nn.Sequential(
361 | nn.Conv1d(configs.input_channels, configs.mid_channels, kernel_size=configs.kernel_size,
362 | stride=configs.stride, bias=False, padding=(configs.kernel_size // 2)),
363 | nn.BatchNorm1d(configs.mid_channels),
364 | nn.ReLU(),
365 | nn.MaxPool1d(kernel_size=2, stride=2, padding=1),
366 | nn.Dropout(configs.dropout)
367 | )
368 |
369 | self.conv_block2 = nn.Sequential(
370 | nn.Conv1d(configs.mid_channels, configs.mid_channels * 2, kernel_size=8, stride=1, bias=False, padding=4),
371 | nn.BatchNorm1d(configs.mid_channels * 2),
372 | nn.ReLU(),
373 | nn.MaxPool1d(kernel_size=2, stride=2, padding=1)
374 | )
375 |
376 | self.conv_block3 = nn.Sequential(
377 | nn.Conv1d(configs.mid_channels * 2, configs.final_out_channels, kernel_size=8, stride=1, bias=False,
378 | padding=4),
379 | nn.BatchNorm1d(configs.final_out_channels),
380 | nn.ReLU(),
381 | nn.MaxPool1d(kernel_size=2, stride=2, padding=1),
382 | )
383 |
384 | self.adaptive_pool = nn.AdaptiveAvgPool1d(configs.features_len)
385 | self.attn_network = attn_network(configs)
386 | self.sparse_max = Sparsemax(dim=-1)
387 | self.feat_len = configs.features_len
388 |
389 | def forward(self, x_in):
390 | x = self.conv_block1(x_in)
391 | x = self.conv_block2(x)
392 | x = self.conv_block3(x)
393 | x = self.adaptive_pool(x)
394 | x_flat = x.reshape(x.shape[0], -1)
395 | attentive_feat = self.calculate_attentive_feat(x_flat)
396 | return attentive_feat
397 |
398 | def self_attention(self, Q, K, scale=True, sparse=True, k=3):
399 |
400 | attention_weight = torch.bmm(Q.view(Q.shape[0], self.feat_len, -1), K.view(K.shape[0], -1, self.feat_len))
401 |
402 | attention_weight = torch.mean(attention_weight, dim=2, keepdim=True)
403 |
404 | if scale:
405 | d_k = torch.tensor(K.shape[-1]).float()
406 | attention_weight = attention_weight / torch.sqrt(d_k)
407 | if sparse:
408 | attention_weight_sparse = self.sparse_max(torch.reshape(attention_weight, [-1, self.feat_len]))
409 | attention_weight = torch.reshape(attention_weight_sparse, [-1, attention_weight.shape[1],
410 | attention_weight.shape[2]])
411 | else:
412 | attention_weight = self.softmax(attention_weight)
413 |
414 | return attention_weight
415 |
416 | def attention_fn(self, Q, K, scaled=False, sparse=True, k=1):
417 |
418 | attention_weight = torch.matmul(F.normalize(Q, p=2, dim=-1),
419 | F.normalize(K, p=2, dim=-1).view(K.shape[0], K.shape[1], -1, self.feat_len))
420 |
421 | if scaled:
422 | d_k = torch.tensor(K.shape[-1]).float()
423 | attention_weight = attention_weight / torch.sqrt(d_k)
424 | attention_weight = k * torch.log(torch.tensor(self.feat_len, dtype=torch.float32)) * attention_weight
425 |
426 | if sparse:
427 | attention_weight_sparse = self.sparse_max(torch.reshape(attention_weight, [-1, self.feat_len]))
428 |
429 | attention_weight = torch.reshape(attention_weight_sparse, attention_weight.shape)
430 | else:
431 | attention_weight = self.softmax(attention_weight)
432 |
433 | return attention_weight
434 |
435 | def calculate_attentive_feat(self, candidate_representation_xi):
436 | Q_xi, K_xi, V_xi = self.attn_network(candidate_representation_xi)
437 | intra_attention_weight_xi = self.self_attention(Q=Q_xi, K=K_xi, sparse=True)
438 | Z_i = torch.bmm(intra_attention_weight_xi.view(intra_attention_weight_xi.shape[0], 1, -1),
439 | V_xi.view(V_xi.shape[0], self.feat_len, -1))
440 | final_feature = F.normalize(Z_i, dim=-1).view(Z_i.shape[0],-1)
441 |
442 | return final_feature
443 |
444 | class attn_network(nn.Module):
445 | def __init__(self, configs):
446 | super(attn_network, self).__init__()
447 |
448 | self.h_dim = configs.features_len * configs.final_out_channels
449 | self.self_attn_Q = nn.Sequential(nn.Linear(in_features=self.h_dim, out_features=self.h_dim),
450 | nn.ELU()
451 | )
452 | self.self_attn_K = nn.Sequential(nn.Linear(in_features=self.h_dim, out_features=self.h_dim),
453 | nn.LeakyReLU()
454 | )
455 | self.self_attn_V = nn.Sequential(nn.Linear(in_features=self.h_dim, out_features=self.h_dim),
456 | nn.LeakyReLU()
457 | )
458 |
459 | def forward(self, x):
460 | Q = self.self_attn_Q(x)
461 | K = self.self_attn_K(x)
462 | V = self.self_attn_V(x)
463 |
464 | return Q, K, V
465 |
466 |
467 | # Sparse max
468 | class Sparsemax(nn.Module):
469 | """Sparsemax function."""
470 |
471 | def __init__(self, dim=None):
472 | """Initialize sparsemax activation
473 |
474 | Args:
475 | dim (int, optional): The dimension over which to apply the sparsemax function.
476 | """
477 | super(Sparsemax, self).__init__()
478 |
479 | self.dim = -1 if dim is None else dim
480 |
481 | def forward(self, input):
482 | """Forward function.
483 | Args:
484 | input (torch.Tensor): Input tensor. First dimension should be the batch size
485 | Returns:
486 | torch.Tensor: [batch_size x number_of_logits] Output tensor
487 | """
488 | # Sparsemax currently only handles 2-dim tensors,
489 | # so we reshape to a convenient shape and reshape back after sparsemax
490 | input = input.transpose(0, self.dim)
491 | original_size = input.size()
492 | input = input.reshape(input.size(0), -1)
493 | input = input.transpose(0, 1)
494 | dim = 1
495 |
496 | number_of_logits = input.size(dim)
497 |
498 | # Translate input by max for numerical stability
499 | input = input - torch.max(input, dim=dim, keepdim=True)[0].expand_as(input)
500 |
501 | # Sort input in descending order.
502 | # (NOTE: Can be replaced with linear time selection method described here:
503 | # http://stanford.edu/~jduchi/projects/DuchiShSiCh08.html)
504 | zs = torch.sort(input=input, dim=dim, descending=True)[0]
505 | range = torch.arange(start=1, end=number_of_logits + 1, step=1, device=input.device, dtype=input.dtype).view(1,
506 | -1)
507 | range = range.expand_as(zs)
508 |
509 | # Determine sparsity of projection
510 | bound = 1 + range * zs
511 | cumulative_sum_zs = torch.cumsum(zs, dim)
512 | is_gt = torch.gt(bound, cumulative_sum_zs).type(input.type())
513 | k = torch.max(is_gt * range, dim, keepdim=True)[0]
514 |
515 | # Compute threshold function
516 | zs_sparse = is_gt * zs
517 |
518 | # Compute taus
519 | taus = (torch.sum(zs_sparse, dim, keepdim=True) - 1) / k
520 | taus = taus.expand_as(input)
521 |
522 | # Sparsemax
523 | self.output = torch.max(torch.zeros_like(input), input - taus)
524 |
525 | # Reshape back to original shape
526 | output = self.output
527 | output = output.transpose(0, 1)
528 | output = output.reshape(original_size)
529 | output = output.transpose(0, self.dim)
530 |
531 | return output
532 |
533 | def backward(self, grad_output):
534 | """Backward function."""
535 | dim = 1
536 |
537 | nonzeros = torch.ne(self.output, 0)
538 | sum = torch.sum(grad_output * nonzeros, dim=dim) / torch.sum(nonzeros, dim=dim)
539 | self.grad_input = nonzeros * (grad_output - sum.expand_as(grad_output))
540 |
541 | return self.grad_input
542 |
--------------------------------------------------------------------------------
/models/resnet18.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import Any, Callable, List, Optional, Type, Union
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch import Tensor
7 |
8 |
9 |
10 |
11 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv1d:
12 | """3x3 convolution with padding"""
13 | return nn.Conv1d(
14 | in_planes,
15 | out_planes,
16 | kernel_size=3,
17 | stride=stride,
18 | padding=dilation,
19 | groups=groups,
20 | bias=False,
21 | dilation=dilation
22 | )
23 |
24 |
25 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv1d:
26 | """1x1 convolution"""
27 | return nn.Conv1d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
28 |
29 |
30 | class BasicBlock(nn.Module):
31 | expansion: int = 1
32 |
33 | def __init__(
34 | self,
35 | inplanes: int,
36 | planes: int,
37 | stride: int = 1,
38 | downsample: Optional[nn.Module] = None,
39 | groups: int = 1,
40 | base_width: int = 64,
41 | dilation: int = 1,
42 | norm_layer: Optional[Callable[..., nn.Module]] = None
43 | ) -> None:
44 | super().__init__()
45 | if norm_layer is None:
46 | norm_layer = nn.BatchNorm1d
47 | if groups != 1 or base_width != 64:
48 | raise ValueError("BasicBlock only supports groups=1 and base_width=64")
49 | if dilation > 1:
50 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
51 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
52 | self.conv1 = conv3x3(inplanes, planes, stride)
53 | self.bn1 = norm_layer(planes)
54 | self.relu = nn.ReLU(inplace=True)
55 | self.conv2 = conv3x3(planes, planes)
56 | self.bn2 = norm_layer(planes)
57 | self.downsample = downsample
58 | self.stride = stride
59 |
60 | def forward(self, x: Tensor) -> Tensor:
61 | identity = x
62 |
63 | out = self.conv1(x)
64 | out = self.bn1(out)
65 | out = self.relu(out)
66 |
67 | out = self.conv2(out)
68 | out = self.bn2(out)
69 |
70 | if self.downsample is not None:
71 | identity = self.downsample(x)
72 |
73 | out += identity
74 | out = self.relu(out)
75 |
76 | return out
77 |
78 |
79 | class ResNet(nn.Module):
80 | def __init__(
81 | self,
82 | block: Type[BasicBlock],
83 | layers: List[int],
84 | configs,
85 | num_classes: int = 5,
86 | zero_init_residual: bool = False,
87 | groups: int = 1,
88 | width_per_group: int = 64,
89 | replace_stride_with_dilation: Optional[List[bool]] = None,
90 | norm_layer: Optional[Callable[..., nn.Module]] = None
91 | ) -> None:
92 | super().__init__()
93 | if norm_layer is None:
94 | norm_layer = nn.BatchNorm1d
95 | self._norm_layer = norm_layer
96 |
97 | self.inplanes = 64
98 | self.dilation = 1
99 | if replace_stride_with_dilation is None:
100 | # each element in the tuple indicates if we should replace
101 | # the 2x2 stride with a dilated convolution instead
102 | replace_stride_with_dilation = [False, False, False]
103 | if len(replace_stride_with_dilation) != 3:
104 | raise ValueError(
105 | "replace_stride_with_dilation should be None "
106 | f"or a 3-element tuple, got {replace_stride_with_dilation}"
107 | )
108 | self.groups = groups
109 | self.base_width = width_per_group
110 | self.conv1 = nn.Conv1d(configs.input_channels, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
111 | self.bn1 = norm_layer(self.inplanes)
112 | self.relu = nn.ReLU(inplace=True)
113 | self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
114 | self.layer1 = self._make_layer(block, 64, layers[0])
115 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
116 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
117 | self.layer4 = self._make_layer(block, configs.final_out_channels, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
118 |
119 | self.avgpool = nn.AdaptiveAvgPool1d(1)
120 | self.fc = nn.Linear(512 * block.expansion, num_classes)
121 |
122 | for m in self.modules():
123 | if isinstance(m, nn.Conv1d):
124 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
125 | elif isinstance(m, (nn.BatchNorm1d, nn.GroupNorm)):
126 | nn.init.constant_(m.weight, 1)
127 | nn.init.constant_(m.bias, 0)
128 |
129 | # Zero-initialize the last BN in each residual branch,
130 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
131 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
132 | if zero_init_residual:
133 | for m in self.modules():
134 | if isinstance(m, Bottleneck) and m.bn3.weight is not None:
135 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
136 | elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
137 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
138 |
139 | def _make_layer(
140 | self,
141 | block: Type[BasicBlock],
142 | planes: int,
143 | blocks: int,
144 | stride: int = 1,
145 | dilate: bool = False
146 | ) -> nn.Sequential:
147 | norm_layer = self._norm_layer
148 | downsample = None
149 | previous_dilation = self.dilation
150 | if dilate:
151 | self.dilation *= stride
152 | stride = 1
153 | if stride != 1 or self.inplanes != planes * block.expansion:
154 | downsample = nn.Sequential(
155 | conv1x1(self.inplanes, planes * block.expansion, stride),
156 | norm_layer(planes * block.expansion)
157 | )
158 |
159 | layers = []
160 | layers.append(
161 | block(
162 | self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
163 | )
164 | )
165 | self.inplanes = planes * block.expansion
166 | for _ in range(1, blocks):
167 | layers.append(
168 | block(
169 | self.inplanes,
170 | planes,
171 | groups=self.groups,
172 | base_width=self.base_width,
173 | dilation=self.dilation,
174 | norm_layer=norm_layer
175 | )
176 | )
177 |
178 | return nn.Sequential(*layers)
179 |
180 | def _forward_impl(self, x: Tensor) -> Tensor:
181 | # See note [TorchScript super()]
182 | x = self.conv1(x)
183 | x = self.bn1(x)
184 | x = self.relu(x)
185 | x = self.maxpool(x)
186 |
187 | x = self.layer1(x)
188 | x = self.layer2(x)
189 | x = self.layer3(x)
190 | x = self.layer4(x)
191 |
192 | x = self.avgpool(x)
193 | # x = torch.flatten(x, 1)
194 | # x = self.fc(x)
195 |
196 | return x
197 |
198 | def forward(self, x: Tensor) -> Tensor:
199 | return self._forward_impl(x)
200 |
201 |
202 | def resnet18(configs) -> ResNet:
203 | layers = [2, 2, 2, 2]
204 | model = ResNet(BasicBlock, layers, configs)
205 |
206 | return model
207 |
--------------------------------------------------------------------------------
/trainers/abstract_trainer.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../../ADATIME/')
3 | import torch
4 | import torch.nn.functional as F
5 | from torchmetrics import Accuracy, AUROC, F1Score
6 | import os
7 | import wandb
8 | import pandas as pd
9 | import numpy as np
10 | import warnings
11 | import sklearn.exceptions
12 | import collections
13 |
14 | from torchmetrics import Accuracy, AUROC, F1Score
15 | from dataloader.dataloader import data_generator, few_shot_data_generator
16 | from configs.data_model_configs import get_dataset_class
17 | from configs.hparams import get_hparams_class
18 | from configs.sweep_params import sweep_alg_hparams
19 | from utils import fix_randomness, starting_logs, DictAsObject,AverageMeter
20 | from algorithms.algorithms import get_algorithm_class
21 | from models.models import get_backbone_class
22 |
23 | warnings.filterwarnings("ignore", category=sklearn.exceptions.UndefinedMetricWarning)
24 |
25 | class AbstractTrainer(object):
26 | """
27 | This class contain the main training functions for our AdAtime
28 | """
29 |
30 | def __init__(self, args):
31 | self.da_method = args.da_method # Selected DA Method
32 | self.dataset = args.dataset # Selected Dataset
33 | self.backbone = args.backbone
34 | self.device = torch.device(args.device) # device
35 |
36 | # Exp Description
37 | self.experiment_description = args.dataset
38 | self.run_description = f"{args.da_method}_{args.exp_name}"
39 |
40 | # paths
41 | self.home_path = os.getcwd() #os.path.dirname(os.getcwd())
42 | self.save_dir = args.save_dir
43 | self.data_path = os.path.join(args.data_path, self.dataset)
44 | # self.create_save_dir(os.path.join(self.home_path, self.save_dir ))
45 | self.exp_log_dir = os.path.join(self.home_path, self.save_dir, self.experiment_description, f"{self.run_description}")
46 | os.makedirs(self.exp_log_dir, exist_ok=True)
47 |
48 |
49 |
50 |
51 | # Specify runs
52 | self.num_runs = args.num_runs
53 |
54 | # get dataset and base model configs
55 | self.dataset_configs, self.hparams_class = self.get_configs()
56 |
57 | # to fix dimension of features in classifier and discriminator networks.
58 | self.dataset_configs.final_out_channels = self.dataset_configs.tcn_final_out_channles if args.backbone == "TCN" else self.dataset_configs.final_out_channels
59 |
60 | # Specify number of hparams
61 | self.hparams = {**self.hparams_class.alg_hparams[self.da_method],
62 | **self.hparams_class.train_params}
63 |
64 | # metrics
65 | self.num_classes = self.dataset_configs.num_classes
66 | self.ACC = Accuracy(task="multiclass", num_classes=self.num_classes)
67 | self.F1 = F1Score(task="multiclass", num_classes=self.num_classes, average="macro")
68 | self.AUROC = AUROC(task="multiclass", num_classes=self.num_classes)
69 |
70 | # metrics
71 |
72 | def sweep(self):
73 | # sweep configurations
74 | pass
75 |
76 | def initialize_algorithm(self):
77 | # get algorithm class
78 | algorithm_class = get_algorithm_class(self.da_method)
79 | backbone_fe = get_backbone_class(self.backbone)
80 |
81 | # Initilaize the algorithm
82 | self.algorithm = algorithm_class(backbone_fe, self.dataset_configs, self.hparams, self.device)
83 | self.algorithm.to(self.device)
84 |
85 | def load_checkpoint(self, model_dir):
86 | checkpoint = torch.load(os.path.join(self.home_path, model_dir, 'checkpoint.pt'))
87 | last_model = checkpoint['last']
88 | best_model = checkpoint['best']
89 | return last_model, best_model
90 |
91 | def train_model(self):
92 | # Get the algorithm and the backbone network
93 | algorithm_class = get_algorithm_class(self.da_method)
94 | backbone_fe = get_backbone_class(self.backbone)
95 |
96 | # Initilaize the algorithm
97 | self.algorithm = algorithm_class(backbone_fe, self.dataset_configs, self.hparams, self.device)
98 | self.algorithm.to(self.device)
99 |
100 | # Training the model
101 | self.last_model, self.best_model = self.algorithm.update(self.src_train_dl, self.trg_train_dl, self.loss_avg_meters, self.logger)
102 | return self.last_model, self.best_model
103 |
104 | def evaluate(self, test_loader):
105 | feature_extractor = self.algorithm.feature_extractor.to(self.device)
106 | classifier = self.algorithm.classifier.to(self.device)
107 |
108 | feature_extractor.eval()
109 | classifier.eval()
110 |
111 | total_loss, preds_list, labels_list = [], [], []
112 |
113 | with torch.no_grad():
114 | for data, labels in test_loader:
115 | data = data.float().to(self.device)
116 | labels = labels.view((-1)).long().to(self.device)
117 |
118 | # forward pass
119 | features = feature_extractor(data)
120 | predictions = classifier(features)
121 |
122 | # compute loss
123 | loss = F.cross_entropy(predictions, labels)
124 | total_loss.append(loss.item())
125 | pred = predictions.detach() # .argmax(dim=1) # get the index of the max log-probability
126 |
127 | # append predictions and labels
128 | preds_list.append(pred)
129 | labels_list.append(labels)
130 |
131 | self.loss = torch.tensor(total_loss).mean() # average loss
132 | self.full_preds = torch.cat((preds_list))
133 | self.full_labels = torch.cat((labels_list))
134 |
135 | def get_configs(self):
136 | dataset_class = get_dataset_class(self.dataset)
137 | hparams_class = get_hparams_class(self.dataset)
138 | return dataset_class(), hparams_class()
139 |
140 | def load_data(self, src_id, trg_id):
141 | self.src_train_dl = data_generator(self.data_path, src_id, self.dataset_configs, self.hparams, "train")
142 | self.src_test_dl = data_generator(self.data_path, src_id, self.dataset_configs, self.hparams, "test")
143 |
144 | self.trg_train_dl = data_generator(self.data_path, trg_id, self.dataset_configs, self.hparams, "train")
145 | self.trg_test_dl = data_generator(self.data_path, trg_id, self.dataset_configs, self.hparams, "test")
146 |
147 | self.few_shot_dl_5 = few_shot_data_generator(self.trg_test_dl, self.dataset_configs,
148 | 5) # set 5 to other value if you want other k-shot FST
149 |
150 | def create_save_dir(self, save_dir):
151 | if not os.path.exists(save_dir):
152 | os.mkdir(save_dir)
153 |
154 | def calculate_metrics_risks(self):
155 | # calculation based source test data
156 | self.evaluate(self.src_test_dl)
157 | src_risk = self.loss.item()
158 | # calculation based few_shot test data
159 | self.evaluate(self.few_shot_dl_5)
160 | fst_risk = self.loss.item()
161 | # calculation based target test data
162 | self.evaluate(self.trg_test_dl)
163 | trg_risk = self.loss.item()
164 |
165 | # calculate metrics
166 | acc = self.ACC(self.full_preds.argmax(dim=1).cpu(), self.full_labels.cpu()).item()
167 | # f1_torch
168 | f1 = self.F1(self.full_preds.argmax(dim=1).cpu(), self.full_labels.cpu()).item()
169 | auroc = self.AUROC(self.full_preds.cpu(), self.full_labels.cpu()).item()
170 | # f1_sk learn
171 | # f1 = f1_score(self.full_preds.argmax(dim=1).cpu().numpy(), self.full_labels.cpu().numpy(), average='macro')
172 |
173 | risks = src_risk, fst_risk, trg_risk
174 | metrics = acc, f1, auroc
175 |
176 | return risks, metrics
177 |
178 | def save_tables_to_file(self,table_results, name):
179 | # save to file if needed
180 | table_results.to_csv(os.path.join(self.exp_log_dir,f"{name}.csv"))
181 |
182 | def save_checkpoint(self, home_path, log_dir, last_model, best_model):
183 | save_dict = {
184 | "last": last_model,
185 | "best": best_model
186 | }
187 | # save classification report
188 | save_path = os.path.join(home_path, log_dir, f"checkpoint.pt")
189 | torch.save(save_dict, save_path)
190 |
191 | def calculate_avg_std_wandb_table(self, results):
192 |
193 | avg_metrics = [np.mean(results.get_column(metric)) for metric in results.columns[2:]]
194 | std_metrics = [np.std(results.get_column(metric)) for metric in results.columns[2:]]
195 | summary_metrics = {metric: np.mean(results.get_column(metric)) for metric in results.columns[2:]}
196 |
197 | results.add_data('mean', '-', *avg_metrics)
198 | results.add_data('std', '-', *std_metrics)
199 |
200 | return results, summary_metrics
201 |
202 | def log_summary_metrics_wandb(self, results, risks):
203 |
204 | # Calculate average and standard deviation for metrics
205 | avg_metrics = [np.mean(results.get_column(metric)) for metric in results.columns[2:]]
206 | std_metrics = [np.std(results.get_column(metric)) for metric in results.columns[2:]]
207 |
208 | avg_risks = [np.mean(risks.get_column(risk)) for risk in risks.columns[2:]]
209 | std_risks = [np.std(risks.get_column(risk)) for risk in risks.columns[2:]]
210 |
211 | # Estimate summary metrics
212 | summary_metrics = {metric: np.mean(results.get_column(metric)) for metric in results.columns[2:]}
213 | summary_risks = {risk: np.mean(risks.get_column(risk)) for risk in risks.columns[2:]}
214 |
215 |
216 | # append avg and std values to metrics
217 | results.add_data('mean', '-', *avg_metrics)
218 | results.add_data('std', '-', *std_metrics)
219 |
220 | # append avg and std values to risks
221 | results.add_data('mean', '-', *avg_risks)
222 | risks.add_data('std', '-', *std_risks)
223 |
224 | def wandb_logging(self, total_results, total_risks, summary_metrics, summary_risks):
225 | # log wandb
226 | wandb.log({'results': total_results})
227 | wandb.log({'risks': total_risks})
228 | wandb.log({'hparams': wandb.Table(dataframe=pd.DataFrame(dict(self.hparams).items(), columns=['parameter', 'value']), allow_mixed_types=True)})
229 | wandb.log(summary_metrics)
230 | wandb.log(summary_risks)
231 |
232 | def calculate_metrics(self):
233 |
234 | self.evaluate(self.trg_test_dl)
235 | # accuracy
236 | acc = self.ACC(self.full_preds.argmax(dim=1).cpu(), self.full_labels.cpu()).item()
237 | # f1
238 | f1 = self.F1(self.full_preds.argmax(dim=1).cpu(), self.full_labels.cpu()).item()
239 | # auroc
240 | auroc = self.AUROC(self.full_preds.cpu(), self.full_labels.cpu()).item()
241 |
242 | return acc, f1, auroc
243 |
244 | def calculate_risks(self):
245 | # calculation based source test data
246 | self.evaluate(self.src_test_dl)
247 | src_risk = self.loss.item()
248 | # calculation based few_shot test data
249 | self.evaluate(self.few_shot_dl_5)
250 | fst_risk = self.loss.item()
251 | # calculation based target test data
252 | self.evaluate(self.trg_test_dl)
253 | trg_risk = self.loss.item()
254 |
255 | return src_risk, fst_risk, trg_risk
256 |
257 | def append_results_to_tables(self, table, scenario, run_id, metrics):
258 |
259 | # Create metrics and risks rows
260 | results_row = [scenario, run_id, *metrics]
261 |
262 | # Create new dataframes for each row
263 | results_df = pd.DataFrame([results_row], columns=table.columns)
264 |
265 | # Concatenate new dataframes with original dataframes
266 | table = pd.concat([table, results_df], ignore_index=True)
267 |
268 | return table
269 |
270 | def add_mean_std_table(self, table, columns):
271 | # Calculate average and standard deviation for metrics
272 | avg_metrics = [table[metric].mean() for metric in columns[2:]]
273 | std_metrics = [table[metric].std() for metric in columns[2:]]
274 |
275 | # Create dataframes for mean and std values
276 | mean_metrics_df = pd.DataFrame([['mean', '-', *avg_metrics]], columns=columns)
277 | std_metrics_df = pd.DataFrame([['std', '-', *std_metrics]], columns=columns)
278 |
279 | # Concatenate original dataframes with mean and std dataframes
280 | table = pd.concat([table, mean_metrics_df, std_metrics_df], ignore_index=True)
281 |
282 | # Create a formatting function to format each element in the tables
283 | format_func = lambda x: f"{x:.4f}" if isinstance(x, float) else x
284 |
285 | # Apply the formatting function to each element in the tables
286 | table = table.applymap(format_func)
287 |
288 | return table
--------------------------------------------------------------------------------
/trainers/sweep.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | sys.path.append('../')
4 | import torch
5 | import torch.nn.functional as F
6 | import os
7 | import wandb
8 | import pandas as pd
9 | import numpy as np
10 | import warnings
11 | import sklearn.exceptions
12 | import collections
13 | import argparse
14 | import warnings
15 | import sklearn.exceptions
16 |
17 | from configs.sweep_params import sweep_alg_hparams
18 | from utils import fix_randomness, starting_logs, DictAsObject
19 | from algorithms.algorithms import get_algorithm_class
20 | from models.models import get_backbone_class
21 | from utils import AverageMeter
22 |
23 | from trainers.abstract_trainer import AbstractTrainer
24 |
25 | warnings.filterwarnings("ignore", category=sklearn.exceptions.UndefinedMetricWarning)
26 | parser = argparse.ArgumentParser()
27 |
28 |
29 | class Trainer(AbstractTrainer):
30 | """
31 | This class contain the main training functions for our AdAtime
32 | """
33 |
34 | def __init__(self, args):
35 | super(Trainer, self).__init__(args)
36 |
37 | # sweep parameters
38 | self.num_sweeps = args.num_sweeps
39 | self.sweep_project_wandb = args.sweep_project_wandb
40 | self.wandb_entity = args.wandb_entity
41 | self.hp_search_strategy = args.hp_search_strategy
42 | self.metric_to_minimize = args.metric_to_minimize
43 |
44 | # Logging
45 | self.exp_log_dir = os.path.join(self.home_path, self.save_dir)
46 | os.makedirs(self.exp_log_dir, exist_ok=True)
47 |
48 | def sweep(self):
49 | # sweep configurations
50 | sweep_runs_count = self.num_sweeps
51 | sweep_config = {
52 | 'method': self.hp_search_strategy,
53 | 'metric': {'name': self.metric_to_minimize, 'goal': 'minimize'},
54 | 'name': self.da_method + '_' + self.backbone,
55 | 'parameters': {**sweep_alg_hparams[self.da_method]}
56 | }
57 | sweep_id = wandb.sweep(sweep_config, project=self.sweep_project_wandb, entity=self.wandb_entity)
58 |
59 | wandb.agent(sweep_id, self.train, count=sweep_runs_count)
60 |
61 | def train(self):
62 | run = wandb.init(config=self.hparams)
63 | self.hparams= wandb.config
64 |
65 | # create tables for results and risks
66 | columns = ["scenario", "run", "acc", "f1_score", "auroc"]
67 | table_results = wandb.Table(columns=columns, allow_mixed_types=True)
68 | columns = ["scenario", "run", "src_risk", "few_shot_risk", "trg_risk"]
69 | table_risks = wandb.Table(columns=columns, allow_mixed_types=True)
70 |
71 | for src_id, trg_id in self.dataset_configs.scenarios:
72 | for run_id in range(self.num_runs):
73 | # set random seed and create logger
74 | fix_randomness(run_id)
75 | self.logger, self.scenario_log_dir = starting_logs( self.dataset, self.da_method, self.exp_log_dir, src_id, trg_id, run_id )
76 |
77 | # average meters
78 | self.loss_avg_meters = collections.defaultdict(lambda: AverageMeter())
79 |
80 | # load data and train model
81 | self.load_data(src_id, trg_id)
82 |
83 | # initiate the domain adaptation algorithm
84 | self.initialize_algorithm()
85 |
86 | # Train the domain adaptation algorithm
87 | self.last_model, self.best_model = self.algorithm.update(self.src_train_dl, self.trg_train_dl, self.loss_avg_meters, self.logger)
88 |
89 | # calculate metrics and risks
90 | metrics = self.calculate_metrics()
91 | risks = self.calculate_risks()
92 |
93 | # append results to tables
94 | scenario = f"{src_id}_to_{trg_id}"
95 | table_results.add_data(scenario, run_id, *metrics)
96 | table_risks.add_data(scenario, run_id, *risks)
97 |
98 | # calculate overall metrics and risks
99 | total_results, summary_metrics = self.calculate_avg_std_wandb_table(table_results)
100 | total_risks, summary_risks = self.calculate_avg_std_wandb_table(table_risks)
101 |
102 | # log results to WandB
103 | self.wandb_logging(total_results, total_risks, summary_metrics, summary_risks)
104 |
105 | # finish the run
106 | run.finish()
107 |
108 |
--------------------------------------------------------------------------------
/trainers/train.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | import os
6 | import wandb
7 | import pandas as pd
8 | import numpy as np
9 | import warnings
10 | import sklearn.exceptions
11 | import collections
12 | import argparse
13 | import warnings
14 | import sklearn.exceptions
15 |
16 | from utils import fix_randomness, starting_logs, AverageMeter
17 | from algorithms.algorithms import get_algorithm_class
18 | from models.models import get_backbone_class
19 | from trainers.abstract_trainer import AbstractTrainer
20 | warnings.filterwarnings("ignore", category=sklearn.exceptions.UndefinedMetricWarning)
21 | parser = argparse.ArgumentParser()
22 |
23 |
24 |
25 | class Trainer(AbstractTrainer):
26 | """
27 | This class contain the main training functions for our AdAtime
28 | """
29 |
30 | def __init__(self, args):
31 | super().__init__(args)
32 |
33 | self.results_columns = ["scenario", "run", "acc", "f1_score", "auroc"]
34 | self.risks_columns = ["scenario", "run", "src_risk", "few_shot_risk", "trg_risk"]
35 |
36 |
37 | def fit(self):
38 |
39 | # table with metrics
40 | table_results = pd.DataFrame(columns=self.results_columns)
41 |
42 | # table with risks
43 | table_risks = pd.DataFrame(columns=self.risks_columns)
44 |
45 |
46 | # Trainer
47 | for src_id, trg_id in self.dataset_configs.scenarios:
48 | for run_id in range(self.num_runs):
49 | # fixing random seed
50 | fix_randomness(run_id)
51 |
52 | # Logging
53 | self.logger, self.scenario_log_dir = starting_logs(self.dataset, self.da_method, self.exp_log_dir,
54 | src_id, trg_id, run_id)
55 | # Average meters
56 | self.loss_avg_meters = collections.defaultdict(lambda: AverageMeter())
57 |
58 | # Load data
59 | self.load_data(src_id, trg_id)
60 |
61 | # initiate the domain adaptation algorithm
62 | self.initialize_algorithm()
63 |
64 | # Train the domain adaptation algorithm
65 | self.last_model, self.best_model = self.algorithm.update(self.src_train_dl, self.trg_train_dl, self.loss_avg_meters, self.logger)
66 |
67 | # Save checkpoint
68 | self.save_checkpoint(self.home_path, self.scenario_log_dir, self.last_model, self.best_model)
69 |
70 | # Calculate risks and metrics
71 | metrics = self.calculate_metrics()
72 | risks = self.calculate_risks()
73 |
74 | # Append results to tables
75 | scenario = f"{src_id}_to_{trg_id}"
76 | table_results = self.append_results_to_tables(table_results, scenario, run_id, metrics)
77 | table_risks = self.append_results_to_tables(table_risks, scenario, run_id, risks)
78 |
79 | # Calculate and append mean and std to tables
80 | table_results = self.add_mean_std_table(table_results, self.results_columns)
81 | table_risks = self.add_mean_std_table(table_risks, self.risks_columns)
82 |
83 |
84 | # Save tables to file if needed
85 | self.save_tables_to_file(table_results, 'results')
86 | self.save_tables_to_file(table_risks, 'risks')
87 |
88 | def test(self):
89 | # Results dataframes
90 | last_results = pd.DataFrame(columns=self.results_columns)
91 | best_results = pd.DataFrame(columns=self.results_columns)
92 |
93 | # Cross-domain scenarios
94 | for src_id, trg_id in self.dataset_configs.scenarios:
95 | for run_id in range(self.num_runs):
96 | # fixing random seed
97 | fix_randomness(run_id)
98 |
99 | # Logging
100 | self.scenario_log_dir = os.path.join(self.exp_log_dir, src_id + "_to_" + trg_id + "_run_" + str(run_id))
101 |
102 | self.loss_avg_meters = collections.defaultdict(lambda: AverageMeter())
103 |
104 | # Load data
105 | self.load_data(src_id, trg_id)
106 |
107 | # Build model
108 | self.initialize_algorithm()
109 |
110 | # Load chechpoint
111 | last_chk, best_chk = self.load_checkpoint(self.scenario_log_dir)
112 |
113 | # Testing the last model
114 | self.algorithm.network.load_state_dict(last_chk)
115 | self.evaluate(self.trg_test_dl)
116 | last_metrics = self.calculate_metrics()
117 | last_results = self.append_results_to_tables(last_results, f"{src_id}_to_{trg_id}", run_id,
118 | last_metrics)
119 |
120 |
121 | # Testing the best model
122 | self.algorithm.network.load_state_dict(best_chk)
123 | self.evaluate(self.trg_test_dl)
124 | best_metrics = self.calculate_metrics()
125 | # Append results to tables
126 | best_results = self.append_results_to_tables(best_results, f"{src_id}_to_{trg_id}", run_id,
127 | best_metrics)
128 |
129 | last_scenario_mean_std = last_results.groupby('scenario')[['acc', 'f1_score', 'auroc']].agg(['mean', 'std'])
130 | best_scenario_mean_std = best_results.groupby('scenario')[['acc', 'f1_score', 'auroc']].agg(['mean', 'std'])
131 |
132 |
133 | # Save tables to file if needed
134 | self.save_tables_to_file(last_scenario_mean_std, 'last_results')
135 | self.save_tables_to_file(best_scenario_mean_std, 'best_results')
136 |
137 | # printing summary
138 | summary_last = {metric: np.mean(last_results[metric]) for metric in self.results_columns[2:]}
139 | summary_best = {metric: np.mean(best_results[metric]) for metric in self.results_columns[2:]}
140 | for summary_name, summary in [('Last', summary_last), ('Best', summary_best)]:
141 | for key, val in summary.items():
142 | print(f'{summary_name}: {key}\t: {val:2.4f}')
143 |
144 |
145 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn as nn
4 |
5 | import random
6 | import os
7 | import sys
8 | import logging
9 | import numpy as np
10 | import pandas as pd
11 | from shutil import copy
12 | from datetime import datetime
13 |
14 | from skorch import NeuralNetClassifier # for DIV Risk
15 | from sklearn.model_selection import train_test_split
16 | from sklearn.metrics import classification_report, accuracy_score
17 |
18 |
19 | class AverageMeter(object):
20 | """Computes and stores the average and current value"""
21 |
22 | def __init__(self):
23 | self.reset()
24 |
25 | def reset(self):
26 | self.val = 0
27 | self.avg = 0
28 | self.sum = 0
29 | self.count = 0
30 |
31 | def update(self, val, n=1):
32 | self.val = val
33 | self.sum += val * n
34 | self.count += n
35 | self.avg = self.sum / self.count
36 |
37 |
38 | def fix_randomness(SEED):
39 | random.seed(SEED)
40 | np.random.seed(SEED)
41 | torch.manual_seed(SEED)
42 | torch.cuda.manual_seed(SEED)
43 | torch.backends.cudnn.deterministic = True
44 | torch.backends.cudnn.benchmark = False
45 |
46 |
47 | def _logger(logger_name, level=logging.DEBUG):
48 | """
49 | Method to return a custom logger with the given name and level
50 | """
51 | logger = logging.getLogger(logger_name)
52 | logger.setLevel(level)
53 | format_string = "%(message)s"
54 | log_format = logging.Formatter(format_string)
55 | # Creating and adding the console handler
56 | console_handler = logging.StreamHandler(sys.stdout)
57 | console_handler.setFormatter(log_format)
58 | logger.addHandler(console_handler)
59 | # Creating and adding the file handler
60 | file_handler = logging.FileHandler(logger_name, mode='a')
61 | file_handler.setFormatter(log_format)
62 | logger.addHandler(file_handler)
63 | return logger
64 |
65 |
66 | def starting_logs(data_type, da_method, exp_log_dir, src_id, tgt_id, run_id):
67 | log_dir = os.path.join(exp_log_dir, src_id + "_to_" + tgt_id + "_run_" + str(run_id))
68 | os.makedirs(log_dir, exist_ok=True)
69 | log_file_name = os.path.join(log_dir, f"logs_{datetime.now().strftime('%d_%m_%Y_%H_%M_%S')}.log")
70 | logger = _logger(log_file_name)
71 | logger.debug("=" * 45)
72 | logger.debug(f'Dataset: {data_type}')
73 | logger.debug(f'Method: {da_method}')
74 | logger.debug("=" * 45)
75 | logger.debug(f'Source: {src_id} ---> Target: {tgt_id}')
76 | logger.debug(f'Run ID: {run_id}')
77 | logger.debug("=" * 45)
78 | return logger, log_dir
79 |
80 |
81 | def save_checkpoint(home_path, algorithm, log_dir, last_model, best_model):
82 | save_dict = {
83 | "last": last_model,
84 | "best": best_model
85 | }
86 | # save classification report
87 | save_path = os.path.join(home_path, log_dir, f"checkpoint.pt")
88 |
89 | torch.save(save_dict, save_path)
90 |
91 |
92 | def weights_init(m):
93 | classname = m.__class__.__name__
94 | if classname.find('Conv') != -1:
95 | m.weight.data.normal_(0.0, 0.02)
96 | elif classname.find('BatchNorm') != -1:
97 | m.weight.data.normal_(1.0, 0.02)
98 | m.bias.data.fill_(0)
99 | elif classname.find('Linear') != -1:
100 | m.weight.data.normal_(0.0, 0.1)
101 | m.bias.data.fill_(0)
102 |
103 |
104 | def _calc_metrics(pred_labels, true_labels, log_dir, home_path, target_names):
105 | pred_labels = np.array(pred_labels).astype(int)
106 | true_labels = np.array(true_labels).astype(int)
107 |
108 | r = classification_report(true_labels, pred_labels, target_names=target_names, digits=6, output_dict=True)
109 |
110 | df = pd.DataFrame(r)
111 | accuracy = accuracy_score(true_labels, pred_labels)
112 | df["accuracy"] = accuracy
113 | df = df * 100
114 |
115 | # save classification report
116 | file_name = "classification_report.xlsx"
117 | report_Save_path = os.path.join(home_path, log_dir, file_name)
118 | df.to_excel(report_Save_path)
119 |
120 | return accuracy * 100, r["macro avg"]["f1-score"] * 100
121 |
122 |
123 | def copy_Files(destination):
124 | destination_dir = os.path.join(destination, "MODEL_BACKUP_FILES")
125 | os.makedirs(destination_dir, exist_ok=True)
126 | copy("main.py", os.path.join(destination_dir, "main.py"))
127 | copy("utils.py", os.path.join(destination_dir, "utils.py"))
128 | copy(f"trainer.py", os.path.join(destination_dir, f"trainer.py"))
129 | copy(f"same_domain_trainer.py", os.path.join(destination_dir, f"same_domain_trainer.py"))
130 | copy("dataloader/dataloader.py", os.path.join(destination_dir, "dataloader.py"))
131 | copy(f"models/models.py", os.path.join(destination_dir, f"models.py"))
132 | copy(f"models/loss.py", os.path.join(destination_dir, f"loss.py"))
133 | copy("algorithms/algorithms.py", os.path.join(destination_dir, "algorithms.py"))
134 | copy(f"configs/data_model_configs.py", os.path.join(destination_dir, f"data_model_configs.py"))
135 | copy(f"configs/hparams.py", os.path.join(destination_dir, f"hparams.py"))
136 | copy(f"configs/sweep_params.py", os.path.join(destination_dir, f"sweep_params.py"))
137 |
138 |
139 |
140 |
141 | def get_iwcv_value(weight, error):
142 | N, d = weight.shape
143 | _N, _d = error.shape
144 | assert N == _N and d == _d, 'dimension mismatch!'
145 | weighted_error = weight * error
146 | return np.mean(weighted_error)
147 |
148 |
149 | def get_dev_value(weight, error):
150 | """
151 | :param weight: shape [N, 1], the importance weight for N source samples in the validation set
152 | :param error: shape [N, 1], the error value for each source sample in the validation set
153 | (typically 0 for correct classification and 1 for wrong classification)
154 | """
155 | N, d = weight.shape
156 | _N, _d = error.shape
157 | assert N == _N and d == _d, 'dimension mismatch!'
158 | weighted_error = weight * error
159 | cov = np.cov(np.concatenate((weighted_error, weight), axis=1), rowvar=False)[0][1]
160 | var_w = np.var(weight, ddof=1)
161 | eta = - cov / var_w
162 | return np.mean(weighted_error) + eta * np.mean(weight) - eta
163 |
164 |
165 | class simple_MLP(nn.Module):
166 | def __init__(self, inp_units, out_units=2):
167 | super(simple_MLP, self).__init__()
168 |
169 | self.dense0 = nn.Linear(inp_units, inp_units // 2)
170 | self.nonlin = nn.ReLU()
171 | self.output = nn.Linear(inp_units // 2, out_units)
172 | self.softmax = nn.Softmax(dim=-1)
173 |
174 | def forward(self, x, **kwargs):
175 | x = self.nonlin(self.dense0(x))
176 | x = self.softmax(self.output(x))
177 | return x
178 |
179 |
180 | def get_weight_gpu(source_feature, target_feature, validation_feature, configs, device):
181 | """
182 | :param source_feature: shape [N_tr, d], features from training set
183 | :param target_feature: shape [N_te, d], features from test set
184 | :param validation_feature: shape [N_v, d], features from validation set
185 | :return:
186 | """
187 | import copy
188 | N_s, d = source_feature.shape
189 | N_t, _d = target_feature.shape
190 | source_feature = copy.deepcopy(source_feature.detach().cpu()) # source_feature.clone()
191 | target_feature = copy.deepcopy(target_feature.detach().cpu()) # target_feature.clone()
192 | source_feature = source_feature.to(device)
193 | target_feature = target_feature.to(device)
194 | all_feature = torch.cat((source_feature, target_feature), dim=0)
195 | all_label = torch.from_numpy(np.asarray([1] * N_s + [0] * N_t, dtype=np.int32)).long()
196 |
197 | feature_for_train, feature_for_test, label_for_train, label_for_test = train_test_split(all_feature, all_label,
198 | train_size=0.8)
199 | learning_rates = [1e-1, 5e-2, 1e-2]
200 | val_acc = []
201 | domain_classifiers = []
202 |
203 | for lr in learning_rates:
204 | domain_classifier = NeuralNetClassifier(
205 | simple_MLP,
206 | module__inp_units=configs.final_out_channels * configs.features_len,
207 | max_epochs=30,
208 | lr=lr,
209 | device=device,
210 | # Shuffle training data on each epoch
211 | iterator_train__shuffle=True,
212 | callbacks="disable"
213 | )
214 | domain_classifier.fit(feature_for_train.float(), label_for_train.long())
215 | output = domain_classifier.predict(feature_for_test)
216 | acc = np.mean((label_for_test.numpy() == output).astype(np.float32))
217 | val_acc.append(acc)
218 | domain_classifiers.append(domain_classifier)
219 |
220 | index = val_acc.index(max(val_acc))
221 | domain_classifier = domain_classifiers[index]
222 |
223 | domain_out = domain_classifier.predict_proba(validation_feature.to(device).float())
224 | return domain_out[:, :1] / domain_out[:, 1:] * N_s * 1.0 / N_t
225 |
226 |
227 | def calc_dev_risk(target_model, src_train_dl, tgt_train_dl, src_valid_dl, configs, device):
228 | src_train_feats = target_model.feature_extractor(src_train_dl.dataset.x_data.float().to(device))
229 | tgt_train_feats = target_model.feature_extractor(tgt_train_dl.dataset.x_data.float().to(device))
230 | src_valid_feats = target_model.feature_extractor(src_valid_dl.dataset.x_data.float().to(device))
231 | src_valid_pred = target_model.classifier(src_valid_feats)
232 |
233 | dev_weights = get_weight_gpu(src_train_feats.to(device), tgt_train_feats.to(device),
234 | src_valid_feats.to(device), configs, device)
235 | dev_error = F.cross_entropy(src_valid_pred, src_valid_dl.dataset.y_data.long().to(device), reduction='none')
236 | dev_risk = get_dev_value(dev_weights, dev_error.unsqueeze(1).detach().cpu().numpy())
237 | # iwcv_risk = get_iwcv_value(dev_weights, dev_error.unsqueeze(1).detach().cpu().numpy())
238 | return dev_risk
239 |
240 |
241 | def calculate_risk(target_model, risk_dataloader, device):
242 | if type(risk_dataloader) == tuple:
243 | x_data = torch.cat((risk_dataloader[0].dataset.x_data, risk_dataloader[1].dataset.x_data), axis=0)
244 | y_data = torch.cat((risk_dataloader[0].dataset.y_data, risk_dataloader[1].dataset.y_data), axis=0)
245 | else:
246 | x_data = risk_dataloader.dataset.x_data
247 | y_data = risk_dataloader.dataset.y_data
248 |
249 | feat = target_model.feature_extractor(x_data.float().to(device))
250 | pred = target_model.classifier(feat)
251 | cls_loss = F.cross_entropy(pred, y_data.long().to(device))
252 | return cls_loss.item()
253 |
254 |
255 | class DictAsObject:
256 | def __init__(self, d):
257 | self.__dict__ = d
258 |
259 | def __getattr__(self, name):
260 | try:
261 | return self.__dict__[name]
262 | except KeyError:
263 | raise AttributeError(f"'DictAsObject' object has no attribute '{name}'")
264 |
265 |
266 | # For DIRT-T
267 | class EMA:
268 | def __init__(self, decay):
269 | self.decay = decay
270 | self.shadow = {}
271 |
272 | def register(self, model):
273 | for name, param in model.named_parameters():
274 | if param.requires_grad:
275 | self.shadow[name] = param.data.clone()
276 | self.params = self.shadow.keys()
277 |
278 | def __call__(self, model):
279 | if self.decay > 0:
280 | for name, param in model.named_parameters():
281 | if name in self.params and param.requires_grad:
282 | self.shadow[name] -= (1 - self.decay) * (self.shadow[name] - param.data)
283 | param.data = self.shadow[name]
284 |
285 |
286 |
--------------------------------------------------------------------------------