├── .DS_Store ├── .gitignore ├── .idea ├── .gitignore ├── AdaTime.iml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── LICENSE ├── README.md ├── algorithms ├── __pycache__ │ ├── algorithms.cpython-310.pyc │ ├── algorithms.cpython-38.pyc │ └── algorithms.cpython-39.pyc └── algorithms.py ├── configs ├── __pycache__ │ ├── data_model_configs.cpython-310.pyc │ ├── data_model_configs.cpython-38.pyc │ ├── data_model_configs.cpython-39.pyc │ ├── hparams.cpython-310.pyc │ ├── hparams.cpython-38.pyc │ ├── hparams.cpython-39.pyc │ ├── sweep_params.cpython-310.pyc │ ├── sweep_params.cpython-38.pyc │ └── sweep_params.cpython-39.pyc ├── data_model_configs.py ├── hparams.py └── sweep_params.py ├── dataloader ├── __pycache__ │ ├── dataloader.cpython-310.pyc │ ├── dataloader.cpython-38.pyc │ └── dataloader.cpython-39.pyc └── dataloader.py ├── main.py ├── main_sweep.py ├── misc ├── adatime.PNG └── results.PNG ├── models ├── __pycache__ │ ├── loss.cpython-310.pyc │ ├── loss.cpython-38.pyc │ ├── loss.cpython-39.pyc │ ├── models.cpython-310.pyc │ ├── models.cpython-38.pyc │ ├── models.cpython-39.pyc │ ├── resnet18.cpython-310.pyc │ ├── resnet18.cpython-38.pyc │ └── resnet18.cpython-39.pyc ├── loss.py ├── models.py └── resnet18.py ├── trainers ├── abstract_trainer.py ├── sweep.py └── train.py └── utils.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | wandb/ 3 | experiments_logs/ 4 | data/ 5 | *.pyc 6 | main.py -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/AdaTime.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 116 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Emadeldeen Eldele 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [TKDD 2023] AdaTime: A Benchmarking Suite for Domain Adaptation on Time Series Data [[Paper](https://arxiv.org/abs/2203.08321)] [[Cite](#citation)] 2 | #### *by: Mohamed Ragab\*, Emadeldeen Eldele\*, Wee Ling Tan, Chuan-Sheng Foo, Zhenghua Chen, Min Wu, Chee Kwoh, Xiaoli Li*
* Equal contribution
☨ Corresponding author 3 | 4 | ## Published in the [ACM Transactions on Knowledge Discovery from Data (TKDD)](https://dl.acm.org/doi/10.1145/3587937). 5 | **AdaTime** is a PyTorch suite to systematically and fairly evaluate different domain adaptation methods on time series data. 6 | 7 |

8 | 9 |

10 | 11 | ## Requirmenets: 12 | - Python3 13 | - Pytorch==1.7 14 | - Numpy==1.20.1 15 | - scikit-learn==0.24.1 16 | - Pandas==1.2.4 17 | - skorch==0.10.0 (For DEV risk calculations) 18 | - openpyxl==3.0.7 (for classification reports) 19 | - Wandb=0.12.7 (for sweeps) 20 | 21 | ## Datasets 22 | 23 | ### Available Datasets 24 | We used four public datasets in this study. We also provide the **preprocessed** versions as follows: 25 | - [Sleep-EDF](https://researchdata.ntu.edu.sg/dataset.xhtml?persistentId=doi:10.21979/N9/UD1IM9) 26 | - [UCIHAR](https://researchdata.ntu.edu.sg/dataset.xhtml?persistentId=doi:10.21979/N9/0SYHTZ) 27 | - [HHAR](https://researchdata.ntu.edu.sg/dataset.xhtml?persistentId=doi:10.21979/N9/OWDFXO) 28 | - [WISDM](https://researchdata.ntu.edu.sg/dataset.xhtml?persistentId=doi:10.21979/N9/KJWE5B) 29 | - [FD](https://researchdata.ntu.edu.sg/dataset.xhtml?persistentId=doi:10.21979/N9/PU85XN) 30 | 31 | ### Adding New Dataset 32 | 33 | #### Structure of data 34 | To add new dataset (*e.g.,* NewData), it should be placed in a folder named: NewData in the datasets directory. 35 | 36 | Since "NewData" has several domains, each domain should be split into train/test splits with naming style as 37 | "train_*x*.pt" and "test_*x*.pt". 38 | 39 | The structure of data files should in dictionary form as follows: 40 | `train.pt = {"samples": data, "labels: labels}`, and similarly for `test.pt`. 41 | 42 | #### Configurations 43 | Next, you have to add a class with the name NewData in the `configs/data_model_configs.py` file. 44 | You can find similar classes for existing datasets as guidelines. 45 | Also, you have to specify the cross-domain scenarios in `self.scenarios` variable. 46 | 47 | Last, you have to add another class with the name NewData in the `configs/hparams.py` file to specify 48 | the training parameters. 49 | 50 | 51 | ## Domain Adaptation Algorithms 52 | ### Existing Algorithms 53 | - [Deep Coral](https://arxiv.org/abs/1607.01719) 54 | - [MMDA](https://arxiv.org/abs/1901.00282) 55 | - [DANN](https://arxiv.org/abs/1505.07818) 56 | - [CDAN](https://arxiv.org/abs/1705.10667) 57 | - [DIRT-T](https://arxiv.org/abs/1802.08735) 58 | - [DSAN](https://ieeexplore.ieee.org/document/9085896) 59 | - [HoMM](https://arxiv.org/pdf/1912.11976.pdf) 60 | - [DDC](https://arxiv.org/abs/1412.3474) 61 | - [CoDATS](https://arxiv.org/pdf/2005.10996.pdf) 62 | - [AdvSKM](https://www.ijcai.org/proceedings/2021/0378.pdf) 63 | - [SASA](https://ojs.aaai.org/index.php/AAAI/article/view/16846/16653) 64 | - [CoTMix](https://arxiv.org/pdf/2212.01555.pdf) 65 | 66 | 67 | ### Adding New Algorithm 68 | To add a new algorithm, place it in `algorithms/algorithms.py` file. 69 | 70 | 71 | ## Training procedure 72 | 73 | The experiments are organised in a hierarchical way such that: 74 | - Several experiments are collected under one directory assigned by `--experiment_description`. 75 | - Each experiment could have different trials, each is specified by `--run_description`. 76 | - For example, if we want to experiment different UDA methods with CNN backbone, we can assign 77 | `--experiment_description CNN_backnones --run_description DANN` and `--experiment_description CNN_backnones --run_description DDC` and so on. 78 | 79 | ### Training a model 80 | 81 | To train a model: 82 | 83 | ``` 84 | python main.py --phase train \ 85 | --experiment_description exp1 \ 86 | --da_method DANN \ 87 | --dataset HHAR \ 88 | --backbone CNN \ 89 | --num_runs 5 \ 90 | ``` 91 | To test a model: 92 | 93 | ``` 94 | python main.py --phase test \ 95 | --experiment_description exp1 \ 96 | --da_method DANN \ 97 | --dataset HHAR \ 98 | --backbone CNN \ 99 | --num_runs 5 \ 100 | ``` 101 | ### Launching a sweep 102 | Sweeps here are deployed on [Wandb](https://wandb.ai/), which makes it easier for visualization, following the training progress, organizing sweeps, and collecting results. 103 | 104 | ``` 105 | python main_sweep.py --experiment_description exp1_sweep \ 106 | --run_description sweep_over_lr \ 107 | --da_method DANN \ 108 | --dataset HHAR \ 109 | --backbone CNN \ 110 | --num_runs 5 \ 111 | --sweep_project_wandb TEST 112 | --num_sweeps 50 \ 113 | ``` 114 | Upon the run, you will find the running progress in the specified project page in wandb. 115 | 116 | `Note:` If you got cuda out of memory error during testing, this is probably due to DEV risk calculations. 117 | 118 | 119 | ### Upper and Lower bounds 120 | - To obtain the source-only or the lower bound you can choose the da_method to be `NO_ADAPT`. 121 | - To obtain the the target-only or the upper bound you can choose the da_method `TARGET_ONLY` 122 | 123 | ## Results 124 | - Each run will have all the cross-domain scenarios results in the format `src_to_trg_run_x`, where `x` 125 | is the run_id (you can have multiple runs by assigning `--num_runs` arg). 126 | - Under each directory, you will find the classification report, a log file, checkpoint, 127 | and the different risks scores. 128 | - By the end of the all the runs, you will find the overall average and std results in the run directory. 129 | 130 | 131 |

132 | 133 |

134 | 135 | 136 | ## Citation 137 | If you found this work useful for you, please consider citing it. 138 | ``` 139 | @article{adatime, 140 | author = {Ragab, Mohamed and Eldele, Emadeldeen and Tan, Wee Ling and Foo, Chuan-Sheng and Chen, Zhenghua and Wu, Min and Kwoh, Chee-Keong and Li, Xiaoli}, 141 | title = {ADATIME: A Benchmarking Suite for Domain Adaptation on Time Series Data}, 142 | year = {2023}, 143 | publisher = {Association for Computing Machinery}, 144 | address = {New York, NY, USA}, 145 | issn = {1556-4681}, 146 | url = {https://doi.org/10.1145/3587937}, 147 | doi = {10.1145/3587937}, 148 | journal = {ACM Trans. Knowl. Discov. Data}, 149 | month = {mar} 150 | } 151 | ``` 152 | 153 | 154 | ## Contact 155 | For any issues/questions regarding the paper or reproducing the results, please contact any of the following. 156 | 157 | Mohamed Ragab: *mohamedr002{at}e.ntu.edu.sg* 158 | 159 | Emadeldeen Eldele: *emad0002{at}e.ntu.edu.sg* 160 | 161 | School of Computer Science and Engineering (SCSE), 162 | Nanyang Technological University (NTU), Singapore. 163 | -------------------------------------------------------------------------------- /algorithms/__pycache__/algorithms.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/algorithms/__pycache__/algorithms.cpython-310.pyc -------------------------------------------------------------------------------- /algorithms/__pycache__/algorithms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/algorithms/__pycache__/algorithms.cpython-38.pyc -------------------------------------------------------------------------------- /algorithms/__pycache__/algorithms.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/algorithms/__pycache__/algorithms.cpython-39.pyc -------------------------------------------------------------------------------- /algorithms/algorithms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import itertools 5 | 6 | from models.models import classifier, ReverseLayerF, Discriminator, RandomLayer, Discriminator_CDAN, \ 7 | codats_classifier, AdvSKM_Disc, CNN_ATTN 8 | from models.loss import MMD_loss, CORAL, ConditionalEntropyLoss, VAT, LMMD_loss, HoMM_loss, NTXentLoss, SupConLoss 9 | from utils import EMA 10 | from torch.optim.lr_scheduler import StepLR 11 | from copy import deepcopy 12 | import torch.nn. functional as F 13 | 14 | def get_algorithm_class(algorithm_name): 15 | """Return the algorithm class with the given name.""" 16 | if algorithm_name not in globals(): 17 | raise NotImplementedError("Algorithm not found: {}".format(algorithm_name)) 18 | return globals()[algorithm_name] 19 | 20 | 21 | class Algorithm(torch.nn.Module): 22 | """ 23 | A subclass of Algorithm implements a domain adaptation algorithm. 24 | Subclasses should implement the update() method. 25 | """ 26 | 27 | def __init__(self, configs, backbone): 28 | super(Algorithm, self).__init__() 29 | self.configs = configs 30 | 31 | self.cross_entropy = nn.CrossEntropyLoss() 32 | self.feature_extractor = backbone(configs) 33 | self.classifier = classifier(configs) 34 | self.network = nn.Sequential(self.feature_extractor, self.classifier) 35 | 36 | 37 | # update function is common to all algorithms 38 | def update(self, src_loader, trg_loader, avg_meter, logger): 39 | # defining best and last model 40 | best_src_risk = float('inf') 41 | best_model = None 42 | 43 | for epoch in range(1, self.hparams["num_epochs"] + 1): 44 | 45 | # training loop 46 | self.training_epoch(src_loader, trg_loader, avg_meter, epoch) 47 | 48 | # saving the best model based on src risk 49 | if (epoch + 1) % 10 == 0 and avg_meter['Src_cls_loss'].avg < best_src_risk: 50 | best_src_risk = avg_meter['Src_cls_loss'].avg 51 | best_model = deepcopy(self.network.state_dict()) 52 | 53 | 54 | logger.debug(f'[Epoch : {epoch}/{self.hparams["num_epochs"]}]') 55 | for key, val in avg_meter.items(): 56 | logger.debug(f'{key}\t: {val.avg:2.4f}') 57 | logger.debug(f'-------------------------------------') 58 | 59 | last_model = self.network.state_dict() 60 | 61 | return last_model, best_model 62 | 63 | # train loop vary from one method to another 64 | def training_epoch(self, *args, **kwargs): 65 | raise NotImplementedError 66 | 67 | 68 | class NO_ADAPT(Algorithm): 69 | """ 70 | Lower bound: train on source and test on target. 71 | """ 72 | def __init__(self, backbone, configs, hparams, device): 73 | super().__init__(configs, backbone) 74 | 75 | # optimizer and scheduler 76 | self.optimizer = torch.optim.Adam( 77 | self.network.parameters(), 78 | lr=hparams["learning_rate"], 79 | weight_decay=hparams["weight_decay"] 80 | ) 81 | self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay']) 82 | # hparams 83 | self.hparams = hparams 84 | # device 85 | self.device = device 86 | 87 | def training_epoch(self,src_loader, trg_loader, avg_meter, epoch): 88 | for src_x, src_y in src_loader: 89 | 90 | src_x, src_y = src_x.to(self.device), src_y.to(self.device) 91 | src_feat = self.feature_extractor(src_x) 92 | src_pred = self.classifier(src_feat) 93 | 94 | src_cls_loss = self.cross_entropy(src_pred, src_y) 95 | 96 | loss = src_cls_loss 97 | 98 | self.optimizer.zero_grad() 99 | loss.backward() 100 | self.optimizer.step() 101 | 102 | losses = {'Src_cls_loss': src_cls_loss.item()} 103 | 104 | for key, val in losses.items(): 105 | avg_meter[key].update(val, 32) 106 | 107 | self.lr_scheduler.step() 108 | 109 | 110 | class TARGET_ONLY(Algorithm): 111 | """ 112 | Upper bound: train on target and test on target. 113 | """ 114 | 115 | def __init__(self, backbone, configs, hparams, device): 116 | super().__init__(configs, backbone) 117 | 118 | # optimizer and scheduler 119 | self.optimizer = torch.optim.Adam( 120 | self.network.parameters(), 121 | lr=hparams["learning_rate"], 122 | weight_decay=hparams["weight_decay"] 123 | ) 124 | self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay']) 125 | # hparams 126 | self.hparams = hparams 127 | # device 128 | self.device = device 129 | 130 | def training_epoch(self, src_loader, trg_loader, avg_meter, epoch): 131 | 132 | for trg_x, trg_y in trg_loader: 133 | 134 | trg_x, trg_y = trg_x.to(self.device), trg_y.to(self.device) 135 | 136 | trg_feat = self.feature_extractor(trg_x) 137 | trg_pred = self.classifier(trg_feat) 138 | 139 | trg_cls_loss = self.cross_entropy(trg_pred, trg_y) 140 | 141 | loss = trg_cls_loss 142 | 143 | self.optimizer.zero_grad() 144 | loss.backward() 145 | self.optimizer.step() 146 | 147 | losses = {'Trg_cls_loss': trg_cls_loss.item()} 148 | 149 | for key, val in losses.items(): 150 | avg_meter[key].update(val, 32) 151 | 152 | self.lr_scheduler.step() 153 | 154 | 155 | class Deep_Coral(Algorithm): 156 | """ 157 | Deep Coral: https://arxiv.org/abs/1607.01719 158 | """ 159 | def __init__(self, backbone, configs, hparams, device): 160 | super().__init__(configs, backbone) 161 | 162 | # optimizer and scheduler 163 | self.optimizer = torch.optim.Adam( 164 | self.network.parameters(), 165 | lr=hparams["learning_rate"], 166 | weight_decay=hparams["weight_decay"] 167 | ) 168 | self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay']) 169 | # hparams 170 | self.hparams = hparams 171 | # device 172 | self.device = device 173 | 174 | # correlation alignment loss 175 | self.coral = CORAL() 176 | 177 | 178 | def training_epoch(self,src_loader, trg_loader, avg_meter, epoch): 179 | 180 | # Construct Joint Loaders 181 | # add if statement 182 | 183 | if len(src_loader) > len(trg_loader): 184 | joint_loader =enumerate(zip(src_loader, itertools.cycle(trg_loader))) 185 | else: 186 | joint_loader =enumerate(zip(itertools.cycle(src_loader), trg_loader)) 187 | 188 | for step, ((src_x, src_y), (trg_x, _)) in joint_loader: 189 | src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(self.device) 190 | 191 | src_feat = self.feature_extractor(src_x) 192 | src_pred = self.classifier(src_feat) 193 | 194 | src_cls_loss = self.cross_entropy(src_pred, src_y) 195 | 196 | trg_feat = self.feature_extractor(trg_x) 197 | 198 | coral_loss = self.coral(src_feat, trg_feat) 199 | 200 | loss = self.hparams["coral_wt"] * coral_loss + \ 201 | self.hparams["src_cls_loss_wt"] * src_cls_loss 202 | 203 | self.optimizer.zero_grad() 204 | loss.backward() 205 | self.optimizer.step() 206 | 207 | losses = {'Total_loss': loss.item(), 'Src_cls_loss': src_cls_loss.item(), 208 | 'coral_loss': coral_loss.item()} 209 | 210 | for key, val in losses.items(): 211 | avg_meter[key].update(val, 32) 212 | 213 | self.lr_scheduler.step() 214 | 215 | class MMDA(Algorithm): 216 | """ 217 | MMDA: https://arxiv.org/abs/1901.00282 218 | """ 219 | 220 | def __init__(self, backbone, configs, hparams, device): 221 | super().__init__(configs, backbone) 222 | 223 | # optimizer and scheduler 224 | self.optimizer = torch.optim.Adam( 225 | self.network.parameters(), 226 | lr=hparams["learning_rate"], 227 | weight_decay=hparams["weight_decay"] 228 | ) 229 | self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay']) 230 | # hparams 231 | self.hparams = hparams 232 | # device 233 | self.device = device 234 | 235 | # Aligment losses 236 | self.mmd = MMD_loss() 237 | self.coral = CORAL() 238 | self.cond_ent = ConditionalEntropyLoss() 239 | 240 | 241 | def training_epoch(self,src_loader, trg_loader, avg_meter, epoch): 242 | 243 | # Construct Joint Loaders 244 | joint_loader =enumerate(zip(src_loader, itertools.cycle(trg_loader))) 245 | 246 | for step, ((src_x, src_y), (trg_x, _)) in joint_loader: 247 | src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(self.device) 248 | 249 | src_feat = self.feature_extractor(src_x) 250 | src_pred = self.classifier(src_feat) 251 | 252 | src_cls_loss = self.cross_entropy(src_pred, src_y) 253 | 254 | trg_feat = self.feature_extractor(trg_x) 255 | src_feat = self.feature_extractor(src_x) 256 | src_pred = self.classifier(src_feat) 257 | 258 | src_cls_loss = self.cross_entropy(src_pred, src_y) 259 | 260 | trg_feat = self.feature_extractor(trg_x) 261 | 262 | coral_loss = self.coral(src_feat, trg_feat) 263 | mmd_loss = self.mmd(src_feat, trg_feat) 264 | cond_ent_loss = self.cond_ent(trg_feat) 265 | 266 | loss = self.hparams["coral_wt"] * coral_loss + \ 267 | self.hparams["mmd_wt"] * mmd_loss + \ 268 | self.hparams["cond_ent_wt"] * cond_ent_loss + \ 269 | self.hparams["src_cls_loss_wt"] * src_cls_loss 270 | 271 | self.optimizer.zero_grad() 272 | loss.backward() 273 | self.optimizer.step() 274 | 275 | losses = {'Total_loss': loss.item(), 'Coral_loss': coral_loss.item(), 'MMD_loss': mmd_loss.item(), 276 | 'cond_ent_wt': cond_ent_loss.item(), 'Src_cls_loss': src_cls_loss.item()} 277 | 278 | for key, val in losses.items(): 279 | avg_meter[key].update(val, 32) 280 | 281 | self.lr_scheduler.step() 282 | 283 | 284 | class DANN(Algorithm): 285 | """ 286 | DANN: https://arxiv.org/abs/1505.07818 287 | """ 288 | 289 | def __init__(self, backbone, configs, hparams, device): 290 | super().__init__(configs, backbone) 291 | 292 | 293 | # optimizer and scheduler 294 | self.optimizer = torch.optim.Adam( 295 | self.network.parameters(), 296 | lr=hparams["learning_rate"], 297 | weight_decay=hparams["weight_decay"] 298 | ) 299 | self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay']) 300 | # hparams 301 | self.hparams = hparams 302 | # device 303 | self.device = device 304 | 305 | # Domain Discriminator 306 | self.domain_classifier = Discriminator(configs) 307 | self.optimizer_disc = torch.optim.Adam( 308 | self.domain_classifier.parameters(), 309 | lr=hparams["learning_rate"], 310 | weight_decay=hparams["weight_decay"], betas=(0.5, 0.99) 311 | ) 312 | 313 | def training_epoch(self,src_loader, trg_loader, avg_meter, epoch): 314 | # Combine dataloaders 315 | # Method 1 (min len of both domains) 316 | # joint_loader = enumerate(zip(src_loader, trg_loader)) 317 | 318 | # Method 2 (max len of both domains) 319 | # joint_loader =enumerate(zip(src_loader, itertools.cycle(trg_loader))) 320 | joint_loader =enumerate(zip(src_loader, itertools.cycle(trg_loader))) 321 | num_batches = max(len(src_loader), len(trg_loader)) 322 | 323 | for step, ((src_x, src_y), (trg_x, _)) in joint_loader: 324 | 325 | src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(self.device) 326 | 327 | p = float(step + epoch * num_batches) / self.hparams["num_epochs"] + 1 / num_batches 328 | alpha = 2. / (1. + np.exp(-10 * p)) - 1 329 | 330 | # zero grad 331 | self.optimizer.zero_grad() 332 | self.optimizer_disc.zero_grad() 333 | 334 | domain_label_src = torch.ones(len(src_x)).to(self.device) 335 | domain_label_trg = torch.zeros(len(trg_x)).to(self.device) 336 | 337 | src_feat = self.feature_extractor(src_x) 338 | src_pred = self.classifier(src_feat) 339 | 340 | trg_feat = self.feature_extractor(trg_x) 341 | 342 | # Task classification Loss 343 | src_cls_loss = self.cross_entropy(src_pred.squeeze(), src_y) 344 | 345 | # Domain classification loss 346 | # source 347 | src_feat_reversed = ReverseLayerF.apply(src_feat, alpha) 348 | src_domain_pred = self.domain_classifier(src_feat_reversed) 349 | src_domain_loss = self.cross_entropy(src_domain_pred, domain_label_src.long()) 350 | 351 | # target 352 | trg_feat_reversed = ReverseLayerF.apply(trg_feat, alpha) 353 | trg_domain_pred = self.domain_classifier(trg_feat_reversed) 354 | trg_domain_loss = self.cross_entropy(trg_domain_pred, domain_label_trg.long()) 355 | 356 | # Total domain loss 357 | domain_loss = src_domain_loss + trg_domain_loss 358 | 359 | loss = self.hparams["src_cls_loss_wt"] * src_cls_loss + \ 360 | self.hparams["domain_loss_wt"] * domain_loss 361 | 362 | loss.backward() 363 | self.optimizer.step() 364 | self.optimizer_disc.step() 365 | 366 | losses = {'Total_loss': loss.item(), 'Domain_loss': domain_loss.item(), 'Src_cls_loss': src_cls_loss.item()} 367 | 368 | for key, val in losses.items(): 369 | avg_meter[key].update(val, 32) 370 | 371 | self.lr_scheduler.step() 372 | 373 | class CDAN(Algorithm): 374 | """ 375 | CDAN: https://arxiv.org/abs/1705.10667 376 | """ 377 | 378 | def __init__(self, backbone, configs, hparams, device): 379 | super().__init__(configs, backbone) 380 | 381 | 382 | # optimizer and scheduler 383 | self.optimizer = torch.optim.Adam( 384 | self.network.parameters(), 385 | lr=hparams["learning_rate"], 386 | weight_decay=hparams["weight_decay"] 387 | ) 388 | self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay']) 389 | # hparams 390 | self.hparams = hparams 391 | # device 392 | self.device = device 393 | 394 | # Aligment Losses 395 | self.criterion_cond = ConditionalEntropyLoss().to(device) 396 | 397 | self.domain_classifier = Discriminator_CDAN(configs) 398 | self.random_layer = RandomLayer([configs.features_len * configs.final_out_channels, configs.num_classes], 399 | configs.features_len * configs.final_out_channels) 400 | self.optimizer_disc = torch.optim.Adam( 401 | self.domain_classifier.parameters(), 402 | lr=hparams["learning_rate"], 403 | weight_decay=hparams["weight_decay"]) 404 | 405 | def training_epoch(self,src_loader, trg_loader, avg_meter, epoch): 406 | 407 | # Construct Joint Loaders 408 | joint_loader =enumerate(zip(src_loader, itertools.cycle(trg_loader))) 409 | 410 | for step, ((src_x, src_y), (trg_x, _)) in joint_loader: 411 | src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(self.device) 412 | # prepare true domain labels 413 | domain_label_src = torch.ones(len(src_x)).to(self.device) 414 | domain_label_trg = torch.zeros(len(trg_x)).to(self.device) 415 | domain_label_concat = torch.cat((domain_label_src, domain_label_trg), 0).long() 416 | 417 | # source features and predictions 418 | src_feat = self.feature_extractor(src_x) 419 | src_pred = self.classifier(src_feat) 420 | 421 | # target features and predictions 422 | trg_feat = self.feature_extractor(trg_x) 423 | trg_pred = self.classifier(trg_feat) 424 | 425 | # concatenate features and predictions 426 | feat_concat = torch.cat((src_feat, trg_feat), dim=0) 427 | pred_concat = torch.cat((src_pred, trg_pred), dim=0) 428 | 429 | # Domain classification loss 430 | feat_x_pred = torch.bmm(pred_concat.unsqueeze(2), feat_concat.unsqueeze(1)).detach() 431 | disc_prediction = self.domain_classifier(feat_x_pred.view(-1, pred_concat.size(1) * feat_concat.size(1))) 432 | disc_loss = self.cross_entropy(disc_prediction, domain_label_concat) 433 | 434 | # update Domain classification 435 | self.optimizer_disc.zero_grad() 436 | disc_loss.backward() 437 | self.optimizer_disc.step() 438 | 439 | # prepare fake domain labels for training the feature extractor 440 | domain_label_src = torch.zeros(len(src_x)).long().to(self.device) 441 | domain_label_trg = torch.ones(len(trg_x)).long().to(self.device) 442 | domain_label_concat = torch.cat((domain_label_src, domain_label_trg), 0) 443 | 444 | # Repeat predictions after updating discriminator 445 | feat_x_pred = torch.bmm(pred_concat.unsqueeze(2), feat_concat.unsqueeze(1)) 446 | disc_prediction = self.domain_classifier(feat_x_pred.view(-1, pred_concat.size(1) * feat_concat.size(1))) 447 | # loss of domain discriminator according to fake labels 448 | 449 | domain_loss = self.cross_entropy(disc_prediction, domain_label_concat) 450 | 451 | # Task classification Loss 452 | src_cls_loss = self.cross_entropy(src_pred.squeeze(), src_y) 453 | 454 | # conditional entropy loss. 455 | loss_trg_cent = self.criterion_cond(trg_pred) 456 | 457 | # total loss 458 | loss = self.hparams["src_cls_loss_wt"] * src_cls_loss + self.hparams["domain_loss_wt"] * domain_loss + \ 459 | self.hparams["cond_ent_wt"] * loss_trg_cent 460 | 461 | # update feature extractor 462 | self.optimizer.zero_grad() 463 | loss.backward() 464 | self.optimizer.step() 465 | 466 | losses = {'Total_loss': loss.item(), 'Domain_loss': domain_loss.item(), 'Src_cls_loss': src_cls_loss.item(), 467 | 'cond_ent_loss': loss_trg_cent.item()} 468 | 469 | for key, val in losses.items(): 470 | avg_meter[key].update(val, 32) 471 | self.lr_scheduler.step() 472 | 473 | class DIRT(Algorithm): 474 | """ 475 | DIRT-T: https://arxiv.org/abs/1802.08735 476 | """ 477 | 478 | def __init__(self, backbone, configs, hparams, device): 479 | super().__init__(configs, backbone) 480 | 481 | # optimizer and scheduler 482 | self.optimizer = torch.optim.Adam( 483 | self.network.parameters(), 484 | lr=hparams["learning_rate"], 485 | weight_decay=hparams["weight_decay"] 486 | ) 487 | self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay']) 488 | # hparams 489 | self.hparams = hparams 490 | # device 491 | self.device = device 492 | 493 | 494 | # Aligment losses 495 | self.criterion_cond = ConditionalEntropyLoss().to(device) 496 | self.vat_loss = VAT(self.network, device).to(device) 497 | self.ema = EMA(0.998) 498 | self.ema.register(self.network) 499 | 500 | # Discriminator 501 | self.domain_classifier = Discriminator(configs) 502 | self.optimizer_disc = torch.optim.Adam( 503 | self.domain_classifier.parameters(), 504 | lr=hparams["learning_rate"], 505 | weight_decay=hparams["weight_decay"] 506 | ) 507 | 508 | def training_epoch(self,src_loader, trg_loader, avg_meter, epoch): 509 | 510 | # Construct Joint Loaders 511 | joint_loader =enumerate(zip(src_loader, itertools.cycle(trg_loader))) 512 | 513 | for step, ((src_x, src_y), (trg_x, _)) in joint_loader: 514 | src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(self.device) 515 | # prepare true domain labels 516 | domain_label_src = torch.ones(len(src_x)).to(self.device) 517 | domain_label_trg = torch.zeros(len(trg_x)).to(self.device) 518 | domain_label_concat = torch.cat((domain_label_src, domain_label_trg), 0).long() 519 | 520 | src_feat = self.feature_extractor(src_x) 521 | src_pred = self.classifier(src_feat) 522 | 523 | # target features and predictions 524 | trg_feat = self.feature_extractor(trg_x) 525 | trg_pred = self.classifier(trg_feat) 526 | 527 | # concatenate features and predictions 528 | feat_concat = torch.cat((src_feat, trg_feat), dim=0) 529 | 530 | # Domain classification loss 531 | disc_prediction = self.domain_classifier(feat_concat.detach()) 532 | disc_loss = self.cross_entropy(disc_prediction, domain_label_concat) 533 | 534 | # update Domain classification 535 | self.optimizer_disc.zero_grad() 536 | disc_loss.backward() 537 | self.optimizer_disc.step() 538 | 539 | # prepare fake domain labels for training the feature extractor 540 | domain_label_src = torch.zeros(len(src_x)).long().to(self.device) 541 | domain_label_trg = torch.ones(len(trg_x)).long().to(self.device) 542 | domain_label_concat = torch.cat((domain_label_src, domain_label_trg), 0) 543 | 544 | # Repeat predictions after updating discriminator 545 | disc_prediction = self.domain_classifier(feat_concat) 546 | 547 | # loss of domain discriminator according to fake labels 548 | domain_loss = self.cross_entropy(disc_prediction, domain_label_concat) 549 | 550 | # Task classification Loss 551 | src_cls_loss = self.cross_entropy(src_pred.squeeze(), src_y) 552 | 553 | # conditional entropy loss. 554 | loss_trg_cent = self.criterion_cond(trg_pred) 555 | 556 | # Virual advariarial training loss 557 | loss_src_vat = self.vat_loss(src_x, src_pred) 558 | loss_trg_vat = self.vat_loss(trg_x, trg_pred) 559 | total_vat = loss_src_vat + loss_trg_vat 560 | # total loss 561 | loss = self.hparams["src_cls_loss_wt"] * src_cls_loss + self.hparams["domain_loss_wt"] * domain_loss + \ 562 | self.hparams["cond_ent_wt"] * loss_trg_cent + self.hparams["vat_loss_wt"] * total_vat 563 | 564 | # update exponential moving average 565 | self.ema(self.network) 566 | 567 | # update feature extractor 568 | self.optimizer.zero_grad() 569 | loss.backward() 570 | self.optimizer.step() 571 | 572 | losses = {'Total_loss': loss.item(), 'Domain_loss': domain_loss.item(), 'Src_cls_loss': src_cls_loss.item(), 573 | 'cond_ent_loss': loss_trg_cent.item()} 574 | 575 | for key, val in losses.items(): 576 | avg_meter[key].update(val, 32) 577 | 578 | self.lr_scheduler.step() 579 | 580 | class DSAN(Algorithm): 581 | """ 582 | DSAN: https://ieeexplore.ieee.org/document/9085896 583 | """ 584 | 585 | def __init__(self, backbone, configs, hparams, device): 586 | super().__init__(configs, backbone) 587 | 588 | # optimizer and scheduler 589 | self.optimizer = torch.optim.Adam( 590 | self.network.parameters(), 591 | lr=hparams["learning_rate"], 592 | weight_decay=hparams["weight_decay"] 593 | ) 594 | self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay']) 595 | # hparams 596 | self.hparams = hparams 597 | # device 598 | self.device = device 599 | 600 | # Alignment losses 601 | self.loss_LMMD = LMMD_loss(device=device, class_num=configs.num_classes).to(device) 602 | 603 | def training_epoch(self,src_loader, trg_loader, avg_meter, epoch): 604 | 605 | # Construct Joint Loaders 606 | joint_loader =enumerate(zip(src_loader, itertools.cycle(trg_loader))) 607 | 608 | for step, ((src_x, src_y), (trg_x, _)) in joint_loader: 609 | src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(self.device) # extract source features 610 | src_feat = self.feature_extractor(src_x) 611 | src_pred = self.classifier(src_feat) 612 | 613 | # extract target features 614 | trg_feat = self.feature_extractor(trg_x) 615 | trg_pred = self.classifier(trg_feat) 616 | 617 | # calculate lmmd loss 618 | domain_loss = self.loss_LMMD.get_loss(src_feat, trg_feat, src_y, torch.nn.functional.softmax(trg_pred, dim=1)) 619 | 620 | # calculate source classification loss 621 | src_cls_loss = self.cross_entropy(src_pred, src_y) 622 | 623 | # calculate the total loss 624 | loss = self.hparams["domain_loss_wt"] * domain_loss + \ 625 | self.hparams["src_cls_loss_wt"] * src_cls_loss 626 | 627 | self.optimizer.zero_grad() 628 | loss.backward() 629 | self.optimizer.step() 630 | 631 | losses = {'Total_loss': loss.item(), 'LMMD_loss': domain_loss.item(), 'Src_cls_loss': src_cls_loss.item()} 632 | 633 | for key, val in losses.items(): 634 | avg_meter[key].update(val, 32) 635 | 636 | self.lr_scheduler.step() 637 | 638 | class HoMM(Algorithm): 639 | """ 640 | HoMM: https://arxiv.org/pdf/1912.11976.pdf 641 | """ 642 | 643 | def __init__(self, backbone, configs, hparams, device): 644 | super().__init__(configs, backbone) 645 | 646 | # optimizer and scheduler 647 | self.optimizer = torch.optim.Adam( 648 | self.network.parameters(), 649 | lr=hparams["learning_rate"], 650 | weight_decay=hparams["weight_decay"] 651 | ) 652 | self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay']) 653 | # hparams 654 | self.hparams = hparams 655 | # device 656 | self.device = device 657 | 658 | # aligment losses 659 | self.coral = CORAL() 660 | self.HoMM_loss = HoMM_loss() 661 | 662 | def training_epoch(self,src_loader, trg_loader, avg_meter, epoch): 663 | 664 | # Construct Joint Loaders 665 | joint_loader =enumerate(zip(src_loader, itertools.cycle(trg_loader))) 666 | 667 | for step, ((src_x, src_y), (trg_x, _)) in joint_loader: 668 | src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(self.device) # extract source features 669 | 670 | src_feat = self.feature_extractor(src_x) 671 | src_pred = self.classifier(src_feat) 672 | 673 | # extract target features 674 | trg_feat = self.feature_extractor(trg_x) 675 | trg_pred = self.classifier(trg_feat) 676 | 677 | # calculate source classification loss 678 | src_cls_loss = self.cross_entropy(src_pred, src_y) 679 | 680 | # calculate lmmd loss 681 | domain_loss = self.HoMM_loss(src_feat, trg_feat) 682 | 683 | # calculate the total loss 684 | loss = self.hparams["domain_loss_wt"] * domain_loss + \ 685 | self.hparams["src_cls_loss_wt"] * src_cls_loss 686 | 687 | self.optimizer.zero_grad() 688 | loss.backward() 689 | self.optimizer.step() 690 | 691 | losses = {'Total_loss': loss.item(), 'HoMM_loss': domain_loss.item(), 'Src_cls_loss': src_cls_loss.item()} 692 | 693 | for key, val in losses.items(): 694 | avg_meter[key].update(val, 32) 695 | 696 | self.lr_scheduler.step() 697 | 698 | 699 | class DDC(Algorithm): 700 | """ 701 | DDC: https://arxiv.org/abs/1412.3474 702 | """ 703 | 704 | def __init__(self, backbone, configs, hparams, device): 705 | super().__init__(configs, backbone) 706 | 707 | # optimizer and scheduler 708 | self.optimizer = torch.optim.Adam( 709 | self.network.parameters(), 710 | lr=hparams["learning_rate"], 711 | weight_decay=hparams["weight_decay"] 712 | ) 713 | self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay']) 714 | # hparams 715 | self.hparams = hparams 716 | # device 717 | self.device = device 718 | 719 | # Aligment losses 720 | self.mmd_loss = MMD_loss() 721 | 722 | def training_epoch(self, src_loader, trg_loader, avg_meter, epoch): 723 | 724 | # Construct Joint Loaders 725 | joint_loader =enumerate(zip(src_loader, itertools.cycle(trg_loader))) 726 | 727 | for step, ((src_x, src_y), (trg_x, _)) in joint_loader: 728 | src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(self.device) # extract source features 729 | # extract source features 730 | src_feat = self.feature_extractor(src_x) 731 | src_pred = self.classifier(src_feat) 732 | 733 | # extract target features 734 | trg_feat = self.feature_extractor(trg_x) 735 | 736 | # calculate source classification loss 737 | src_cls_loss = self.cross_entropy(src_pred, src_y) 738 | 739 | # calculate mmd loss 740 | domain_loss = self.mmd_loss(src_feat, trg_feat) 741 | 742 | # calculate the total loss 743 | loss = self.hparams["domain_loss_wt"] * domain_loss + \ 744 | self.hparams["src_cls_loss_wt"] * src_cls_loss 745 | 746 | self.optimizer.zero_grad() 747 | loss.backward() 748 | self.optimizer.step() 749 | 750 | losses = {'Total_loss': loss.item(), 'MMD_loss': domain_loss.item(), 'Src_cls_loss': src_cls_loss.item()} 751 | 752 | for key, val in losses.items(): 753 | avg_meter[key].update(val, 32) 754 | 755 | self.lr_scheduler.step() 756 | 757 | class CoDATS(Algorithm): 758 | """ 759 | CoDATS: https://arxiv.org/pdf/2005.10996.pdf 760 | """ 761 | 762 | def __init__(self, backbone, configs, hparams, device): 763 | super().__init__(configs, backbone) 764 | 765 | # we replace the original classifier with codats the classifier 766 | # remember to use same name of self.classifier, as we use it for the model evaluation 767 | self.classifier = codats_classifier(configs) 768 | self.network = nn.Sequential(self.feature_extractor, self.classifier) 769 | 770 | # optimizer and scheduler 771 | self.optimizer = torch.optim.Adam( 772 | self.network.parameters(), 773 | lr=hparams["learning_rate"], 774 | weight_decay=hparams["weight_decay"] 775 | ) 776 | self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay']) 777 | # hparams 778 | self.hparams = hparams 779 | # device 780 | self.device = device 781 | 782 | 783 | # Domain classifier 784 | self.domain_classifier = Discriminator(configs) 785 | 786 | self.optimizer_disc = torch.optim.Adam( 787 | self.domain_classifier.parameters(), 788 | lr=hparams["learning_rate"], 789 | weight_decay=hparams["weight_decay"], betas=(0.5, 0.99) 790 | ) 791 | 792 | def training_epoch(self,src_loader, trg_loader, avg_meter, epoch): 793 | 794 | # Construct Joint Loaders 795 | joint_loader =enumerate(zip(src_loader, itertools.cycle(trg_loader))) 796 | num_batches = max(len(src_loader), len(trg_loader)) 797 | for step, ((src_x, src_y), (trg_x, _)) in joint_loader: 798 | src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(self.device) # extract source features 799 | 800 | p = float(step + epoch * num_batches) / self.hparams["num_epochs"] + 1 / num_batches 801 | alpha = 2. / (1. + np.exp(-10 * p)) - 1 802 | 803 | # zero grad 804 | self.optimizer.zero_grad() 805 | self.optimizer_disc.zero_grad() 806 | 807 | domain_label_src = torch.ones(len(src_x)).to(self.device) 808 | domain_label_trg = torch.zeros(len(trg_x)).to(self.device) 809 | 810 | src_feat = self.feature_extractor(src_x) 811 | src_pred = self.classifier(src_feat) 812 | 813 | trg_feat = self.feature_extractor(trg_x) 814 | 815 | # Task classification Loss 816 | src_cls_loss = self.cross_entropy(src_pred.squeeze(), src_y) 817 | 818 | # Domain classification loss 819 | # source 820 | src_feat_reversed = ReverseLayerF.apply(src_feat, alpha) 821 | src_domain_pred = self.domain_classifier(src_feat_reversed) 822 | src_domain_loss = self.cross_entropy(src_domain_pred, domain_label_src.long()) 823 | 824 | # target 825 | trg_feat_reversed = ReverseLayerF.apply(trg_feat, alpha) 826 | trg_domain_pred = self.domain_classifier(trg_feat_reversed) 827 | trg_domain_loss = self.cross_entropy(trg_domain_pred, domain_label_trg.long()) 828 | 829 | # Total domain loss 830 | domain_loss = src_domain_loss + trg_domain_loss 831 | 832 | loss = self.hparams["src_cls_loss_wt"] * src_cls_loss + \ 833 | self.hparams["domain_loss_wt"] * domain_loss 834 | 835 | loss.backward() 836 | self.optimizer.step() 837 | self.optimizer_disc.step() 838 | 839 | losses = {'Total_loss': loss.item(), 'Domain_loss': domain_loss.item(), 'Src_cls_loss': src_cls_loss.item()} 840 | for key, val in losses.items(): 841 | avg_meter[key].update(val, 32) 842 | 843 | self.lr_scheduler.step() 844 | 845 | class AdvSKM(Algorithm): 846 | """ 847 | AdvSKM: https://www.ijcai.org/proceedings/2021/0378.pdf 848 | """ 849 | 850 | def __init__(self, backbone, configs, hparams, device): 851 | super().__init__(configs, backbone) 852 | 853 | # optimizer and scheduler 854 | self.optimizer = torch.optim.Adam( 855 | self.network.parameters(), 856 | lr=hparams["learning_rate"], 857 | weight_decay=hparams["weight_decay"] 858 | ) 859 | self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay']) 860 | # hparams 861 | self.hparams = hparams 862 | # device 863 | self.device = device 864 | 865 | # Aligment losses 866 | self.mmd_loss = MMD_loss() 867 | self.AdvSKM_embedder = AdvSKM_Disc(configs).to(device) 868 | self.optimizer_disc = torch.optim.Adam( 869 | self.AdvSKM_embedder.parameters(), 870 | lr=hparams["learning_rate"], 871 | weight_decay=hparams["weight_decay"] 872 | ) 873 | 874 | def training_epoch(self,src_loader, trg_loader, avg_meter, epoch): 875 | 876 | # Construct Joint Loaders 877 | joint_loader =enumerate(zip(src_loader, itertools.cycle(trg_loader))) 878 | for step, ((src_x, src_y), (trg_x, _)) in joint_loader: 879 | src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(self.device) # extract source features 880 | 881 | src_feat = self.feature_extractor(src_x) 882 | src_pred = self.classifier(src_feat) 883 | 884 | # extract target features 885 | trg_feat = self.feature_extractor(trg_x) 886 | 887 | source_embedding_disc = self.AdvSKM_embedder(src_feat.detach()) 888 | target_embedding_disc = self.AdvSKM_embedder(trg_feat.detach()) 889 | mmd_loss = - self.mmd_loss(source_embedding_disc, target_embedding_disc) 890 | mmd_loss.requires_grad = True 891 | 892 | # update discriminator 893 | self.optimizer_disc.zero_grad() 894 | mmd_loss.backward() 895 | self.optimizer_disc.step() 896 | 897 | # calculate source classification loss 898 | src_cls_loss = self.cross_entropy(src_pred, src_y) 899 | 900 | # domain loss. 901 | source_embedding_disc = self.AdvSKM_embedder(src_feat) 902 | target_embedding_disc = self.AdvSKM_embedder(trg_feat) 903 | 904 | mmd_loss_adv = self.mmd_loss(source_embedding_disc, target_embedding_disc) 905 | mmd_loss_adv.requires_grad = True 906 | 907 | # calculate the total loss 908 | loss = self.hparams["domain_loss_wt"] * mmd_loss_adv + \ 909 | self.hparams["src_cls_loss_wt"] * src_cls_loss 910 | 911 | # update optimizer 912 | self.optimizer.zero_grad() 913 | loss.backward() 914 | self.optimizer.step() 915 | 916 | losses = {'Total_loss': loss.item(), 'MMD_loss': mmd_loss_adv.item(), 'Src_cls_loss': src_cls_loss.item()} 917 | for key, val in losses.items(): 918 | avg_meter[key].update(val, 32) 919 | 920 | self.lr_scheduler.step() 921 | 922 | class SASA(Algorithm): 923 | 924 | def __init__(self, backbone, configs, hparams, device): 925 | super().__init__(configs, backbone) 926 | 927 | # feature_length for classifier 928 | configs.features_len = 1 929 | self.classifier = classifier(configs) 930 | # feature length for feature extractor 931 | configs.features_len = 1 932 | self.feature_extractor = CNN_ATTN(configs) 933 | self.network = nn.Sequential(self.feature_extractor, self.classifier) 934 | 935 | # optimizer and scheduler 936 | self.optimizer = torch.optim.Adam( 937 | self.network.parameters(), 938 | lr=hparams["learning_rate"], 939 | weight_decay=hparams["weight_decay"] 940 | ) 941 | self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay']) 942 | # hparams 943 | self.hparams = hparams 944 | # device 945 | self.device = device 946 | 947 | 948 | def training_epoch(self,src_loader, trg_loader, avg_meter, epoch): 949 | 950 | # Construct Joint Loaders 951 | joint_loader =enumerate(zip(src_loader, itertools.cycle(trg_loader))) 952 | for step, ((src_x, src_y), (trg_x, _)) in joint_loader: 953 | src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(self.device) # extract source features 954 | 955 | # Extract features 956 | src_feature = self.feature_extractor(src_x) 957 | tgt_feature = self.feature_extractor(trg_x) 958 | 959 | # source classification loss 960 | y_pred = self.classifier(src_feature) 961 | src_cls_loss = self.cross_entropy(y_pred, src_y) 962 | 963 | # MMD loss 964 | domain_loss_intra = self.mmd_loss(src_struct=src_feature, 965 | tgt_struct=tgt_feature, weight=self.hparams['domain_loss_wt']) 966 | 967 | # total loss 968 | total_loss = self.hparams['src_cls_loss_wt'] * src_cls_loss + domain_loss_intra 969 | 970 | # remove old gradients 971 | self.optimizer.zero_grad() 972 | # calculate gradients 973 | total_loss.backward() 974 | # update the weights 975 | self.optimizer.step() 976 | 977 | losses = {'Total_loss': total_loss.item(), 'MMD_loss': domain_loss_intra.item(), 978 | 'Src_cls_loss': src_cls_loss.item()} 979 | for key, val in losses.items(): 980 | avg_meter[key].update(val, 32) 981 | 982 | self.lr_scheduler.step() 983 | def mmd_loss(self, src_struct, tgt_struct, weight): 984 | delta = torch.mean(src_struct - tgt_struct, dim=-2) 985 | loss_value = torch.norm(delta, 2) * weight 986 | return loss_value 987 | 988 | 989 | class CoTMix(Algorithm): 990 | def __init__(self, backbone, configs, hparams, device): 991 | super().__init__(configs, backbone) 992 | 993 | # optimizer and scheduler 994 | self.optimizer = torch.optim.Adam( 995 | self.network.parameters(), 996 | lr=hparams["learning_rate"], 997 | weight_decay=hparams["weight_decay"] 998 | ) 999 | self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay']) 1000 | # hparams 1001 | self.hparams = hparams 1002 | # device 1003 | self.device = device 1004 | 1005 | # Aligment losses 1006 | self.contrastive_loss = NTXentLoss(device, hparams["batch_size"], 0.2, True) 1007 | self.entropy_loss = ConditionalEntropyLoss() 1008 | self.sup_contrastive_loss = SupConLoss(device) 1009 | 1010 | def training_epoch(self,src_loader, trg_loader, avg_meter, epoch): 1011 | 1012 | # Construct Joint Loaders 1013 | joint_loader =enumerate(zip(src_loader, itertools.cycle(trg_loader))) 1014 | for step, ((src_x, src_y), (trg_x, _)) in joint_loader: 1015 | src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(self.device) # extract source features 1016 | 1017 | # ====== Temporal Mixup ===================== 1018 | src_dominant, trg_dominant = self.temporal_mixup(src_x, trg_x) 1019 | 1020 | # ====== Source ===================== 1021 | self.optimizer.zero_grad() 1022 | 1023 | # Src original features 1024 | src_orig_feat = self.feature_extractor(src_x) 1025 | src_orig_logits = self.classifier(src_orig_feat) 1026 | 1027 | # Target original features 1028 | trg_orig_feat = self.feature_extractor(trg_x) 1029 | trg_orig_logits = self.classifier(trg_orig_feat) 1030 | 1031 | # ----------- The two main losses 1032 | # Cross-Entropy loss 1033 | src_cls_loss = self.cross_entropy(src_orig_logits, src_y) 1034 | loss = src_cls_loss * round(self.hparams["src_cls_weight"], 2) 1035 | 1036 | # Target Entropy loss 1037 | trg_entropy_loss = self.entropy_loss(trg_orig_logits) 1038 | loss += trg_entropy_loss * round(self.hparams["trg_entropy_weight"], 2) 1039 | 1040 | # ----------- Auxiliary losses 1041 | # Extract source-dominant mixup features. 1042 | src_dominant_feat = self.feature_extractor(src_dominant) 1043 | src_dominant_logits = self.classifier(src_dominant_feat) 1044 | 1045 | # supervised contrastive loss on source domain side 1046 | src_concat = torch.cat([src_orig_logits.unsqueeze(1), src_dominant_logits.unsqueeze(1)], dim=1) 1047 | src_supcon_loss = self.sup_contrastive_loss(src_concat, src_y) 1048 | loss += src_supcon_loss * round(self.hparams["src_supCon_weight"], 2) 1049 | 1050 | # Extract target-dominant mixup features. 1051 | trg_dominant_feat = self.feature_extractor(trg_dominant) 1052 | trg_dominant_logits = self.classifier(trg_dominant_feat) 1053 | 1054 | # Unsupervised contrastive loss on target domain side 1055 | trg_con_loss = self.contrastive_loss(trg_orig_logits, trg_dominant_logits) 1056 | loss += trg_con_loss * round(self.hparams["trg_cont_weight"], 2) 1057 | 1058 | loss.backward() 1059 | self.optimizer.step() 1060 | 1061 | losses = {'Total_loss': loss.item(), 1062 | 'src_cls_loss': src_cls_loss.item(), 1063 | 'trg_entropy_loss': trg_entropy_loss.item(), 1064 | 'src_supcon_loss': src_supcon_loss.item(), 1065 | 'trg_con_loss': trg_con_loss.item() 1066 | } 1067 | for key, val in losses.items(): 1068 | avg_meter[key].update(val, 32) 1069 | 1070 | self.lr_scheduler.step() 1071 | 1072 | def temporal_mixup(self,src_x, trg_x): 1073 | 1074 | mix_ratio = round(self.hparams["mix_ratio"], 2) 1075 | temporal_shift = self.hparams["temporal_shift"] 1076 | h = temporal_shift // 2 # half 1077 | 1078 | src_dominant = mix_ratio * src_x + (1 - mix_ratio) * \ 1079 | torch.mean(torch.stack([torch.roll(trg_x, -i, 2) for i in range(-h, h)], 2), 2) 1080 | 1081 | trg_dominant = mix_ratio * trg_x + (1 - mix_ratio) * \ 1082 | torch.mean(torch.stack([torch.roll(src_x, -i, 2) for i in range(-h, h)], 2), 2) 1083 | 1084 | return src_dominant, trg_dominant 1085 | 1086 | 1087 | 1088 | # Untied Approaches: (MCD) 1089 | class MCD(Algorithm): 1090 | """ 1091 | Maximum Classifier Discrepancy for Unsupervised Domain Adaptation 1092 | MCD: https://arxiv.org/pdf/1712.02560.pdf 1093 | """ 1094 | 1095 | def __init__(self, backbone, configs, hparams, device): 1096 | super().__init__(configs, backbone) 1097 | 1098 | self.feature_extractor = backbone(configs) 1099 | self.classifier = classifier(configs) 1100 | self.classifier2 = classifier(configs) 1101 | 1102 | self.network = nn.Sequential(self.feature_extractor, self.classifier) 1103 | 1104 | 1105 | # optimizer and scheduler 1106 | self.optimizer_fe = torch.optim.Adam( 1107 | self.feature_extractor.parameters(), 1108 | lr=hparams["learning_rate"], 1109 | weight_decay=hparams["weight_decay"] 1110 | ) 1111 | # optimizer and scheduler 1112 | self.optimizer_c1 = torch.optim.Adam( 1113 | self.classifier.parameters(), 1114 | lr=hparams["learning_rate"], 1115 | weight_decay=hparams["weight_decay"] 1116 | ) 1117 | # optimizer and scheduler 1118 | self.optimizer_c2 = torch.optim.Adam( 1119 | self.classifier2.parameters(), 1120 | lr=hparams["learning_rate"], 1121 | weight_decay=hparams["weight_decay"] 1122 | ) 1123 | 1124 | self.lr_scheduler_fe = StepLR(self.optimizer_fe, step_size=hparams['step_size'], gamma=hparams['lr_decay']) 1125 | self.lr_scheduler_c1 = StepLR(self.optimizer_c1, step_size=hparams['step_size'], gamma=hparams['lr_decay']) 1126 | self.lr_scheduler_c2 = StepLR(self.optimizer_c2, step_size=hparams['step_size'], gamma=hparams['lr_decay']) 1127 | 1128 | # hparams 1129 | self.hparams = hparams 1130 | # device 1131 | self.device = device 1132 | 1133 | # Aligment losses 1134 | self.mmd_loss = MMD_loss() 1135 | 1136 | def update(self, src_loader, trg_loader, avg_meter, logger): 1137 | # defining best and last model 1138 | best_src_risk = float('inf') 1139 | best_model = None 1140 | 1141 | for epoch in range(1, self.hparams["num_epochs"] + 1): 1142 | 1143 | # source pretraining loop 1144 | self.pretrain_epoch(src_loader, avg_meter) 1145 | 1146 | # training loop 1147 | self.training_epoch(src_loader, trg_loader, avg_meter, epoch) 1148 | 1149 | # saving the best model based on src risk 1150 | if (epoch + 1) % 10 == 0 and avg_meter['Src_cls_loss'].avg < best_src_risk: 1151 | best_src_risk = avg_meter['Src_cls_loss'].avg 1152 | best_model = deepcopy(self.network.state_dict()) 1153 | 1154 | 1155 | logger.debug(f'[Epoch : {epoch}/{self.hparams["num_epochs"]}]') 1156 | for key, val in avg_meter.items(): 1157 | logger.debug(f'{key}\t: {val.avg:2.4f}') 1158 | logger.debug(f'-------------------------------------') 1159 | 1160 | last_model = self.network.state_dict() 1161 | 1162 | return last_model, best_model 1163 | 1164 | def pretrain_epoch(self, src_loader,avg_meter): 1165 | for src_x, src_y in src_loader: 1166 | src_x, src_y = src_x.to(self.device), src_y.to(self.device) 1167 | 1168 | src_feat = self.feature_extractor(src_x) 1169 | src_pred1 = self.classifier(src_feat) 1170 | src_pred2 = self.classifier2(src_feat) 1171 | 1172 | src_cls_loss1 = self.cross_entropy(src_pred1, src_y) 1173 | src_cls_loss2 = self.cross_entropy(src_pred2, src_y) 1174 | 1175 | loss = src_cls_loss1 + src_cls_loss2 1176 | 1177 | self.optimizer_c1.zero_grad() 1178 | self.optimizer_c2.zero_grad() 1179 | self.optimizer_fe.zero_grad() 1180 | 1181 | loss.backward() 1182 | 1183 | self.optimizer_c1.step() 1184 | self.optimizer_c2.step() 1185 | self.optimizer_fe.step() 1186 | 1187 | 1188 | losses = {'Src_cls_loss': loss.item()} 1189 | 1190 | for key, val in losses.items(): 1191 | avg_meter[key].update(val, 32) 1192 | 1193 | def training_epoch(self, src_loader, trg_loader, avg_meter, epoch): 1194 | 1195 | # Construct Joint Loaders 1196 | joint_loader =enumerate(zip(src_loader, itertools.cycle(trg_loader))) 1197 | 1198 | for step, ((src_x, src_y), (trg_x, _)) in joint_loader: 1199 | src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(self.device) # extract source features 1200 | 1201 | 1202 | # extract source features 1203 | src_feat = self.feature_extractor(src_x) 1204 | src_pred1 = self.classifier(src_feat) 1205 | src_pred2 = self.classifier2(src_feat) 1206 | 1207 | # source losses 1208 | src_cls_loss1 = self.cross_entropy(src_pred1, src_y) 1209 | src_cls_loss2 = self.cross_entropy(src_pred2, src_y) 1210 | loss_s = src_cls_loss1 + src_cls_loss2 1211 | 1212 | 1213 | # Freeze the feature extractor 1214 | for k, v in self.feature_extractor.named_parameters(): 1215 | v.requires_grad = False 1216 | # update C1 and C2 to maximize their difference on target sample 1217 | trg_feat = self.feature_extractor(trg_x) 1218 | trg_pred1 = self.classifier(trg_feat.detach()) 1219 | trg_pred2 = self.classifier2(trg_feat.detach()) 1220 | 1221 | 1222 | loss_dis = self.discrepancy(trg_pred1, trg_pred2) 1223 | 1224 | loss = loss_s - loss_dis 1225 | 1226 | loss.backward() 1227 | self.optimizer_c1.step() 1228 | self.optimizer_c2.step() 1229 | 1230 | self.optimizer_c1.zero_grad() 1231 | self.optimizer_c2.zero_grad() 1232 | self.optimizer_fe.zero_grad() 1233 | 1234 | # Freeze the classifiers 1235 | for k, v in self.classifier.named_parameters(): 1236 | v.requires_grad = False 1237 | for k, v in self.classifier2.named_parameters(): 1238 | v.requires_grad = False 1239 | # Freeze the feature extractor 1240 | for k, v in self.feature_extractor.named_parameters(): 1241 | v.requires_grad = True 1242 | # update feature extractor to minimize the discrepaqncy on target samples 1243 | trg_feat = self.feature_extractor(trg_x) 1244 | trg_pred1 = self.classifier(trg_feat) 1245 | trg_pred2 = self.classifier2(trg_feat) 1246 | 1247 | 1248 | loss_dis_t = self.discrepancy(trg_pred1, trg_pred2) 1249 | domain_loss = self.hparams["domain_loss_wt"] * loss_dis_t 1250 | 1251 | domain_loss.backward() 1252 | self.optimizer_fe.step() 1253 | 1254 | self.optimizer_fe.zero_grad() 1255 | self.optimizer_c1.zero_grad() 1256 | self.optimizer_c2.zero_grad() 1257 | 1258 | 1259 | losses = {'Total_loss': loss.item(), 'MMD_loss': domain_loss.item()} 1260 | 1261 | for key, val in losses.items(): 1262 | avg_meter[key].update(val, 32) 1263 | 1264 | self.lr_scheduler_fe.step() 1265 | self.lr_scheduler_c1.step() 1266 | self.lr_scheduler_c2.step() 1267 | 1268 | def discrepancy(self, out1, out2): 1269 | 1270 | return torch.mean(torch.abs(F.softmax(out1) - F.softmax(out2))) 1271 | -------------------------------------------------------------------------------- /configs/__pycache__/data_model_configs.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/configs/__pycache__/data_model_configs.cpython-310.pyc -------------------------------------------------------------------------------- /configs/__pycache__/data_model_configs.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/configs/__pycache__/data_model_configs.cpython-38.pyc -------------------------------------------------------------------------------- /configs/__pycache__/data_model_configs.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/configs/__pycache__/data_model_configs.cpython-39.pyc -------------------------------------------------------------------------------- /configs/__pycache__/hparams.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/configs/__pycache__/hparams.cpython-310.pyc -------------------------------------------------------------------------------- /configs/__pycache__/hparams.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/configs/__pycache__/hparams.cpython-38.pyc -------------------------------------------------------------------------------- /configs/__pycache__/hparams.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/configs/__pycache__/hparams.cpython-39.pyc -------------------------------------------------------------------------------- /configs/__pycache__/sweep_params.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/configs/__pycache__/sweep_params.cpython-310.pyc -------------------------------------------------------------------------------- /configs/__pycache__/sweep_params.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/configs/__pycache__/sweep_params.cpython-38.pyc -------------------------------------------------------------------------------- /configs/__pycache__/sweep_params.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/configs/__pycache__/sweep_params.cpython-39.pyc -------------------------------------------------------------------------------- /configs/data_model_configs.py: -------------------------------------------------------------------------------- 1 | def get_dataset_class(dataset_name): 2 | """Return the algorithm class with the given name.""" 3 | if dataset_name not in globals(): 4 | raise NotImplementedError("Dataset not found: {}".format(dataset_name)) 5 | return globals()[dataset_name] 6 | 7 | class HAR(): 8 | def __init__(self): 9 | super(HAR, self) 10 | self.scenarios = [("2", "11"), ("6", "23"), ("7", "13"), ("9", "18"), ("12", "16"), ("18", "27"), ("20", "5"), ("24", "8"), ("28", "27"), ("30", "20")] 11 | self.class_names = ['walk', 'upstairs', 'downstairs', 'sit', 'stand', 'lie'] 12 | self.sequence_len = 128 13 | self.shuffle = True 14 | self.drop_last = True 15 | self.normalize = True 16 | 17 | # model configs 18 | self.input_channels = 9 19 | self.kernel_size = 5 20 | self.stride = 1 21 | self.dropout = 0.5 22 | self.num_classes = 6 23 | 24 | # CNN and RESNET features 25 | self.mid_channels = 64 26 | self.final_out_channels = 128 27 | self.features_len = 1 28 | 29 | # TCN features 30 | self.tcn_layers = [75, 150] 31 | self.tcn_final_out_channles = self.tcn_layers[-1] 32 | self.tcn_kernel_size = 17 33 | self.tcn_dropout = 0.0 34 | 35 | # lstm features 36 | self.lstm_hid = 128 37 | self.lstm_n_layers = 1 38 | self.lstm_bid = False 39 | 40 | # discriminator 41 | self.disc_hid_dim = 64 42 | self.hidden_dim = 500 43 | self.DSKN_disc_hid = 128 44 | 45 | 46 | class EEG(): 47 | def __init__(self): 48 | super(EEG, self).__init__() 49 | # data parameters 50 | self.num_classes = 5 51 | self.class_names = ['W', 'N1', 'N2', 'N3', 'REM'] 52 | self.sequence_len = 3000 53 | self.scenarios = [("0", "11"), ("7", "18"), ("9", "14"), ("12", "5"), ("16", "1"), 54 | ("3", "19"), ("18", "12"), ("13", "17"), ("5", "15"), ("6", "2")] 55 | self.shuffle = True 56 | self.drop_last = True 57 | self.normalize = True 58 | 59 | # model configs 60 | self.input_channels = 1 61 | self.kernel_size = 25 62 | self.stride = 6 63 | self.dropout = 0.2 64 | 65 | # features 66 | self.mid_channels = 32 67 | self.final_out_channels = 128 68 | self.features_len = 1 69 | 70 | # TCN features 71 | self.tcn_layers = [32,64] 72 | self.tcn_final_out_channles = self.tcn_layers[-1] 73 | self.tcn_kernel_size = 15# 25 74 | self.tcn_dropout = 0.0 75 | 76 | # lstm features 77 | self.lstm_hid = 128 78 | self.lstm_n_layers = 1 79 | self.lstm_bid = False 80 | 81 | # discriminator 82 | self.DSKN_disc_hid = 128 83 | self.hidden_dim = 500 84 | self.disc_hid_dim = 100 85 | 86 | 87 | class WISDM(object): 88 | def __init__(self): 89 | super(WISDM, self).__init__() 90 | self.class_names = ['walk', 'jog', 'sit', 'stand', 'upstairs', 'downstairs'] 91 | self.sequence_len = 128 92 | self.scenarios = [("7", "18"), ("20", "30"), ("35", "31"), ("17", "23"), ("6", "19"), 93 | ("2", "11"), ("33", "12"), ("5", "26"), ("28", "4"), ("23", "32")] 94 | self.num_classes = 6 95 | self.shuffle = True 96 | self.drop_last = False 97 | self.normalize = True 98 | 99 | # model configs 100 | self.input_channels = 3 101 | self.kernel_size = 5 102 | self.stride = 1 103 | self.dropout = 0.5 104 | self.num_classes = 6 105 | 106 | # features 107 | self.mid_channels = 64 108 | self.final_out_channels = 128 109 | self.features_len = 1 110 | 111 | # TCN features 112 | self.tcn_layers = [75,150,300] 113 | self.tcn_final_out_channles = self.tcn_layers[-1] 114 | self.tcn_kernel_size = 17 115 | self.tcn_dropout = 0.0 116 | 117 | # lstm features 118 | self.lstm_hid = 128 119 | self.lstm_n_layers = 1 120 | self.lstm_bid = False 121 | 122 | # discriminator 123 | self.disc_hid_dim = 64 124 | self.DSKN_disc_hid = 128 125 | self.hidden_dim = 500 126 | 127 | 128 | class HHAR(object): ## HHAR dataset, SAMSUNG device. 129 | def __init__(self): 130 | super(HHAR, self).__init__() 131 | self.sequence_len = 128 132 | self.scenarios = [("0", "6"), ("1", "6"), ("2", "7"), ("3", "8"), ("4", "5"), 133 | ("5", "0"), ("6", "1"), ("7", "4"), ("8", "3"), ("0", "2")] 134 | self.class_names = ['bike', 'sit', 'stand', 'walk', 'stairs_up', 'stairs_down'] 135 | self.num_classes = 6 136 | self.shuffle = True 137 | self.drop_last = True 138 | self.normalize = True 139 | 140 | # model configs 141 | self.input_channels = 3 142 | self.kernel_size = 5 143 | self.stride = 1 144 | self.dropout = 0.5 145 | 146 | # features 147 | self.mid_channels = 64 148 | self.final_out_channels = 128 149 | self.features_len = 1 150 | 151 | # TCN features 152 | self.tcn_layers = [75,150] 153 | self.tcn_final_out_channles = self.tcn_layers[-1] 154 | self.tcn_kernel_size = 17 155 | self.tcn_dropout = 0.0 156 | 157 | # lstm features 158 | self.lstm_hid = 128 159 | self.lstm_n_layers = 1 160 | self.lstm_bid = False 161 | 162 | # discriminator 163 | self.disc_hid_dim = 64 164 | self.DSKN_disc_hid = 128 165 | self.hidden_dim = 500 166 | 167 | 168 | 169 | class FD(object): 170 | def __init__(self): 171 | super(FD, self).__init__() 172 | self.sequence_len = 5120 173 | self.scenarios = [("0", "1"), ("0", "3"), ("1", "0"), ("1", "2"),("1", "3"), 174 | ("2", "1"),("2", "3"), ("3", "0"), ("3", "1"), ("3", "2")] 175 | self.class_names = ['Healthy', 'D1', 'D2'] 176 | self.num_classes = 3 177 | self.shuffle = True 178 | self.drop_last = True 179 | self.normalize = True 180 | 181 | # Model configs 182 | self.input_channels = 1 183 | self.kernel_size = 32 184 | self.stride = 6 185 | self.dropout = 0.5 186 | 187 | self.mid_channels = 64 188 | self.final_out_channels = 128 189 | self.features_len = 1 190 | 191 | # TCN features 192 | self.tcn_layers = [75, 150] 193 | self.tcn_final_out_channles = self.tcn_layers[-1] 194 | self.tcn_kernel_size = 17 195 | self.tcn_dropout = 0.0 196 | 197 | # lstm features 198 | self.lstm_hid = 128 199 | self.lstm_n_layers = 1 200 | self.lstm_bid = False 201 | 202 | # discriminator 203 | self.disc_hid_dim = 64 204 | self.DSKN_disc_hid = 128 205 | self.hidden_dim = 500 206 | -------------------------------------------------------------------------------- /configs/hparams.py: -------------------------------------------------------------------------------- 1 | ## The cuurent hyper-parameters values are not necessarily the best ones for a specific risk. 2 | def get_hparams_class(dataset_name): 3 | """Return the algorithm class with the given name.""" 4 | if dataset_name not in globals(): 5 | raise NotImplementedError("Dataset not found: {}".format(dataset_name)) 6 | return globals()[dataset_name] 7 | 8 | 9 | class HAR(): 10 | def __init__(self): 11 | super(HAR, self).__init__() 12 | self.train_params = { 13 | 'num_epochs': 40, 14 | 'batch_size': 32, 15 | 'weight_decay': 1e-4, 16 | 'step_size': 50, 17 | 'lr_decay': 0.5 18 | 19 | } 20 | self.alg_hparams = { 21 | 'NO_ADAPT': {'learning_rate': 1e-3, 'src_cls_loss_wt': 1}, 22 | 'TARGET_ONLY': {'learning_rate': 1e-3, 'trg_cls_loss_wt': 1}, 23 | "SASA": { 24 | "domain_loss_wt": 7.3937939938562, 25 | "learning_rate": 0.005, 26 | "src_cls_loss_wt": 4.185814373345016, 27 | "weight_decay": 0.0001 28 | }, 29 | "DDC": { 30 | "learning_rate": 0.001, 31 | "mmd_wt": 3.7991920933520342, 32 | "src_cls_loss_wt": 6.286301875125623, 33 | "domain_loss_wt": 6.36, 34 | "weight_decay": 0.0001 35 | }, 36 | "CoDATS": { 37 | "domain_loss_wt": 3.2750474868706925, 38 | "learning_rate": 0.001, 39 | "src_cls_loss_wt": 6.335109786953256, 40 | "weight_decay": 0.0001 41 | }, 42 | "DANN": { 43 | "domain_loss_wt": 2.943729820531079, 44 | "learning_rate": 0.001, 45 | "src_cls_loss_wt": 5.1390077646202, 46 | "weight_decay": 0.0001 47 | }, 48 | "DIRT": { 49 | "cond_ent_wt": 1.20721518968644, 50 | "domain_loss_wt": 1.9012145515129044, 51 | "learning_rate": 0.005, 52 | "src_cls_loss_wt": 9.67861021290254, 53 | "vat_loss_wt": 7.7102843136045855, 54 | "weight_decay": 0.0001 55 | }, 56 | "DSAN": { 57 | "learning_rate": 0.001, 58 | "mmd_wt": 2.0872340713147786, 59 | "src_cls_loss_wt": 1.8744909939900247, 60 | "domain_loss_wt": 1.59, 61 | "weight_decay": 0.0001 62 | }, 63 | "MMDA": { 64 | "cond_ent_wt": 1.383002023133561, 65 | "coral_wt": 8.36810764913737, 66 | "learning_rate": 0.001, 67 | "mmd_wt": 3.964042918489996, 68 | "src_cls_loss_wt": 6.794522068759213, 69 | "weight_decay": 0.0001 70 | }, 71 | "Deep_Coral": { 72 | "coral_wt": 4.23035475456397, 73 | "learning_rate": 0.0005, 74 | "src_cls_loss_wt": 0.1013209750429822, 75 | "weight_decay": 0.0001 76 | }, 77 | "CDAN": { 78 | "cond_ent_wt": 1.2920143348777362, 79 | "domain_loss_wt": 9.545761950873414, 80 | "learning_rate": 0.001, 81 | "src_cls_loss_wt": 9.430292987535724, 82 | "weight_decay": 0.0001 83 | }, 84 | "AdvSKM": { 85 | "domain_loss_wt": 1.338788378230754, 86 | "learning_rate": 0.0005, 87 | "src_cls_loss_wt": 2.468525942065072, 88 | "weight_decay": 0.0001 89 | }, 90 | "HoMM": { 91 | "hommd_wt": 2.8305712579412683, 92 | "learning_rate": 0.0005, 93 | "src_cls_loss_wt": 0.1282520874653523, 94 | "domain_loss_wt": 9.13, 95 | "weight_decay": 0.0001 96 | }, 97 | 'CoTMix': {'learning_rate': 0.001, 'mix_ratio': 0.9, 'temporal_shift': 14, 98 | 'src_cls_weight': 0.78, 'src_supCon_weight': 0.1, 'trg_cont_weight': 0.1, 99 | 'trg_entropy_weight': 0.05}, 100 | 'MCD': {'learning_rate': 1e-2, 'src_cls_loss_wt': 9.74, 'domain_loss_wt': 5.43}, 101 | 102 | } 103 | 104 | 105 | class EEG(): 106 | def __init__(self): 107 | super(EEG, self).__init__() 108 | self.train_params = { 109 | 'num_epochs': 40, 110 | 'batch_size': 128, 111 | 'weight_decay': 1e-4, 112 | 113 | } 114 | self.alg_hparams = { 115 | 'NO_ADAPT': {'learning_rate': 1e-3, 'src_cls_loss_wt': 1}, 116 | 'TARGET_ONLY': {'learning_rate': 1e-3, 'trg_cls_loss_wt': 1}, 117 | "SASA": { 118 | "domain_loss_wt": 5.8045319155819515, 119 | "learning_rate": 0.005, 120 | "src_cls_loss_wt": 4.438490884851632, 121 | "weight_decay": 0.0001 122 | }, 123 | "CoDATS": { 124 | "domain_loss_wt": 0.3551260369189456, 125 | "learning_rate": 0.005, 126 | "src_cls_loss_wt": 1.2534327517723889, 127 | "weight_decay": 0.0001 128 | }, 129 | "AdvSKM": { 130 | "domain_loss_wt": 5.600818539370264, 131 | "learning_rate": 0.0005, 132 | "src_cls_loss_wt": 4.231231335081738, 133 | "weight_decay": 0.0001 134 | }, 135 | "Deep_Coral": { 136 | "coral_wt": 9.50224286095279, 137 | "learning_rate": 0.0005, 138 | "src_cls_loss_wt": 0.8149666724969482, 139 | "weight_decay": 0.0001 140 | }, 141 | "DANN": { 142 | "domain_loss_wt": 0.27634197975549135, 143 | "learning_rate": 0.0005, 144 | "src_cls_loss_wt": 8.441929209893459, 145 | "weight_decay": 0.0001 146 | }, 147 | "DDC": { 148 | "learning_rate": 0.0005, 149 | "mmd_wt": 5.900770246907044, 150 | "src_cls_loss_wt": 1.979307877348751, 151 | "domain_loss_wt": 8.923, 152 | "weight_decay": 0.0001 153 | }, 154 | "DIRT": { 155 | "cond_ent_wt": 1.7021814402136783, 156 | "domain_loss_wt": 1.6488583075821344, 157 | "learning_rate": 0.01, 158 | "src_cls_loss_wt": 6.427127521674593, 159 | "vat_loss_wt": 5.078600240648073, 160 | "weight_decay": 0.0001 161 | }, 162 | "MMDA": { 163 | "cond_ent_wt": 9.177841626283191, 164 | "coral_wt": 2.768290045896212, 165 | "learning_rate": 0.0005, 166 | "mmd_wt": 2.25231504738171, 167 | "src_cls_loss_wt": 8.64418208100774, 168 | "weight_decay": 0.0001 169 | }, 170 | "DSAN": { 171 | "learning_rate": 0.001, 172 | "mmd_wt": 5.01196798268099, 173 | "src_cls_loss_wt": 7.774381653453339, 174 | "domain_loss_wt": 6.708, 175 | "weight_decay": 0.0001 176 | }, 177 | "HoMM": { 178 | "hommd_wt": 3.843851397373747, 179 | "learning_rate": 0.001, 180 | "src_cls_loss_wt": 1.8311375304849091, 181 | "domain_loss_wt": 1.102, 182 | "weight_decay": 0.0001 183 | }, 184 | "CDAN": { 185 | "cond_ent_wt": 0.7559091229767906, 186 | "domain_loss_wt": 0.17693531166083065, 187 | "learning_rate": 0.0005, 188 | "src_cls_loss_wt": 7.764624556216286, 189 | "weight_decay": 0.0001 190 | }, 191 | 'CoTMix': {'learning_rate': 0.001, 'mix_ratio': 0.79, 'temporal_shift': 300, 192 | 'src_cls_weight': 0.96, 'src_supCon_weight': 0.1, 'trg_cont_weight': 0.1, 193 | 'trg_entropy_weight': 0.05} 194 | 195 | } 196 | 197 | 198 | class WISDM(): 199 | def __init__(self): 200 | super().__init__() 201 | self.train_params = { 202 | 'num_epochs': 40, 203 | 'batch_size': 32, 204 | 'weight_decay': 1e-4, 205 | 206 | } 207 | self.alg_hparams = { 208 | 'NO_ADAPT': {'learning_rate': 1e-3, 'src_cls_loss_wt': 1}, 209 | 'TARGET_ONLY': {'learning_rate': 1e-3, 'trg_cls_loss_wt': 1}, 210 | "SASA": { 211 | "domain_loss_wt": 1.2632988839197083, 212 | "learning_rate": 0.005, 213 | "src_cls_loss_wt": 9.898676755625807, 214 | "weight_decay": 0.0001 215 | }, 216 | "CDAN": { 217 | "cond_ent_wt": 0.837129024245748, 218 | "domain_loss_wt": 5.9197207530729266, 219 | "learning_rate": 0.001, 220 | "src_cls_loss_wt": 6.983963629299826, 221 | "weight_decay": 0.0001 222 | }, 223 | "HoMM": { 224 | "hommd_wt": 6.799448304230478, 225 | "learning_rate": 0.005, 226 | "src_cls_loss_wt": 0.2563533185103576, 227 | "domain_loss_wt": 4.239, 228 | "weight_decay": 0.0001 229 | }, 230 | "DANN": { 231 | "domain_loss_wt": 2.6051391453662873, 232 | "learning_rate": 0.005, 233 | "src_cls_loss_wt": 5.272383517138417, 234 | "weight_decay": 0.0001 235 | }, 236 | "DIRT": { 237 | "cond_ent_wt": 1.6935884891647972, 238 | "domain_loss_wt": 7.774841143071709, 239 | "learning_rate": 0.005, 240 | "src_cls_loss_wt": 9.62463958771893, 241 | "vat_loss_wt": 4.644539486962429, 242 | "weight_decay": 0.0001 243 | }, 244 | "AdvSKM": { 245 | "domain_loss_wt": 0.17573022784621156, 246 | "learning_rate": 0.001, 247 | "src_cls_loss_wt": 7.656694101023234, 248 | "weight_decay": 0.0001 249 | }, 250 | "MMDA": { 251 | "cond_ent_wt": 7.555540424691775, 252 | "coral_wt": 5.254400971297628, 253 | "learning_rate": 0.005, 254 | "mmd_wt": 2.295549751091742, 255 | "src_cls_loss_wt": 6.653513071102565, 256 | "weight_decay": 0.0001 257 | }, 258 | "Deep_Coral": { 259 | "coral_wt": 6.4881104202861755, 260 | "learning_rate": 0.001, 261 | "src_cls_loss_wt": 6.66305608395703, 262 | "weight_decay": 0.0001 263 | }, 264 | "CoDATS": { 265 | "domain_loss_wt": 4.574872968982744, 266 | "learning_rate": 0.001, 267 | "src_cls_loss_wt": 5.860885469514424, 268 | "weight_decay": 0.0001 269 | }, 270 | "DSAN": { 271 | "learning_rate": 0.005, 272 | "mmd_wt": 1.5468030830413808, 273 | "src_cls_loss_wt": 1.2981011362021273, 274 | "domain_loss_wt": 0.1, 275 | "weight_decay": 0.0001 276 | }, 277 | "DDC": { 278 | "learning_rate": 0.001, 279 | "mmd_wt": 1.9901164953952095, 280 | "src_cls_loss_wt": 4.881899626451807, 281 | "domain_loss_wt": 7.595, 282 | "weight_decay": 0.0001 283 | }, 284 | "CoTMix": { 285 | 'learning_rate': 0.001, 286 | 'mix_ratio': 0.72, 287 | 'temporal_shift': 14, 288 | 'src_cls_weight': 0.98, 289 | 'src_supCon_weight': 0.1, 290 | 'trg_cont_weight': 0.1, 291 | 'trg_entropy_weight': 0.05} 292 | } 293 | 294 | 295 | class HHAR(): 296 | def __init__(self): 297 | super().__init__() 298 | self.train_params = { 299 | 'num_epochs': 40, 300 | 'batch_size': 32, 301 | 'weight_decay': 1e-4, 302 | } 303 | self.alg_hparams = { 304 | 'NO_ADAPT': {'learning_rate': 1e-3, 'src_cls_loss_wt': 1}, 305 | 'TARGET_ONLY': {'learning_rate': 1e-3, 'trg_cls_loss_wt': 1}, 306 | 307 | "SASA": { 308 | "domain_loss_wt": 5.760124609738364, 309 | "learning_rate": 0.001, 310 | "src_cls_loss_wt": 4.130742585941761, 311 | "weight_decay": 0.0001 312 | }, 313 | "DSAN": { 314 | "learning_rate": 0.0005, 315 | "mmd_wt": 0.5993593617252002, 316 | "src_cls_loss_wt": 0.386167577207679, 317 | "domain_loss_wt": 0.16, 318 | "weight_decay": 0.0001 319 | }, 320 | "CoDATS": { 321 | "domain_loss_wt": 9.314114040099962, 322 | "learning_rate": 0.0005, 323 | "src_cls_loss_wt": 7.700018679383289, 324 | "weight_decay": 0.0001 325 | }, 326 | "HoMM": { 327 | "hommd_wt": 7.172430927893522, 328 | "learning_rate": 0.0005, 329 | "src_cls_loss_wt": 0.20121211752349172, 330 | "domain_loss_wt": 0.9824, 331 | "weight_decay": 0.0001 332 | }, 333 | "DIRT": { 334 | "cond_ent_wt": 1.329734510542011, 335 | "domain_loss_wt": 6.632293308809388, 336 | "learning_rate": 0.001, 337 | "src_cls_loss_wt": 7.729881324550688, 338 | "vat_loss_wt": 6.912258476982827, 339 | "weight_decay": 0.0001 340 | }, 341 | "AdvSKM": { 342 | "domain_loss_wt": 1.8649335076712072, 343 | "learning_rate": 0.001, 344 | "src_cls_loss_wt": 3.961611563054495, 345 | "weight_decay": 0.0001 346 | }, 347 | "DDC": { 348 | "learning_rate": 0.0005, 349 | "mmd_wt": 8.355791702302787, 350 | "src_cls_loss_wt": 1.2079058664226126, 351 | "domain_loss_wt": 0.2048, 352 | "weight_decay": 0.0001 353 | }, 354 | "CDAN": { 355 | "cond_ent_wt": 0.1841898900507932, 356 | "domain_loss_wt": 1.9307294194382076, 357 | "learning_rate": 0.0005, 358 | "src_cls_loss_wt": 4.15410157776963, 359 | "weight_decay": 0.0001 360 | }, 361 | "DANN": { 362 | "domain_loss_wt": 1.0296390274908802, 363 | "learning_rate": 0.0005, 364 | "src_cls_loss_wt": 2.038458138479581, 365 | "weight_decay": 0.0001 366 | }, 367 | "Deep_Coral": { 368 | "coral_wt": 5.9357031653707475, 369 | "learning_rate": 0.0005, 370 | "src_cls_loss_wt": 0.43859323168654, 371 | "weight_decay": 0.0001 372 | }, 373 | "MMDA": { 374 | "cond_ent_wt": 6.707871745810609, 375 | "coral_wt": 5.903714930042433, 376 | "learning_rate": 0.005, 377 | "mmd_wt": 6.480169289397163, 378 | "src_cls_loss_wt": 0.18878476669902317, 379 | "weight_decay": 0.0001 380 | }, 381 | 'CoTMix': {'learning_rate': 0.001, 'mix_ratio': 0.52, 'temporal_shift': 14, 382 | 'src_cls_weight': 0.8, 'src_supCon_weight': 0.1, 'trg_cont_weight': 0.1, 383 | 'trg_entropy_weight': 0.05} 384 | 385 | } 386 | 387 | 388 | class FD(): 389 | def __init__(self): 390 | super().__init__() 391 | self.train_params = { 392 | 'num_epochs': 40, 393 | 'batch_size': 32, 394 | 'weight_decay': 1e-4, 395 | } 396 | self.alg_hparams = { 397 | 'NO_ADAPT': {'learning_rate': 1e-3, 'src_cls_loss_wt': 1}, 398 | 'TARGET_ONLY': {'learning_rate': 1e-3, 'trg_cls_loss_wt': 1}, 399 | "SASA": { 400 | "domain_loss_wt": 0.7821851095870519, 401 | "learning_rate": 0.005, 402 | "src_cls_loss_wt": 7.680225091930735, 403 | "weight_decay": 0.0001 404 | }, 405 | "MMDA": { 406 | "cond_ent_wt": 8.12868726468387, 407 | "coral_wt": 7.2734249221691005, 408 | "learning_rate": 0.0005, 409 | "mmd_wt": 4.967077206689191, 410 | "src_cls_loss_wt": 0.30259189730747005, 411 | "weight_decay": 0.0001 412 | }, 413 | "AdvSKM": { 414 | "domain_loss_wt": 9.377024659182622, 415 | "learning_rate": 0.001, 416 | "src_cls_loss_wt": 0.7569318345582794, 417 | "weight_decay": 0.0001 418 | }, 419 | "HoMM": { 420 | "hommd_wt": 6.719563315664067, 421 | "learning_rate": 0.001, 422 | "src_cls_loss_wt": 1.5584167741262964, 423 | "domain_loss_wt": 0.9824, 424 | "weight_decay": 0.0001 425 | }, 426 | "Deep_Coral": { 427 | "coral_wt": 7.493856538302936, 428 | "learning_rate": 0.001, 429 | "src_cls_loss_wt": 1.452466194151791, 430 | "weight_decay": 0.0001 431 | }, 432 | "DIRT": { 433 | "cond_ent_wt": 4.753485587751647, 434 | "domain_loss_wt": 7.427507171955081, 435 | "learning_rate": 0.001, 436 | "src_cls_loss_wt": 9.818770948448943, 437 | "vat_loss_wt": 9.609164719194178, 438 | "weight_decay": 0.0001 439 | }, 440 | "DSAN": { 441 | "learning_rate": 0.005, 442 | "mmd_wt": 7.278792967879357, 443 | "src_cls_loss_wt": 2.5146121077752395, 444 | "domain_loss_wt": 0.16, 445 | "weight_decay": 0.0001 446 | }, 447 | "CDAN": { 448 | "cond_ent_wt": 0.553637609557987, 449 | "domain_loss_wt": 6.759045461432962, 450 | "learning_rate": 0.001, 451 | "src_cls_loss_wt": 6.854042579661701, 452 | "weight_decay": 0.0001 453 | }, 454 | "DDC": { 455 | "learning_rate": 0.005, 456 | "mmd_wt": 6.701050990813831, 457 | "src_cls_loss_wt": 1.1626428404763771, 458 | "domain_loss_wt": 0.2048, 459 | "weight_decay": 0.0001 460 | }, 461 | "CoDATS": { 462 | "domain_loss_wt": 0.6990097136753354, 463 | "learning_rate": 0.005, 464 | "src_cls_loss_wt": 9.57338373194037, 465 | "weight_decay": 0.0001 466 | }, 467 | "DANN": { 468 | "domain_loss_wt": 5.221878412210977, 469 | "learning_rate": 0.001, 470 | "src_cls_loss_wt": 4.233865748743297, 471 | "weight_decay": 0.0001 472 | }, 473 | 'CoTMix': {'learning_rate': 0.001, 'mix_ratio': 0.52, 'temporal_shift': 14, 474 | 'src_cls_weight': 0.8, 'src_supCon_weight': 0.1, 'trg_cont_weight': 0.1, 475 | 'trg_entropy_weight': 0.05} 476 | } 477 | -------------------------------------------------------------------------------- /configs/sweep_params.py: -------------------------------------------------------------------------------- 1 | sweep_train_hparams = { 2 | 'num_epochs': {'values': [3, 4, 5, 6]}, 3 | 'batch_size': {'values': [32, 64]}, 4 | 'learning_rate':{'values': [1e-2, 5e-3, 1e-3, 5e-4]}, 5 | 'disc_lr': {'values': [1e-2, 5e-3, 1e-3, 5e-4]}, 6 | 'weight_decay': {'values': [1e-4, 1e-5, 1e-6]}, 7 | 'step_size': {'values': [5, 10, 30]}, 8 | 'gamma': {'values': [5, 10, 15, 20, 25]}, 9 | 'optimizer': {'values': ['adam']}, 10 | } 11 | sweep_alg_hparams = { 12 | 'DANN': { 13 | 'learning_rate': {'values': [1e-2, 5e-3, 1e-3, 5e-4]}, 14 | 'src_cls_loss_wt': {'distribution': 'uniform', 'min': 1e-1, 'max': 10}, 15 | 'domain_loss_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10}, 16 | }, 17 | 18 | 'AdvSKM': { 19 | 'learning_rate': {'values': [1e-2, 5e-3, 1e-3, 5e-4]}, 20 | 'src_cls_loss_wt': {'distribution': 'uniform', 'min': 1e-1, 'max': 10}, 21 | 'domain_loss_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10}, 22 | }, 23 | 24 | 'CoDATS': { 25 | 'learning_rate': {'values': [1e-2, 5e-3, 1e-3, 5e-4]}, 26 | 'src_cls_loss_wt': {'distribution': 'uniform', 'min': 1e-1, 'max': 10}, 27 | 'domain_loss_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10}, 28 | }, 29 | 30 | 'CDAN': { 31 | 'learning_rate': {'values': [1e-2, 5e-3, 1e-3, 5e-4]}, 32 | 'src_cls_loss_wt': {'distribution': 'uniform', 'min': 1e-1, 'max': 10}, 33 | 'domain_loss_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10}, 34 | 'cond_ent_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10}, 35 | }, 36 | 37 | 'Deep_Coral': { 38 | 'learning_rate': {'values': [1e-2, 5e-3, 1e-3, 5e-4]}, 39 | 'src_cls_loss_wt': {'distribution': 'uniform', 'min': 1e-1, 'max': 10}, 40 | 'coral_wt': {'distribution': 'uniform', 'min': 1e-1, 'max': 10}, 41 | }, 42 | 43 | 'DIRT': { 44 | 'learning_rate': {'values': [1e-2, 5e-3, 1e-3, 5e-4]}, 45 | 'src_cls_loss_wt': {'distribution': 'uniform', 'min': 1e-1, 'max': 10}, 46 | 'domain_loss_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10}, 47 | 'cond_ent_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10}, 48 | 'vat_loss_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10}, 49 | }, 50 | 51 | 'HoMM': { 52 | 'learning_rate': {'values': [1e-2, 5e-3, 1e-3, 5e-4]}, 53 | 'src_cls_loss_wt': {'distribution': 'uniform', 'min': 1e-1, 'max': 10}, 54 | 'hommd_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10}, 55 | }, 56 | 57 | 'MMDA': { 58 | 'learning_rate': {'values': [1e-2, 5e-3, 1e-3, 5e-4]}, 59 | 'src_cls_loss_wt': {'distribution': 'uniform', 'min': 1e-1, 'max': 10}, 60 | 'coral_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10}, 61 | 'cond_ent_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10}, 62 | 'mmd_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10}, 63 | }, 64 | 65 | 'DSAN': { 66 | 'learning_rate': {'values': [1e-2, 5e-3, 1e-3, 5e-4]}, 67 | 'src_cls_loss_wt': {'distribution': 'uniform', 'min': 1e-1, 'max': 10}, 68 | 'mmd_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10}, 69 | }, 70 | 71 | 'DDC': { 72 | 'learning_rate': {'values': [1e-2, 5e-3, 1e-3, 5e-4]}, 73 | 'src_cls_loss_wt': {'distribution': 'uniform', 'min': 1e-1, 'max': 10}, 74 | 'mmd_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10}, 75 | }, 76 | 77 | 'SASA': { 78 | 'learning_rate': {'values': [1e-2, 5e-3, 1e-3, 5e-4]}, 79 | 'src_cls_loss_wt': {'distribution': 'uniform', 'min': 1e-1, 'max': 10}, 80 | 'domain_loss_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10}, 81 | }, 82 | 83 | 'CoTMix': { 84 | 'learning_rate': {'values': [1e-2, 5e-3, 1e-3, 5e-4]}, 85 | 'temporal_shift': {'values': [5, 10, 15, 20, 30, 50]}, 86 | 'src_cls_weight': {'distribution': 'uniform', 'min': 1e-1, 'max': 1}, 87 | 'mix_ratio': {'distribution': 'uniform', 'min': 0.5, 'max': 0.99}, 88 | 'src_supCon_weight': {'distribution': 'uniform', 'min': 1e-3, 'max': 1}, 89 | 'trg_cont_weight': {'distribution': 'uniform', 'min': 1e-3, 'max': 1}, 90 | 'trg_entropy_weight': {'distribution': 'uniform', 'min': 1e-3, 'max': 1}, 91 | }, 92 | } 93 | 94 | -------------------------------------------------------------------------------- /dataloader/__pycache__/dataloader.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/dataloader/__pycache__/dataloader.cpython-310.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/dataloader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/dataloader/__pycache__/dataloader.cpython-38.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/dataloader.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/dataloader/__pycache__/dataloader.cpython-39.pyc -------------------------------------------------------------------------------- /dataloader/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torch.utils.data import Dataset 4 | from torchvision import transforms 5 | 6 | from sklearn.model_selection import train_test_split 7 | 8 | import os, sys 9 | import numpy as np 10 | import random 11 | 12 | 13 | class Load_Dataset(Dataset): 14 | def __init__(self, dataset, dataset_configs): 15 | super().__init__() 16 | self.num_channels = dataset_configs.input_channels 17 | 18 | # Load samples 19 | x_data = dataset["samples"] 20 | 21 | # Load labels 22 | y_data = dataset.get("labels") 23 | if y_data is not None and isinstance(y_data, np.ndarray): 24 | y_data = torch.from_numpy(y_data) 25 | 26 | # Convert to torch tensor 27 | if isinstance(x_data, np.ndarray): 28 | x_data = torch.from_numpy(x_data) 29 | 30 | # Check samples dimensions. 31 | # The dimension of the data is expected to be (N, C, L) 32 | # where N is the #samples, C: #channels, and L is the sequence length 33 | if len(x_data.shape) == 2: 34 | x_data = x_data.unsqueeze(1) 35 | elif len(x_data.shape) == 3 and x_data.shape[1] != self.num_channels: 36 | x_data = x_data.transpose(1, 2) 37 | 38 | # Normalize data 39 | if dataset_configs.normalize: 40 | data_mean = torch.mean(x_data, dim=(0, 2)) 41 | data_std = torch.std(x_data, dim=(0, 2)) 42 | self.transform = transforms.Normalize(mean=data_mean, std=data_std) 43 | else: 44 | self.transform = None 45 | self.x_data = x_data.float() 46 | self.y_data = y_data.long() if y_data is not None else None 47 | self.len = x_data.shape[0] 48 | 49 | 50 | def __getitem__(self, index): 51 | x = self.x_data[index] 52 | if self.transform: 53 | x = self.transform(self.x_data[index].reshape(self.num_channels, -1, 1)).reshape(self.x_data[index].shape) 54 | y = self.y_data[index] if self.y_data is not None else None 55 | return x, y 56 | 57 | def __len__(self): 58 | return self.len 59 | 60 | 61 | def data_generator(data_path, domain_id, dataset_configs, hparams, dtype): 62 | # loading dataset file from path 63 | dataset_file = torch.load(os.path.join(data_path, f"{dtype}_{domain_id}.pt")) 64 | 65 | # Loading datasets 66 | dataset = Load_Dataset(dataset_file, dataset_configs) 67 | 68 | if dtype == "test": # you don't need to shuffle or drop last batch while testing 69 | shuffle = False 70 | drop_last = False 71 | else: 72 | shuffle = dataset_configs.shuffle 73 | drop_last = dataset_configs.drop_last 74 | 75 | # Dataloaders 76 | data_loader = torch.utils.data.DataLoader(dataset=dataset, 77 | batch_size=hparams["batch_size"], 78 | shuffle=shuffle, 79 | drop_last=drop_last, 80 | num_workers=0) 81 | 82 | return data_loader 83 | 84 | 85 | 86 | def data_generator_old(data_path, domain_id, dataset_configs, hparams): 87 | # loading path 88 | train_dataset = torch.load(os.path.join(data_path, "train_" + domain_id + ".pt")) 89 | test_dataset = torch.load(os.path.join(data_path, "test_" + domain_id + ".pt")) 90 | 91 | # Loading datasets 92 | train_dataset = Load_Dataset(train_dataset, dataset_configs) 93 | test_dataset = Load_Dataset(test_dataset, dataset_configs) 94 | 95 | # Dataloaders 96 | batch_size = hparams["batch_size"] 97 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, 98 | shuffle=True, drop_last=True, num_workers=0) 99 | 100 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, 101 | shuffle=False, drop_last=dataset_configs.drop_last, num_workers=0) 102 | return train_loader, test_loader 103 | 104 | 105 | 106 | def few_shot_data_generator(data_loader, dataset_configs, num_samples=5): 107 | x_data = data_loader.dataset.x_data 108 | y_data = data_loader.dataset.y_data 109 | 110 | NUM_SAMPLES_PER_CLASS = num_samples 111 | NUM_CLASSES = len(torch.unique(y_data)) 112 | 113 | counts = [y_data.eq(i).sum().item() for i in range(NUM_CLASSES)] 114 | samples_count_dict = {i: min(counts[i], NUM_SAMPLES_PER_CLASS) for i in range(NUM_CLASSES)} 115 | 116 | samples_ids = {i: torch.where(y_data == i)[0] for i in range(NUM_CLASSES)} 117 | selected_ids = {i: torch.randperm(samples_ids[i].size(0))[:samples_count_dict[i]] for i in range(NUM_CLASSES)} 118 | 119 | selected_x = torch.cat([x_data[samples_ids[i][selected_ids[i]]] for i in range(NUM_CLASSES)], dim=0) 120 | selected_y = torch.cat([y_data[samples_ids[i][selected_ids[i]]] for i in range(NUM_CLASSES)], dim=0) 121 | 122 | few_shot_dataset = {"samples": selected_x, "labels": selected_y} 123 | few_shot_dataset = Load_Dataset(few_shot_dataset, dataset_configs) 124 | 125 | few_shot_loader = torch.utils.data.DataLoader(dataset=few_shot_dataset, batch_size=len(few_shot_dataset), 126 | shuffle=False, drop_last=False, num_workers=0) 127 | 128 | return few_shot_loader 129 | 130 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from trainers.train import Trainer 2 | 3 | import argparse 4 | parser = argparse.ArgumentParser() 5 | 6 | if __name__ == "__main__": 7 | 8 | # ======== Experiments Phase ================ 9 | parser.add_argument('--phase', default='train', type=str, help='train, test') 10 | 11 | # ======== Experiments Name ================ 12 | parser.add_argument('--save_dir', default='experiments_logs', type=str, help='Directory containing all experiments') 13 | parser.add_argument('--exp_name', default='EXP1', type=str, help='experiment name') 14 | 15 | # ========= Select the DA methods ============ 16 | parser.add_argument('--da_method', default='MCD', type=str, help='NO_ADAPT, Deep_Coral, MMDA, DANN, CDAN, DIRT, DSAN, HoMM, CoDATS, AdvSKM, SASA, CoTMix, TARGET_ONLY') 17 | 18 | # ========= Select the DATASET ============== 19 | parser.add_argument('--data_path', default=r'../ADATIME_data', type=str, help='Path containing datase2t') 20 | parser.add_argument('--dataset', default='HAR', type=str, help='Dataset of choice: (WISDM - EEG - HAR - HHAR_SA)') 21 | 22 | # ========= Select the BACKBONE ============== 23 | parser.add_argument('--backbone', default='CNN', type=str, help='Backbone of choice: (CNN - RESNET18 - TCN)') 24 | 25 | # ========= Experiment settings =============== 26 | parser.add_argument('--num_runs', default=1, type=int, help='Number of consecutive run with different seeds') 27 | parser.add_argument('--device', default= "cuda", type=str, help='cpu or cuda') 28 | 29 | # arguments 30 | args = parser.parse_args() 31 | 32 | # create trainier object 33 | trainer = Trainer(args) 34 | 35 | # train and test 36 | if args.phase == 'train': 37 | trainer.fit() 38 | elif args.phase == 'test': 39 | trainer.test() 40 | 41 | 42 | 43 | #TODO: 44 | # 1- Change the naming of the functions ---> ( Done) 45 | # 2- Change the algorithms following DCORAL --> (Done) 46 | # 3- Keep one trainer for both train and test -->(Done) 47 | # 4- Create the new joint loader that consider the all possible batches --> Done 48 | # 5- Implement Lower/Upper Bound Approach --> Done 49 | # 6- Add the best hparams --> Done 50 | # 7- Add pretrain based methods (ADDA, MCD, MDD) 51 | -------------------------------------------------------------------------------- /main_sweep.py: -------------------------------------------------------------------------------- 1 | from trainers.sweep import Trainer 2 | import argparse 3 | parser = argparse.ArgumentParser() 4 | 5 | 6 | 7 | 8 | if __name__ == "__main__": 9 | # ========= Select the DA methods ============ 10 | parser.add_argument('--da_method', default='Deep_Coral', type=str, 11 | help='DANN, Deep_Coral, WDGRL, MMDA, VADA, DIRT, CDAN, ADDA, HoMM, CoDATS') 12 | 13 | # ========= Select the DATASET ============== 14 | parser.add_argument('--data_path', default=r'../ADATIME_data', type=str, help='Path containing datase2t') 15 | parser.add_argument('--dataset', default='HAR', type=str, help='Dataset of choice: (WISDM - EEG - HAR - HHAR_SA)') 16 | 17 | # ========= Select the BACKBONE ============== 18 | parser.add_argument('--backbone', default='CNN', type=str, help='Backbone of choice: (CNN - RESNET18 - TCN)') 19 | 20 | # ========= Experiment settings =============== 21 | parser.add_argument('--num_runs', default=1, type=int, help='Number of consecutive run with different seeds') 22 | parser.add_argument('--device', default="cuda", type=str, help='cpu or cuda') 23 | parser.add_argument('--exp_name', default='sweep_EXP1', type=str, help='experiment name') 24 | 25 | # ======== sweep settings ===================== 26 | parser.add_argument('--num_sweeps', default=1, type=str, help='Number of sweep runs') 27 | 28 | # We run sweeps using wandb plateform, so next parameters are for wandb. 29 | parser.add_argument('--sweep_project_wandb', default='ADATIME_refactor', type=str, help='Project name in Wandb') 30 | parser.add_argument('--wandb_entity', type=str, 31 | help='Entity name in Wandb (can be left blank if there is a default entity)') 32 | parser.add_argument('--hp_search_strategy', default="random", type=str, 33 | help='The way of selecting hyper-parameters (random-grid-bayes). in wandb see:https://docs.wandb.ai/guides/sweeps/configuration') 34 | parser.add_argument('--metric_to_minimize', default="src_risk", type=str, 35 | help='select one of: (src_risk - trg_risk - few_shot_trg_risk - dev_risk)') 36 | 37 | # ======== Experiments Name ================ 38 | parser.add_argument('--save_dir', default='experiments_logs/sweep_logs', type=str, 39 | help='Directory containing all experiments') 40 | 41 | args = parser.parse_args() 42 | 43 | trainer = Trainer(args) 44 | 45 | trainer.train() 46 | -------------------------------------------------------------------------------- /misc/adatime.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/misc/adatime.PNG -------------------------------------------------------------------------------- /misc/results.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/misc/results.PNG -------------------------------------------------------------------------------- /models/__pycache__/loss.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/models/__pycache__/loss.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/models/__pycache__/loss.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/models/__pycache__/loss.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/models.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/models/__pycache__/models.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/models.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/models/__pycache__/models.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/models.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/models/__pycache__/models.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet18.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/models/__pycache__/resnet18.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet18.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/models/__pycache__/resnet18.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet18.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emadeldeen24/AdaTime/2d9be6bd0542e150518afaf8aa0c6ea0823b4a3d/models/__pycache__/resnet18.cpython-39.pyc -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class ConditionalEntropyLoss(torch.nn.Module): 8 | def __init__(self): 9 | super(ConditionalEntropyLoss, self).__init__() 10 | 11 | def forward(self, x): 12 | b = F.softmax(x, dim=1) * F.log_softmax(x, dim=1) 13 | b = b.sum(dim=1) 14 | return -1.0 * b.mean(dim=0) 15 | 16 | 17 | class VAT(nn.Module): 18 | def __init__(self, model, device): 19 | super(VAT, self).__init__() 20 | self.n_power = 1 21 | self.XI = 1e-6 22 | self.model = model 23 | self.epsilon = 3.5 24 | self.device = device 25 | 26 | def forward(self, X, logit): 27 | vat_loss = self.virtual_adversarial_loss(X, logit) 28 | return vat_loss 29 | 30 | def generate_virtual_adversarial_perturbation(self, x, logit): 31 | d = torch.randn_like(x, device=self.device) 32 | 33 | for _ in range(self.n_power): 34 | d = self.XI * self.get_normalized_vector(d).requires_grad_() 35 | logit_m = self.model(x + d) 36 | dist = self.kl_divergence_with_logit(logit, logit_m) 37 | grad = torch.autograd.grad(dist, [d])[0] 38 | d = grad.detach() 39 | 40 | return self.epsilon * self.get_normalized_vector(d) 41 | 42 | def kl_divergence_with_logit(self, q_logit, p_logit): 43 | q = F.softmax(q_logit, dim=1) 44 | qlogq = torch.mean(torch.sum(q * F.log_softmax(q_logit, dim=1), dim=1)) 45 | qlogp = torch.mean(torch.sum(q * F.log_softmax(p_logit, dim=1), dim=1)) 46 | return qlogq - qlogp 47 | 48 | def get_normalized_vector(self, d): 49 | return F.normalize(d.view(d.size(0), -1), p=2, dim=1).reshape(d.size()) 50 | 51 | def virtual_adversarial_loss(self, x, logit): 52 | r_vadv = self.generate_virtual_adversarial_perturbation(x, logit) 53 | logit_p = logit.detach() 54 | logit_m = self.model(x + r_vadv) 55 | loss = self.kl_divergence_with_logit(logit_p, logit_m) 56 | return loss 57 | 58 | 59 | class MMD_loss(nn.Module): 60 | def __init__(self, kernel_type='rbf', kernel_mul=2.0, kernel_num=5): 61 | super(MMD_loss, self).__init__() 62 | self.kernel_num = kernel_num 63 | self.kernel_mul = kernel_mul 64 | self.fix_sigma = None 65 | self.kernel_type = kernel_type 66 | 67 | def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): 68 | n_samples = int(source.size()[0]) + int(target.size()[0]) 69 | total = torch.cat([source, target], dim=0) 70 | total0 = total.unsqueeze(0).expand( 71 | int(total.size(0)), int(total.size(0)), int(total.size(1))) 72 | total1 = total.unsqueeze(1).expand( 73 | int(total.size(0)), int(total.size(0)), int(total.size(1))) 74 | L2_distance = ((total0 - total1) ** 2).sum(2) 75 | if fix_sigma: 76 | bandwidth = fix_sigma 77 | else: 78 | bandwidth = torch.sum(L2_distance.data) / (n_samples ** 2 - n_samples) 79 | bandwidth /= kernel_mul ** (kernel_num // 2) 80 | bandwidth_list = [bandwidth * (kernel_mul ** i) 81 | for i in range(kernel_num)] 82 | kernel_val = [torch.exp(-L2_distance / bandwidth_temp) 83 | for bandwidth_temp in bandwidth_list] 84 | return sum(kernel_val) 85 | 86 | def linear_mmd2(self, f_of_X, f_of_Y): 87 | loss = 0.0 88 | delta = f_of_X.float().mean(0) - f_of_Y.float().mean(0) 89 | loss = delta.dot(delta.T) 90 | return loss 91 | 92 | def forward(self, source, target): 93 | if self.kernel_type == 'linear': 94 | return self.linear_mmd2(source, target) 95 | elif self.kernel_type == 'rbf': 96 | batch_size = int(source.size()[0]) 97 | kernels = self.guassian_kernel( 98 | source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma) 99 | with torch.no_grad(): 100 | XX = torch.mean(kernels[:batch_size, :batch_size]) 101 | YY = torch.mean(kernels[batch_size:, batch_size:]) 102 | XY = torch.mean(kernels[:batch_size, batch_size:]) 103 | YX = torch.mean(kernels[batch_size:, :batch_size]) 104 | loss = torch.mean(XX + YY - XY - YX) 105 | torch.cuda.empty_cache() 106 | return loss 107 | 108 | 109 | class CORAL(nn.Module): 110 | def __init__(self): 111 | super(CORAL, self).__init__() 112 | 113 | def forward(self, source, target): 114 | d = source.size(1) 115 | 116 | # source covariance 117 | xm = torch.mean(source, 0, keepdim=True) - source 118 | xc = xm.t() @ xm 119 | 120 | # target covariance 121 | xmt = torch.mean(target, 0, keepdim=True) - target 122 | xct = xmt.t() @ xmt 123 | 124 | # frobenius norm between source and target 125 | loss = torch.mean(torch.mul((xc - xct), (xc - xct))) 126 | loss = loss / (4 * d * d) 127 | return loss 128 | 129 | 130 | ### FOR DCAN ####################### 131 | def EntropyLoss(input_): 132 | mask = input_.ge(0.0000001) 133 | mask_out = torch.masked_select(input_, mask) 134 | entropy = - (torch.sum(mask_out * torch.log(mask_out))) 135 | return entropy / float(input_.size(0)) 136 | 137 | 138 | def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): 139 | n_samples = int(source.size()[0]) + int(target.size()[0]) 140 | total = torch.cat([source, target], dim=0) 141 | total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) 142 | total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) 143 | L2_distance = ((total0 - total1) ** 2).sum(2) 144 | if fix_sigma: 145 | bandwidth = fix_sigma 146 | else: 147 | bandwidth = torch.sum(L2_distance.data) / (n_samples ** 2 - n_samples) 148 | bandwidth /= kernel_mul ** (kernel_num // 2) 149 | bandwidth_list = [bandwidth * (kernel_mul ** i) for i in range(kernel_num)] 150 | kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list] 151 | return sum(kernel_val) # /len(kernel_val) 152 | 153 | 154 | def MMD(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): 155 | batch_size = int(source.size()[0]) 156 | kernels = guassian_kernel(source, target, 157 | kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma) 158 | loss = 0 159 | for i in range(batch_size): 160 | s1, s2 = i, (i + 1) % batch_size 161 | t1, t2 = s1 + batch_size, s2 + batch_size 162 | loss += kernels[s1, s2] + kernels[t1, t2] 163 | loss -= kernels[s1, t2] + kernels[s2, t1] 164 | return loss / float(batch_size) 165 | 166 | 167 | def MMD_reg(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): 168 | batch_size_source = int(source.size()[0]) 169 | batch_size_target = int(target.size()[0]) 170 | kernels = guassian_kernel(source, target, 171 | kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma) 172 | loss = 0 173 | for i in range(batch_size_source): 174 | s1, s2 = i, (i + 1) % batch_size_source 175 | t1, t2 = s1 + batch_size_target, s2 + batch_size_target 176 | loss += kernels[s1, s2] + kernels[t1, t2] 177 | loss -= kernels[s1, t2] + kernels[s2, t1] 178 | return loss / float(batch_size_source + batch_size_target) 179 | 180 | 181 | ### FOR HoMM ####################### 182 | class HoMM_loss(nn.Module): 183 | def __init__(self): 184 | super(HoMM_loss, self).__init__() 185 | 186 | def forward(self, xs, xt): 187 | xs = xs - torch.mean(xs, axis=0) 188 | xt = xt - torch.mean(xt, axis=0) 189 | xs = torch.unsqueeze(xs, axis=-1) 190 | xs = torch.unsqueeze(xs, axis=-1) 191 | xt = torch.unsqueeze(xt, axis=-1) 192 | xt = torch.unsqueeze(xt, axis=-1) 193 | xs_1 = xs.permute(0, 2, 1, 3) 194 | xs_2 = xs.permute(0, 2, 3, 1) 195 | xt_1 = xt.permute(0, 2, 1, 3) 196 | xt_2 = xt.permute(0, 2, 3, 1) 197 | HR_Xs = xs * xs_1 * xs_2 # dim: b*L*L*L 198 | HR_Xs = torch.mean(HR_Xs, axis=0) # dim: L*L*L 199 | HR_Xt = xt * xt_1 * xt_2 200 | HR_Xt = torch.mean(HR_Xt, axis=0) 201 | return torch.mean((HR_Xs - HR_Xt) ** 2) 202 | 203 | 204 | ### FOR DSAN ####################### 205 | class LMMD_loss(nn.Module): 206 | def __init__(self, device, class_num=3, kernel_type='rbf', kernel_mul=2.0, kernel_num=5, fix_sigma=None): 207 | super(LMMD_loss, self).__init__() 208 | self.class_num = class_num 209 | self.kernel_num = kernel_num 210 | self.kernel_mul = kernel_mul 211 | self.fix_sigma = fix_sigma 212 | self.kernel_type = kernel_type 213 | self.device = device 214 | 215 | def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): 216 | n_samples = int(source.size()[0]) + int(target.size()[0]) 217 | total = torch.cat([source, target], dim=0) 218 | total0 = total.unsqueeze(0).expand( 219 | int(total.size(0)), int(total.size(0)), int(total.size(1))) 220 | total1 = total.unsqueeze(1).expand( 221 | int(total.size(0)), int(total.size(0)), int(total.size(1))) 222 | L2_distance = ((total0 - total1) ** 2).sum(2) 223 | if fix_sigma: 224 | bandwidth = fix_sigma 225 | else: 226 | bandwidth = torch.sum(L2_distance.data) / (n_samples ** 2 - n_samples) 227 | bandwidth /= kernel_mul ** (kernel_num // 2) 228 | bandwidth_list = [bandwidth * (kernel_mul ** i) 229 | for i in range(kernel_num)] 230 | kernel_val = [torch.exp(-L2_distance / bandwidth_temp) 231 | for bandwidth_temp in bandwidth_list] 232 | return sum(kernel_val) 233 | 234 | def get_loss(self, source, target, s_label, t_label): 235 | batch_size = source.size()[0] 236 | weight_ss, weight_tt, weight_st = self.cal_weight( 237 | s_label, t_label, batch_size=batch_size, class_num=self.class_num) 238 | weight_ss = torch.from_numpy(weight_ss).to(self.device) 239 | weight_tt = torch.from_numpy(weight_tt).to(self.device) 240 | weight_st = torch.from_numpy(weight_st).to(self.device) 241 | 242 | kernels = self.guassian_kernel(source, target, 243 | kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma) 244 | loss = torch.Tensor([0]).to(self.device) 245 | if torch.sum(torch.isnan(sum(kernels))): 246 | return loss 247 | SS = kernels[:batch_size, :batch_size] 248 | TT = kernels[batch_size:, batch_size:] 249 | ST = kernels[:batch_size, batch_size:] 250 | 251 | loss += torch.sum(weight_ss * SS + weight_tt * TT - 2 * weight_st * ST) 252 | return loss 253 | 254 | def convert_to_onehot(self, sca_label, class_num=31): 255 | return np.eye(class_num)[sca_label] 256 | 257 | def cal_weight(self, s_label, t_label, batch_size=32, class_num=4): 258 | batch_size = s_label.size()[0] 259 | s_sca_label = s_label.cpu().data.numpy() 260 | s_vec_label = self.convert_to_onehot(s_sca_label, class_num=self.class_num) 261 | s_sum = np.sum(s_vec_label, axis=0).reshape(1, class_num) 262 | s_sum[s_sum == 0] = 100 263 | s_vec_label = s_vec_label / s_sum 264 | 265 | t_sca_label = t_label.cpu().data.max(1)[1].numpy() 266 | t_vec_label = t_label.cpu().data.numpy() 267 | t_sum = np.sum(t_vec_label, axis=0).reshape(1, class_num) 268 | t_sum[t_sum == 0] = 100 269 | t_vec_label = t_vec_label / t_sum 270 | 271 | index = list(set(s_sca_label) & set(t_sca_label)) 272 | mask_arr = np.zeros((batch_size, class_num)) 273 | mask_arr[:, index] = 1 274 | t_vec_label = t_vec_label * mask_arr 275 | s_vec_label = s_vec_label * mask_arr 276 | 277 | weight_ss = np.matmul(s_vec_label, s_vec_label.T) 278 | weight_tt = np.matmul(t_vec_label, t_vec_label.T) 279 | weight_st = np.matmul(s_vec_label, t_vec_label.T) 280 | 281 | length = len(index) 282 | if length != 0: 283 | weight_ss = weight_ss / length 284 | weight_tt = weight_tt / length 285 | weight_st = weight_st / length 286 | else: 287 | weight_ss = np.array([0]) 288 | weight_tt = np.array([0]) 289 | weight_st = np.array([0]) 290 | return weight_ss.astype('float32'), weight_tt.astype('float32'), weight_st.astype('float32') 291 | 292 | 293 | 294 | class NTXentLoss(torch.nn.Module): 295 | 296 | def __init__(self, device, batch_size, temperature, use_cosine_similarity): 297 | super(NTXentLoss, self).__init__() 298 | self.batch_size = batch_size 299 | self.temperature = temperature 300 | self.device = device 301 | self.softmax = torch.nn.Softmax(dim=-1) 302 | self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool) 303 | self.similarity_function = self._get_similarity_function(use_cosine_similarity) 304 | self.criterion = torch.nn.CrossEntropyLoss(reduction="sum") 305 | 306 | def _get_similarity_function(self, use_cosine_similarity): 307 | if use_cosine_similarity: 308 | self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1) 309 | return self._cosine_simililarity 310 | else: 311 | return self._dot_simililarity 312 | 313 | def _get_correlated_mask(self): 314 | diag = np.eye(2 * self.batch_size) 315 | l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size) 316 | l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size) 317 | mask = torch.from_numpy((diag + l1 + l2)) 318 | mask = (1 - mask).type(torch.bool) 319 | return mask.to(self.device) 320 | 321 | @staticmethod 322 | def _dot_simililarity(x, y): 323 | v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2) 324 | # x shape: (N, 1, C) 325 | # y shape: (1, C, 2N) 326 | # v shape: (N, 2N) 327 | return v 328 | 329 | def _cosine_simililarity(self, x, y): 330 | # x shape: (N, 1, C) 331 | # y shape: (1, 2N, C) 332 | # v shape: (N, 2N) 333 | v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0)) 334 | return v 335 | 336 | def forward(self, zis, zjs): 337 | representations = torch.cat([zjs, zis], dim=0) 338 | 339 | similarity_matrix = self.similarity_function(representations, representations) 340 | 341 | # filter out the scores from the positive samples 342 | l_pos = torch.diag(similarity_matrix, self.batch_size) 343 | r_pos = torch.diag(similarity_matrix, -self.batch_size) 344 | positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1) 345 | 346 | negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1) 347 | 348 | logits = torch.cat((positives, negatives), dim=1) 349 | logits /= self.temperature 350 | 351 | labels = torch.zeros(2 * self.batch_size).to(self.device).long() 352 | loss = self.criterion(logits, labels) 353 | 354 | return loss / (2 * self.batch_size) 355 | 356 | 357 | class SupConLoss(torch.nn.Module): 358 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 359 | It also supports the unsupervised contrastive loss in SimCLR""" 360 | 361 | def __init__(self, device, temperature=0.2, contrast_mode='all'): 362 | super(SupConLoss, self).__init__() 363 | self.temperature = temperature 364 | self.contrast_mode = contrast_mode 365 | self.device = device 366 | 367 | def forward(self, features, labels=None, mask=None): 368 | """Compute loss for model. If both `labels` and `mask` are None, 369 | it degenerates to SimCLR unsupervised loss: 370 | https://arxiv.org/pdf/2002.05709.pdf 371 | Args: 372 | features: hidden vector of shape [bsz, n_views, ...]. 373 | labels: ground truth of shape [bsz]. 374 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 375 | has the same class as sample i. Can be asymmetric. 376 | Returns: 377 | A loss scalar. 378 | """ 379 | device = self.device # 'cuda' #(torch.device('cuda') 380 | # if features.is_cuda 381 | # else torch.device('cpu')) 382 | 383 | if len(features.shape) < 3: 384 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 385 | 'at least 3 dimensions are required') 386 | if len(features.shape) > 3: 387 | features = features.view(features.shape[0], features.shape[1], -1) 388 | 389 | batch_size = features.shape[0] 390 | if labels is not None and mask is not None: 391 | raise ValueError('Cannot define both `labels` and `mask`') 392 | elif labels is None and mask is None: 393 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 394 | elif labels is not None: 395 | labels = labels.contiguous().view(-1, 1) 396 | if labels.shape[0] != batch_size: 397 | raise ValueError('Num of labels does not match num of features') 398 | mask = torch.eq(labels, labels.T).float().to(device) 399 | else: 400 | mask = mask.float().to(device) 401 | 402 | contrast_count = features.shape[1] 403 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 404 | if self.contrast_mode == 'one': 405 | anchor_feature = features[:, 0] 406 | anchor_count = 1 407 | elif self.contrast_mode == 'all': 408 | anchor_feature = contrast_feature 409 | anchor_count = contrast_count 410 | else: 411 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 412 | 413 | # compute logits 414 | anchor_dot_contrast = torch.div( 415 | torch.matmul(anchor_feature, contrast_feature.T), 416 | self.temperature) 417 | # for numerical stability 418 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 419 | logits = anchor_dot_contrast - logits_max.detach() 420 | 421 | # tile mask 422 | mask = mask.repeat(anchor_count, contrast_count) 423 | # mask-out self-contrast cases 424 | logits_mask = torch.scatter( 425 | torch.ones_like(mask), 426 | 1, 427 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 428 | 0 429 | ) 430 | mask = mask * logits_mask 431 | 432 | # compute log_prob 433 | exp_logits = torch.exp(logits) * logits_mask 434 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 435 | 436 | # compute mean of log-likelihood over positive 437 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 438 | 439 | # loss 440 | # loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos 441 | loss = - self.temperature * mean_log_prob_pos 442 | loss = loss.view(anchor_count, batch_size).mean() 443 | 444 | return loss 445 | -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import math 4 | from torch.autograd import Function 5 | from torch.nn.utils import weight_norm 6 | import torch.nn.functional as F 7 | from .resnet18 import resnet18 8 | 9 | 10 | # from utils import weights_init 11 | 12 | def get_backbone_class(backbone_name): 13 | """Return the algorithm class with the given name.""" 14 | if backbone_name not in globals(): 15 | raise NotImplementedError("Algorithm not found: {}".format(backbone_name)) 16 | return globals()[backbone_name] 17 | 18 | 19 | ################################################## 20 | ########## BACKBONE NETWORKS ################### 21 | ################################################## 22 | 23 | ########## CNN ############################# 24 | class CNN(nn.Module): 25 | def __init__(self, configs): 26 | super(CNN, self).__init__() 27 | 28 | self.conv_block1 = nn.Sequential( 29 | nn.Conv1d(configs.input_channels, configs.mid_channels, kernel_size=configs.kernel_size, 30 | stride=configs.stride, bias=False, padding=(configs.kernel_size // 2)), 31 | nn.BatchNorm1d(configs.mid_channels), 32 | nn.ReLU(), 33 | nn.MaxPool1d(kernel_size=2, stride=2, padding=1), 34 | nn.Dropout(configs.dropout) 35 | ) 36 | 37 | self.conv_block2 = nn.Sequential( 38 | nn.Conv1d(configs.mid_channels, configs.mid_channels * 2, kernel_size=8, stride=1, bias=False, padding=4), 39 | nn.BatchNorm1d(configs.mid_channels * 2), 40 | nn.ReLU(), 41 | nn.MaxPool1d(kernel_size=2, stride=2, padding=1) 42 | ) 43 | 44 | self.conv_block3 = nn.Sequential( 45 | nn.Conv1d(configs.mid_channels * 2, configs.final_out_channels, kernel_size=8, stride=1, bias=False, 46 | padding=4), 47 | nn.BatchNorm1d(configs.final_out_channels), 48 | nn.ReLU(), 49 | nn.MaxPool1d(kernel_size=2, stride=2, padding=1), 50 | ) 51 | 52 | self.adaptive_pool = nn.AdaptiveAvgPool1d(configs.features_len) 53 | 54 | def forward(self, x_in): 55 | x = self.conv_block1(x_in) 56 | x = self.conv_block2(x) 57 | x = self.conv_block3(x) 58 | x = self.adaptive_pool(x) 59 | 60 | x_flat = x.reshape(x.shape[0], -1) 61 | return x_flat 62 | 63 | 64 | 65 | class classifier(nn.Module): 66 | def __init__(self, configs): 67 | super(classifier, self).__init__() 68 | self.logits = nn.Linear(configs.features_len * configs.final_out_channels, configs.num_classes) 69 | self.configs = configs 70 | 71 | def forward(self, x): 72 | 73 | predictions = self.logits(x) 74 | 75 | return predictions 76 | 77 | 78 | 79 | ########## TCN ############################# 80 | torch.backends.cudnn.benchmark = True # might be required to fasten TCN 81 | 82 | 83 | class Chomp1d(nn.Module): 84 | def __init__(self, chomp_size): 85 | super(Chomp1d, self).__init__() 86 | self.chomp_size = chomp_size 87 | 88 | def forward(self, x): 89 | return x[:, :, :-self.chomp_size].contiguous() 90 | 91 | 92 | class TCN(nn.Module): 93 | def __init__(self, configs): 94 | super(TCN, self).__init__() 95 | 96 | in_channels0 = configs.input_channels 97 | out_channels0 = configs.tcn_layers[1] 98 | kernel_size = configs.tcn_kernel_size 99 | stride = 1 100 | dilation0 = 1 101 | padding0 = (kernel_size - 1) * dilation0 102 | 103 | self.net0 = nn.Sequential( 104 | weight_norm(nn.Conv1d(in_channels0, out_channels0, kernel_size, stride=stride, padding=padding0, 105 | dilation=dilation0)), 106 | nn.ReLU(), 107 | weight_norm(nn.Conv1d(out_channels0, out_channels0, kernel_size, stride=stride, padding=padding0, 108 | dilation=dilation0)), 109 | nn.ReLU(), 110 | ) 111 | 112 | self.downsample0 = nn.Conv1d(in_channels0, out_channels0, 1) if in_channels0 != out_channels0 else None 113 | self.relu = nn.ReLU() 114 | 115 | in_channels1 = configs.tcn_layers[0] 116 | out_channels1 = configs.tcn_layers[1] 117 | dilation1 = 2 118 | padding1 = (kernel_size - 1) * dilation1 119 | self.net1 = nn.Sequential( 120 | nn.Conv1d(in_channels0, out_channels1, kernel_size, stride=stride, padding=padding1, dilation=dilation1), 121 | nn.ReLU(), 122 | nn.Conv1d(out_channels1, out_channels1, kernel_size, stride=stride, padding=padding1, dilation=dilation1), 123 | nn.ReLU(), 124 | ) 125 | self.downsample1 = nn.Conv1d(out_channels1, out_channels1, 1) if in_channels1 != out_channels1 else None 126 | 127 | self.conv_block1 = nn.Sequential( 128 | nn.Conv1d(in_channels0, out_channels0, kernel_size=kernel_size, stride=stride, bias=False, padding=padding0, 129 | dilation=dilation0), 130 | Chomp1d(padding0), 131 | nn.BatchNorm1d(out_channels0), 132 | nn.ReLU(), 133 | 134 | nn.Conv1d(out_channels0, out_channels0, kernel_size=kernel_size, stride=stride, bias=False, 135 | padding=padding0, dilation=dilation0), 136 | Chomp1d(padding0), 137 | nn.BatchNorm1d(out_channels0), 138 | nn.ReLU(), 139 | ) 140 | 141 | self.conv_block2 = nn.Sequential( 142 | nn.Conv1d(out_channels0, out_channels1, kernel_size=kernel_size, stride=stride, bias=False, 143 | padding=padding1, dilation=dilation1), 144 | Chomp1d(padding1), 145 | nn.BatchNorm1d(out_channels1), 146 | nn.ReLU(), 147 | 148 | nn.Conv1d(out_channels1, out_channels1, kernel_size=kernel_size, stride=stride, bias=False, 149 | padding=padding1, dilation=dilation1), 150 | Chomp1d(padding1), 151 | nn.BatchNorm1d(out_channels1), 152 | nn.ReLU(), 153 | ) 154 | 155 | def forward(self, inputs): 156 | """Inputs have to have dimension (N, C_in, L_in)""" 157 | x0 = self.conv_block1(inputs) 158 | res0 = inputs if self.downsample0 is None else self.downsample0(inputs) 159 | out_0 = self.relu(x0 + res0) 160 | 161 | x1 = self.conv_block2(out_0) 162 | res1 = out_0 if self.downsample1 is None else self.downsample1(out_0) 163 | out_1 = self.relu(x1 + res1) 164 | 165 | out = out_1[:, :, -1] 166 | return out 167 | 168 | 169 | ######## RESNET ############################################## 170 | 171 | class RESNET18(nn.Module): 172 | def __init__(self, configs): 173 | super(RESNET18, self).__init__() 174 | self.resnet = resnet18(configs) 175 | def forward(self, x_in): 176 | x = self.resnet(x_in) 177 | x_flat = x.reshape(x.shape[0], -1) 178 | return x_flat 179 | 180 | class BasicBlock(nn.Module): 181 | expansion = 1 182 | 183 | def __init__(self, inplanes, planes, stride=1, downsample=None): 184 | super(BasicBlock, self).__init__() 185 | self.conv1 = nn.Conv1d(inplanes, planes, kernel_size=1, stride=stride, 186 | bias=False) 187 | self.bn1 = nn.BatchNorm1d(planes) 188 | 189 | self.downsample = downsample 190 | self.stride = stride 191 | 192 | def forward(self, x): 193 | residual = x 194 | 195 | out = self.conv1(x) 196 | out = self.bn1(out) 197 | out = F.relu(out) 198 | 199 | if self.downsample is not None: 200 | residual = self.downsample(x) 201 | 202 | out += residual 203 | out = F.relu(out) 204 | 205 | return out 206 | 207 | 208 | ################################################## 209 | ########## OTHER NETWORKS ###################### 210 | ################################################## 211 | 212 | class codats_classifier(nn.Module): 213 | def __init__(self, configs): 214 | super(codats_classifier, self).__init__() 215 | model_output_dim = configs.features_len 216 | self.hidden_dim = configs.hidden_dim 217 | self.logits = nn.Sequential( 218 | nn.Linear(model_output_dim * configs.final_out_channels, self.hidden_dim), 219 | nn.ReLU(), 220 | nn.Linear(self.hidden_dim, self.hidden_dim), 221 | nn.ReLU(), 222 | nn.Linear(self.hidden_dim, configs.num_classes)) 223 | 224 | def forward(self, x_in): 225 | predictions = self.logits(x_in) 226 | return predictions 227 | 228 | 229 | class Discriminator(nn.Module): 230 | """Discriminator model for source domain.""" 231 | 232 | def __init__(self, configs): 233 | """Init discriminator.""" 234 | super(Discriminator, self).__init__() 235 | 236 | self.layer = nn.Sequential( 237 | nn.Linear(configs.features_len * configs.final_out_channels, configs.disc_hid_dim), 238 | nn.ReLU(), 239 | nn.Linear(configs.disc_hid_dim, configs.disc_hid_dim), 240 | nn.ReLU(), 241 | nn.Linear(configs.disc_hid_dim, 2) 242 | # nn.LogSoftmax(dim=1) 243 | ) 244 | 245 | def forward(self, input): 246 | """Forward the discriminator.""" 247 | out = self.layer(input) 248 | return out 249 | 250 | 251 | #### Codes required by DANN ############## 252 | class ReverseLayerF(Function): 253 | @staticmethod 254 | def forward(ctx, x, alpha): 255 | ctx.alpha = alpha 256 | return x.view_as(x) 257 | 258 | @staticmethod 259 | def backward(ctx, grad_output): 260 | output = grad_output.neg() * ctx.alpha 261 | return output, None 262 | 263 | 264 | #### Codes required by CDAN ############## 265 | class RandomLayer(nn.Module): 266 | def __init__(self, input_dim_list=[], output_dim=1024): 267 | super(RandomLayer, self).__init__() 268 | self.input_num = len(input_dim_list) 269 | self.output_dim = output_dim 270 | self.random_matrix = [torch.randn(input_dim_list[i], output_dim) for i in range(self.input_num)] 271 | 272 | def forward(self, input_list): 273 | return_list = [torch.mm(input_list[i], self.random_matrix[i]) for i in range(self.input_num)] 274 | return_tensor = return_list[0] / math.pow(float(self.output_dim), 1.0 / len(return_list)) 275 | for single in return_list[1:]: 276 | return_tensor = torch.mul(return_tensor, single) 277 | return return_tensor 278 | 279 | def cuda(self): 280 | super(RandomLayer, self).cuda() 281 | self.random_matrix = [val.cuda() for val in self.random_matrix] 282 | 283 | 284 | class Discriminator_CDAN(nn.Module): 285 | """Discriminator model for CDAN .""" 286 | 287 | def __init__(self, configs): 288 | """Init discriminator.""" 289 | super(Discriminator_CDAN, self).__init__() 290 | 291 | self.restored = False 292 | 293 | self.layer = nn.Sequential( 294 | nn.Linear(configs.features_len * configs.final_out_channels * configs.num_classes, configs.disc_hid_dim), 295 | nn.ReLU(), 296 | nn.Linear(configs.disc_hid_dim, configs.disc_hid_dim), 297 | nn.ReLU(), 298 | nn.Linear(configs.disc_hid_dim, 2) 299 | # nn.LogSoftmax(dim=1) 300 | ) 301 | 302 | def forward(self, input): 303 | """Forward the discriminator.""" 304 | out = self.layer(input) 305 | return out 306 | 307 | 308 | #### Codes required by AdvSKM ############## 309 | class Cosine_act(nn.Module): 310 | def __init__(self): 311 | super(Cosine_act, self).__init__() 312 | 313 | def forward(self, input): 314 | return torch.cos(input) 315 | 316 | 317 | cos_act = Cosine_act() 318 | 319 | class AdvSKM_Disc(nn.Module): 320 | """Discriminator model for source domain.""" 321 | 322 | def __init__(self, configs): 323 | """Init discriminator.""" 324 | super(AdvSKM_Disc, self).__init__() 325 | 326 | self.input_dim = configs.features_len * configs.final_out_channels 327 | self.hid_dim = configs.DSKN_disc_hid 328 | self.branch_1 = nn.Sequential( 329 | nn.Linear(self.input_dim, self.hid_dim), 330 | nn.Linear(self.hid_dim, self.hid_dim), 331 | nn.BatchNorm1d(self.hid_dim), 332 | cos_act, 333 | nn.Linear(self.hid_dim, self.hid_dim // 2), 334 | nn.Linear(self.hid_dim // 2, self.hid_dim // 2), 335 | nn.BatchNorm1d(self.hid_dim // 2), 336 | cos_act 337 | ) 338 | self.branch_2 = nn.Sequential( 339 | nn.Linear(configs.features_len * configs.final_out_channels, configs.disc_hid_dim), 340 | nn.Linear(configs.disc_hid_dim, configs.disc_hid_dim), 341 | nn.BatchNorm1d(configs.disc_hid_dim), 342 | nn.ReLU(), 343 | nn.Linear(configs.disc_hid_dim, configs.disc_hid_dim // 2), 344 | nn.Linear(configs.disc_hid_dim // 2, configs.disc_hid_dim // 2), 345 | nn.BatchNorm1d(configs.disc_hid_dim // 2), 346 | nn.ReLU()) 347 | 348 | def forward(self, input): 349 | """Forward the discriminator.""" 350 | out_cos = self.branch_1(input) 351 | out_rel = self.branch_2(input) 352 | total_out = torch.cat((out_cos, out_rel), dim=1) 353 | return total_out 354 | 355 | # SASA model 356 | class CNN_ATTN(nn.Module): 357 | def __init__(self, configs): 358 | super(CNN_ATTN, self).__init__() 359 | 360 | self.conv_block1 = nn.Sequential( 361 | nn.Conv1d(configs.input_channels, configs.mid_channels, kernel_size=configs.kernel_size, 362 | stride=configs.stride, bias=False, padding=(configs.kernel_size // 2)), 363 | nn.BatchNorm1d(configs.mid_channels), 364 | nn.ReLU(), 365 | nn.MaxPool1d(kernel_size=2, stride=2, padding=1), 366 | nn.Dropout(configs.dropout) 367 | ) 368 | 369 | self.conv_block2 = nn.Sequential( 370 | nn.Conv1d(configs.mid_channels, configs.mid_channels * 2, kernel_size=8, stride=1, bias=False, padding=4), 371 | nn.BatchNorm1d(configs.mid_channels * 2), 372 | nn.ReLU(), 373 | nn.MaxPool1d(kernel_size=2, stride=2, padding=1) 374 | ) 375 | 376 | self.conv_block3 = nn.Sequential( 377 | nn.Conv1d(configs.mid_channels * 2, configs.final_out_channels, kernel_size=8, stride=1, bias=False, 378 | padding=4), 379 | nn.BatchNorm1d(configs.final_out_channels), 380 | nn.ReLU(), 381 | nn.MaxPool1d(kernel_size=2, stride=2, padding=1), 382 | ) 383 | 384 | self.adaptive_pool = nn.AdaptiveAvgPool1d(configs.features_len) 385 | self.attn_network = attn_network(configs) 386 | self.sparse_max = Sparsemax(dim=-1) 387 | self.feat_len = configs.features_len 388 | 389 | def forward(self, x_in): 390 | x = self.conv_block1(x_in) 391 | x = self.conv_block2(x) 392 | x = self.conv_block3(x) 393 | x = self.adaptive_pool(x) 394 | x_flat = x.reshape(x.shape[0], -1) 395 | attentive_feat = self.calculate_attentive_feat(x_flat) 396 | return attentive_feat 397 | 398 | def self_attention(self, Q, K, scale=True, sparse=True, k=3): 399 | 400 | attention_weight = torch.bmm(Q.view(Q.shape[0], self.feat_len, -1), K.view(K.shape[0], -1, self.feat_len)) 401 | 402 | attention_weight = torch.mean(attention_weight, dim=2, keepdim=True) 403 | 404 | if scale: 405 | d_k = torch.tensor(K.shape[-1]).float() 406 | attention_weight = attention_weight / torch.sqrt(d_k) 407 | if sparse: 408 | attention_weight_sparse = self.sparse_max(torch.reshape(attention_weight, [-1, self.feat_len])) 409 | attention_weight = torch.reshape(attention_weight_sparse, [-1, attention_weight.shape[1], 410 | attention_weight.shape[2]]) 411 | else: 412 | attention_weight = self.softmax(attention_weight) 413 | 414 | return attention_weight 415 | 416 | def attention_fn(self, Q, K, scaled=False, sparse=True, k=1): 417 | 418 | attention_weight = torch.matmul(F.normalize(Q, p=2, dim=-1), 419 | F.normalize(K, p=2, dim=-1).view(K.shape[0], K.shape[1], -1, self.feat_len)) 420 | 421 | if scaled: 422 | d_k = torch.tensor(K.shape[-1]).float() 423 | attention_weight = attention_weight / torch.sqrt(d_k) 424 | attention_weight = k * torch.log(torch.tensor(self.feat_len, dtype=torch.float32)) * attention_weight 425 | 426 | if sparse: 427 | attention_weight_sparse = self.sparse_max(torch.reshape(attention_weight, [-1, self.feat_len])) 428 | 429 | attention_weight = torch.reshape(attention_weight_sparse, attention_weight.shape) 430 | else: 431 | attention_weight = self.softmax(attention_weight) 432 | 433 | return attention_weight 434 | 435 | def calculate_attentive_feat(self, candidate_representation_xi): 436 | Q_xi, K_xi, V_xi = self.attn_network(candidate_representation_xi) 437 | intra_attention_weight_xi = self.self_attention(Q=Q_xi, K=K_xi, sparse=True) 438 | Z_i = torch.bmm(intra_attention_weight_xi.view(intra_attention_weight_xi.shape[0], 1, -1), 439 | V_xi.view(V_xi.shape[0], self.feat_len, -1)) 440 | final_feature = F.normalize(Z_i, dim=-1).view(Z_i.shape[0],-1) 441 | 442 | return final_feature 443 | 444 | class attn_network(nn.Module): 445 | def __init__(self, configs): 446 | super(attn_network, self).__init__() 447 | 448 | self.h_dim = configs.features_len * configs.final_out_channels 449 | self.self_attn_Q = nn.Sequential(nn.Linear(in_features=self.h_dim, out_features=self.h_dim), 450 | nn.ELU() 451 | ) 452 | self.self_attn_K = nn.Sequential(nn.Linear(in_features=self.h_dim, out_features=self.h_dim), 453 | nn.LeakyReLU() 454 | ) 455 | self.self_attn_V = nn.Sequential(nn.Linear(in_features=self.h_dim, out_features=self.h_dim), 456 | nn.LeakyReLU() 457 | ) 458 | 459 | def forward(self, x): 460 | Q = self.self_attn_Q(x) 461 | K = self.self_attn_K(x) 462 | V = self.self_attn_V(x) 463 | 464 | return Q, K, V 465 | 466 | 467 | # Sparse max 468 | class Sparsemax(nn.Module): 469 | """Sparsemax function.""" 470 | 471 | def __init__(self, dim=None): 472 | """Initialize sparsemax activation 473 | 474 | Args: 475 | dim (int, optional): The dimension over which to apply the sparsemax function. 476 | """ 477 | super(Sparsemax, self).__init__() 478 | 479 | self.dim = -1 if dim is None else dim 480 | 481 | def forward(self, input): 482 | """Forward function. 483 | Args: 484 | input (torch.Tensor): Input tensor. First dimension should be the batch size 485 | Returns: 486 | torch.Tensor: [batch_size x number_of_logits] Output tensor 487 | """ 488 | # Sparsemax currently only handles 2-dim tensors, 489 | # so we reshape to a convenient shape and reshape back after sparsemax 490 | input = input.transpose(0, self.dim) 491 | original_size = input.size() 492 | input = input.reshape(input.size(0), -1) 493 | input = input.transpose(0, 1) 494 | dim = 1 495 | 496 | number_of_logits = input.size(dim) 497 | 498 | # Translate input by max for numerical stability 499 | input = input - torch.max(input, dim=dim, keepdim=True)[0].expand_as(input) 500 | 501 | # Sort input in descending order. 502 | # (NOTE: Can be replaced with linear time selection method described here: 503 | # http://stanford.edu/~jduchi/projects/DuchiShSiCh08.html) 504 | zs = torch.sort(input=input, dim=dim, descending=True)[0] 505 | range = torch.arange(start=1, end=number_of_logits + 1, step=1, device=input.device, dtype=input.dtype).view(1, 506 | -1) 507 | range = range.expand_as(zs) 508 | 509 | # Determine sparsity of projection 510 | bound = 1 + range * zs 511 | cumulative_sum_zs = torch.cumsum(zs, dim) 512 | is_gt = torch.gt(bound, cumulative_sum_zs).type(input.type()) 513 | k = torch.max(is_gt * range, dim, keepdim=True)[0] 514 | 515 | # Compute threshold function 516 | zs_sparse = is_gt * zs 517 | 518 | # Compute taus 519 | taus = (torch.sum(zs_sparse, dim, keepdim=True) - 1) / k 520 | taus = taus.expand_as(input) 521 | 522 | # Sparsemax 523 | self.output = torch.max(torch.zeros_like(input), input - taus) 524 | 525 | # Reshape back to original shape 526 | output = self.output 527 | output = output.transpose(0, 1) 528 | output = output.reshape(original_size) 529 | output = output.transpose(0, self.dim) 530 | 531 | return output 532 | 533 | def backward(self, grad_output): 534 | """Backward function.""" 535 | dim = 1 536 | 537 | nonzeros = torch.ne(self.output, 0) 538 | sum = torch.sum(grad_output * nonzeros, dim=dim) / torch.sum(nonzeros, dim=dim) 539 | self.grad_input = nonzeros * (grad_output - sum.expand_as(grad_output)) 540 | 541 | return self.grad_input 542 | -------------------------------------------------------------------------------- /models/resnet18.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Any, Callable, List, Optional, Type, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch import Tensor 7 | 8 | 9 | 10 | 11 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv1d: 12 | """3x3 convolution with padding""" 13 | return nn.Conv1d( 14 | in_planes, 15 | out_planes, 16 | kernel_size=3, 17 | stride=stride, 18 | padding=dilation, 19 | groups=groups, 20 | bias=False, 21 | dilation=dilation 22 | ) 23 | 24 | 25 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv1d: 26 | """1x1 convolution""" 27 | return nn.Conv1d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 28 | 29 | 30 | class BasicBlock(nn.Module): 31 | expansion: int = 1 32 | 33 | def __init__( 34 | self, 35 | inplanes: int, 36 | planes: int, 37 | stride: int = 1, 38 | downsample: Optional[nn.Module] = None, 39 | groups: int = 1, 40 | base_width: int = 64, 41 | dilation: int = 1, 42 | norm_layer: Optional[Callable[..., nn.Module]] = None 43 | ) -> None: 44 | super().__init__() 45 | if norm_layer is None: 46 | norm_layer = nn.BatchNorm1d 47 | if groups != 1 or base_width != 64: 48 | raise ValueError("BasicBlock only supports groups=1 and base_width=64") 49 | if dilation > 1: 50 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 51 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 52 | self.conv1 = conv3x3(inplanes, planes, stride) 53 | self.bn1 = norm_layer(planes) 54 | self.relu = nn.ReLU(inplace=True) 55 | self.conv2 = conv3x3(planes, planes) 56 | self.bn2 = norm_layer(planes) 57 | self.downsample = downsample 58 | self.stride = stride 59 | 60 | def forward(self, x: Tensor) -> Tensor: 61 | identity = x 62 | 63 | out = self.conv1(x) 64 | out = self.bn1(out) 65 | out = self.relu(out) 66 | 67 | out = self.conv2(out) 68 | out = self.bn2(out) 69 | 70 | if self.downsample is not None: 71 | identity = self.downsample(x) 72 | 73 | out += identity 74 | out = self.relu(out) 75 | 76 | return out 77 | 78 | 79 | class ResNet(nn.Module): 80 | def __init__( 81 | self, 82 | block: Type[BasicBlock], 83 | layers: List[int], 84 | configs, 85 | num_classes: int = 5, 86 | zero_init_residual: bool = False, 87 | groups: int = 1, 88 | width_per_group: int = 64, 89 | replace_stride_with_dilation: Optional[List[bool]] = None, 90 | norm_layer: Optional[Callable[..., nn.Module]] = None 91 | ) -> None: 92 | super().__init__() 93 | if norm_layer is None: 94 | norm_layer = nn.BatchNorm1d 95 | self._norm_layer = norm_layer 96 | 97 | self.inplanes = 64 98 | self.dilation = 1 99 | if replace_stride_with_dilation is None: 100 | # each element in the tuple indicates if we should replace 101 | # the 2x2 stride with a dilated convolution instead 102 | replace_stride_with_dilation = [False, False, False] 103 | if len(replace_stride_with_dilation) != 3: 104 | raise ValueError( 105 | "replace_stride_with_dilation should be None " 106 | f"or a 3-element tuple, got {replace_stride_with_dilation}" 107 | ) 108 | self.groups = groups 109 | self.base_width = width_per_group 110 | self.conv1 = nn.Conv1d(configs.input_channels, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 111 | self.bn1 = norm_layer(self.inplanes) 112 | self.relu = nn.ReLU(inplace=True) 113 | self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) 114 | self.layer1 = self._make_layer(block, 64, layers[0]) 115 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) 116 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) 117 | self.layer4 = self._make_layer(block, configs.final_out_channels, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) 118 | 119 | self.avgpool = nn.AdaptiveAvgPool1d(1) 120 | self.fc = nn.Linear(512 * block.expansion, num_classes) 121 | 122 | for m in self.modules(): 123 | if isinstance(m, nn.Conv1d): 124 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 125 | elif isinstance(m, (nn.BatchNorm1d, nn.GroupNorm)): 126 | nn.init.constant_(m.weight, 1) 127 | nn.init.constant_(m.bias, 0) 128 | 129 | # Zero-initialize the last BN in each residual branch, 130 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 131 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 132 | if zero_init_residual: 133 | for m in self.modules(): 134 | if isinstance(m, Bottleneck) and m.bn3.weight is not None: 135 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 136 | elif isinstance(m, BasicBlock) and m.bn2.weight is not None: 137 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 138 | 139 | def _make_layer( 140 | self, 141 | block: Type[BasicBlock], 142 | planes: int, 143 | blocks: int, 144 | stride: int = 1, 145 | dilate: bool = False 146 | ) -> nn.Sequential: 147 | norm_layer = self._norm_layer 148 | downsample = None 149 | previous_dilation = self.dilation 150 | if dilate: 151 | self.dilation *= stride 152 | stride = 1 153 | if stride != 1 or self.inplanes != planes * block.expansion: 154 | downsample = nn.Sequential( 155 | conv1x1(self.inplanes, planes * block.expansion, stride), 156 | norm_layer(planes * block.expansion) 157 | ) 158 | 159 | layers = [] 160 | layers.append( 161 | block( 162 | self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer 163 | ) 164 | ) 165 | self.inplanes = planes * block.expansion 166 | for _ in range(1, blocks): 167 | layers.append( 168 | block( 169 | self.inplanes, 170 | planes, 171 | groups=self.groups, 172 | base_width=self.base_width, 173 | dilation=self.dilation, 174 | norm_layer=norm_layer 175 | ) 176 | ) 177 | 178 | return nn.Sequential(*layers) 179 | 180 | def _forward_impl(self, x: Tensor) -> Tensor: 181 | # See note [TorchScript super()] 182 | x = self.conv1(x) 183 | x = self.bn1(x) 184 | x = self.relu(x) 185 | x = self.maxpool(x) 186 | 187 | x = self.layer1(x) 188 | x = self.layer2(x) 189 | x = self.layer3(x) 190 | x = self.layer4(x) 191 | 192 | x = self.avgpool(x) 193 | # x = torch.flatten(x, 1) 194 | # x = self.fc(x) 195 | 196 | return x 197 | 198 | def forward(self, x: Tensor) -> Tensor: 199 | return self._forward_impl(x) 200 | 201 | 202 | def resnet18(configs) -> ResNet: 203 | layers = [2, 2, 2, 2] 204 | model = ResNet(BasicBlock, layers, configs) 205 | 206 | return model 207 | -------------------------------------------------------------------------------- /trainers/abstract_trainer.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../../ADATIME/') 3 | import torch 4 | import torch.nn.functional as F 5 | from torchmetrics import Accuracy, AUROC, F1Score 6 | import os 7 | import wandb 8 | import pandas as pd 9 | import numpy as np 10 | import warnings 11 | import sklearn.exceptions 12 | import collections 13 | 14 | from torchmetrics import Accuracy, AUROC, F1Score 15 | from dataloader.dataloader import data_generator, few_shot_data_generator 16 | from configs.data_model_configs import get_dataset_class 17 | from configs.hparams import get_hparams_class 18 | from configs.sweep_params import sweep_alg_hparams 19 | from utils import fix_randomness, starting_logs, DictAsObject,AverageMeter 20 | from algorithms.algorithms import get_algorithm_class 21 | from models.models import get_backbone_class 22 | 23 | warnings.filterwarnings("ignore", category=sklearn.exceptions.UndefinedMetricWarning) 24 | 25 | class AbstractTrainer(object): 26 | """ 27 | This class contain the main training functions for our AdAtime 28 | """ 29 | 30 | def __init__(self, args): 31 | self.da_method = args.da_method # Selected DA Method 32 | self.dataset = args.dataset # Selected Dataset 33 | self.backbone = args.backbone 34 | self.device = torch.device(args.device) # device 35 | 36 | # Exp Description 37 | self.experiment_description = args.dataset 38 | self.run_description = f"{args.da_method}_{args.exp_name}" 39 | 40 | # paths 41 | self.home_path = os.getcwd() #os.path.dirname(os.getcwd()) 42 | self.save_dir = args.save_dir 43 | self.data_path = os.path.join(args.data_path, self.dataset) 44 | # self.create_save_dir(os.path.join(self.home_path, self.save_dir )) 45 | self.exp_log_dir = os.path.join(self.home_path, self.save_dir, self.experiment_description, f"{self.run_description}") 46 | os.makedirs(self.exp_log_dir, exist_ok=True) 47 | 48 | 49 | 50 | 51 | # Specify runs 52 | self.num_runs = args.num_runs 53 | 54 | # get dataset and base model configs 55 | self.dataset_configs, self.hparams_class = self.get_configs() 56 | 57 | # to fix dimension of features in classifier and discriminator networks. 58 | self.dataset_configs.final_out_channels = self.dataset_configs.tcn_final_out_channles if args.backbone == "TCN" else self.dataset_configs.final_out_channels 59 | 60 | # Specify number of hparams 61 | self.hparams = {**self.hparams_class.alg_hparams[self.da_method], 62 | **self.hparams_class.train_params} 63 | 64 | # metrics 65 | self.num_classes = self.dataset_configs.num_classes 66 | self.ACC = Accuracy(task="multiclass", num_classes=self.num_classes) 67 | self.F1 = F1Score(task="multiclass", num_classes=self.num_classes, average="macro") 68 | self.AUROC = AUROC(task="multiclass", num_classes=self.num_classes) 69 | 70 | # metrics 71 | 72 | def sweep(self): 73 | # sweep configurations 74 | pass 75 | 76 | def initialize_algorithm(self): 77 | # get algorithm class 78 | algorithm_class = get_algorithm_class(self.da_method) 79 | backbone_fe = get_backbone_class(self.backbone) 80 | 81 | # Initilaize the algorithm 82 | self.algorithm = algorithm_class(backbone_fe, self.dataset_configs, self.hparams, self.device) 83 | self.algorithm.to(self.device) 84 | 85 | def load_checkpoint(self, model_dir): 86 | checkpoint = torch.load(os.path.join(self.home_path, model_dir, 'checkpoint.pt')) 87 | last_model = checkpoint['last'] 88 | best_model = checkpoint['best'] 89 | return last_model, best_model 90 | 91 | def train_model(self): 92 | # Get the algorithm and the backbone network 93 | algorithm_class = get_algorithm_class(self.da_method) 94 | backbone_fe = get_backbone_class(self.backbone) 95 | 96 | # Initilaize the algorithm 97 | self.algorithm = algorithm_class(backbone_fe, self.dataset_configs, self.hparams, self.device) 98 | self.algorithm.to(self.device) 99 | 100 | # Training the model 101 | self.last_model, self.best_model = self.algorithm.update(self.src_train_dl, self.trg_train_dl, self.loss_avg_meters, self.logger) 102 | return self.last_model, self.best_model 103 | 104 | def evaluate(self, test_loader): 105 | feature_extractor = self.algorithm.feature_extractor.to(self.device) 106 | classifier = self.algorithm.classifier.to(self.device) 107 | 108 | feature_extractor.eval() 109 | classifier.eval() 110 | 111 | total_loss, preds_list, labels_list = [], [], [] 112 | 113 | with torch.no_grad(): 114 | for data, labels in test_loader: 115 | data = data.float().to(self.device) 116 | labels = labels.view((-1)).long().to(self.device) 117 | 118 | # forward pass 119 | features = feature_extractor(data) 120 | predictions = classifier(features) 121 | 122 | # compute loss 123 | loss = F.cross_entropy(predictions, labels) 124 | total_loss.append(loss.item()) 125 | pred = predictions.detach() # .argmax(dim=1) # get the index of the max log-probability 126 | 127 | # append predictions and labels 128 | preds_list.append(pred) 129 | labels_list.append(labels) 130 | 131 | self.loss = torch.tensor(total_loss).mean() # average loss 132 | self.full_preds = torch.cat((preds_list)) 133 | self.full_labels = torch.cat((labels_list)) 134 | 135 | def get_configs(self): 136 | dataset_class = get_dataset_class(self.dataset) 137 | hparams_class = get_hparams_class(self.dataset) 138 | return dataset_class(), hparams_class() 139 | 140 | def load_data(self, src_id, trg_id): 141 | self.src_train_dl = data_generator(self.data_path, src_id, self.dataset_configs, self.hparams, "train") 142 | self.src_test_dl = data_generator(self.data_path, src_id, self.dataset_configs, self.hparams, "test") 143 | 144 | self.trg_train_dl = data_generator(self.data_path, trg_id, self.dataset_configs, self.hparams, "train") 145 | self.trg_test_dl = data_generator(self.data_path, trg_id, self.dataset_configs, self.hparams, "test") 146 | 147 | self.few_shot_dl_5 = few_shot_data_generator(self.trg_test_dl, self.dataset_configs, 148 | 5) # set 5 to other value if you want other k-shot FST 149 | 150 | def create_save_dir(self, save_dir): 151 | if not os.path.exists(save_dir): 152 | os.mkdir(save_dir) 153 | 154 | def calculate_metrics_risks(self): 155 | # calculation based source test data 156 | self.evaluate(self.src_test_dl) 157 | src_risk = self.loss.item() 158 | # calculation based few_shot test data 159 | self.evaluate(self.few_shot_dl_5) 160 | fst_risk = self.loss.item() 161 | # calculation based target test data 162 | self.evaluate(self.trg_test_dl) 163 | trg_risk = self.loss.item() 164 | 165 | # calculate metrics 166 | acc = self.ACC(self.full_preds.argmax(dim=1).cpu(), self.full_labels.cpu()).item() 167 | # f1_torch 168 | f1 = self.F1(self.full_preds.argmax(dim=1).cpu(), self.full_labels.cpu()).item() 169 | auroc = self.AUROC(self.full_preds.cpu(), self.full_labels.cpu()).item() 170 | # f1_sk learn 171 | # f1 = f1_score(self.full_preds.argmax(dim=1).cpu().numpy(), self.full_labels.cpu().numpy(), average='macro') 172 | 173 | risks = src_risk, fst_risk, trg_risk 174 | metrics = acc, f1, auroc 175 | 176 | return risks, metrics 177 | 178 | def save_tables_to_file(self,table_results, name): 179 | # save to file if needed 180 | table_results.to_csv(os.path.join(self.exp_log_dir,f"{name}.csv")) 181 | 182 | def save_checkpoint(self, home_path, log_dir, last_model, best_model): 183 | save_dict = { 184 | "last": last_model, 185 | "best": best_model 186 | } 187 | # save classification report 188 | save_path = os.path.join(home_path, log_dir, f"checkpoint.pt") 189 | torch.save(save_dict, save_path) 190 | 191 | def calculate_avg_std_wandb_table(self, results): 192 | 193 | avg_metrics = [np.mean(results.get_column(metric)) for metric in results.columns[2:]] 194 | std_metrics = [np.std(results.get_column(metric)) for metric in results.columns[2:]] 195 | summary_metrics = {metric: np.mean(results.get_column(metric)) for metric in results.columns[2:]} 196 | 197 | results.add_data('mean', '-', *avg_metrics) 198 | results.add_data('std', '-', *std_metrics) 199 | 200 | return results, summary_metrics 201 | 202 | def log_summary_metrics_wandb(self, results, risks): 203 | 204 | # Calculate average and standard deviation for metrics 205 | avg_metrics = [np.mean(results.get_column(metric)) for metric in results.columns[2:]] 206 | std_metrics = [np.std(results.get_column(metric)) for metric in results.columns[2:]] 207 | 208 | avg_risks = [np.mean(risks.get_column(risk)) for risk in risks.columns[2:]] 209 | std_risks = [np.std(risks.get_column(risk)) for risk in risks.columns[2:]] 210 | 211 | # Estimate summary metrics 212 | summary_metrics = {metric: np.mean(results.get_column(metric)) for metric in results.columns[2:]} 213 | summary_risks = {risk: np.mean(risks.get_column(risk)) for risk in risks.columns[2:]} 214 | 215 | 216 | # append avg and std values to metrics 217 | results.add_data('mean', '-', *avg_metrics) 218 | results.add_data('std', '-', *std_metrics) 219 | 220 | # append avg and std values to risks 221 | results.add_data('mean', '-', *avg_risks) 222 | risks.add_data('std', '-', *std_risks) 223 | 224 | def wandb_logging(self, total_results, total_risks, summary_metrics, summary_risks): 225 | # log wandb 226 | wandb.log({'results': total_results}) 227 | wandb.log({'risks': total_risks}) 228 | wandb.log({'hparams': wandb.Table(dataframe=pd.DataFrame(dict(self.hparams).items(), columns=['parameter', 'value']), allow_mixed_types=True)}) 229 | wandb.log(summary_metrics) 230 | wandb.log(summary_risks) 231 | 232 | def calculate_metrics(self): 233 | 234 | self.evaluate(self.trg_test_dl) 235 | # accuracy 236 | acc = self.ACC(self.full_preds.argmax(dim=1).cpu(), self.full_labels.cpu()).item() 237 | # f1 238 | f1 = self.F1(self.full_preds.argmax(dim=1).cpu(), self.full_labels.cpu()).item() 239 | # auroc 240 | auroc = self.AUROC(self.full_preds.cpu(), self.full_labels.cpu()).item() 241 | 242 | return acc, f1, auroc 243 | 244 | def calculate_risks(self): 245 | # calculation based source test data 246 | self.evaluate(self.src_test_dl) 247 | src_risk = self.loss.item() 248 | # calculation based few_shot test data 249 | self.evaluate(self.few_shot_dl_5) 250 | fst_risk = self.loss.item() 251 | # calculation based target test data 252 | self.evaluate(self.trg_test_dl) 253 | trg_risk = self.loss.item() 254 | 255 | return src_risk, fst_risk, trg_risk 256 | 257 | def append_results_to_tables(self, table, scenario, run_id, metrics): 258 | 259 | # Create metrics and risks rows 260 | results_row = [scenario, run_id, *metrics] 261 | 262 | # Create new dataframes for each row 263 | results_df = pd.DataFrame([results_row], columns=table.columns) 264 | 265 | # Concatenate new dataframes with original dataframes 266 | table = pd.concat([table, results_df], ignore_index=True) 267 | 268 | return table 269 | 270 | def add_mean_std_table(self, table, columns): 271 | # Calculate average and standard deviation for metrics 272 | avg_metrics = [table[metric].mean() for metric in columns[2:]] 273 | std_metrics = [table[metric].std() for metric in columns[2:]] 274 | 275 | # Create dataframes for mean and std values 276 | mean_metrics_df = pd.DataFrame([['mean', '-', *avg_metrics]], columns=columns) 277 | std_metrics_df = pd.DataFrame([['std', '-', *std_metrics]], columns=columns) 278 | 279 | # Concatenate original dataframes with mean and std dataframes 280 | table = pd.concat([table, mean_metrics_df, std_metrics_df], ignore_index=True) 281 | 282 | # Create a formatting function to format each element in the tables 283 | format_func = lambda x: f"{x:.4f}" if isinstance(x, float) else x 284 | 285 | # Apply the formatting function to each element in the tables 286 | table = table.applymap(format_func) 287 | 288 | return table -------------------------------------------------------------------------------- /trainers/sweep.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('../') 4 | import torch 5 | import torch.nn.functional as F 6 | import os 7 | import wandb 8 | import pandas as pd 9 | import numpy as np 10 | import warnings 11 | import sklearn.exceptions 12 | import collections 13 | import argparse 14 | import warnings 15 | import sklearn.exceptions 16 | 17 | from configs.sweep_params import sweep_alg_hparams 18 | from utils import fix_randomness, starting_logs, DictAsObject 19 | from algorithms.algorithms import get_algorithm_class 20 | from models.models import get_backbone_class 21 | from utils import AverageMeter 22 | 23 | from trainers.abstract_trainer import AbstractTrainer 24 | 25 | warnings.filterwarnings("ignore", category=sklearn.exceptions.UndefinedMetricWarning) 26 | parser = argparse.ArgumentParser() 27 | 28 | 29 | class Trainer(AbstractTrainer): 30 | """ 31 | This class contain the main training functions for our AdAtime 32 | """ 33 | 34 | def __init__(self, args): 35 | super(Trainer, self).__init__(args) 36 | 37 | # sweep parameters 38 | self.num_sweeps = args.num_sweeps 39 | self.sweep_project_wandb = args.sweep_project_wandb 40 | self.wandb_entity = args.wandb_entity 41 | self.hp_search_strategy = args.hp_search_strategy 42 | self.metric_to_minimize = args.metric_to_minimize 43 | 44 | # Logging 45 | self.exp_log_dir = os.path.join(self.home_path, self.save_dir) 46 | os.makedirs(self.exp_log_dir, exist_ok=True) 47 | 48 | def sweep(self): 49 | # sweep configurations 50 | sweep_runs_count = self.num_sweeps 51 | sweep_config = { 52 | 'method': self.hp_search_strategy, 53 | 'metric': {'name': self.metric_to_minimize, 'goal': 'minimize'}, 54 | 'name': self.da_method + '_' + self.backbone, 55 | 'parameters': {**sweep_alg_hparams[self.da_method]} 56 | } 57 | sweep_id = wandb.sweep(sweep_config, project=self.sweep_project_wandb, entity=self.wandb_entity) 58 | 59 | wandb.agent(sweep_id, self.train, count=sweep_runs_count) 60 | 61 | def train(self): 62 | run = wandb.init(config=self.hparams) 63 | self.hparams= wandb.config 64 | 65 | # create tables for results and risks 66 | columns = ["scenario", "run", "acc", "f1_score", "auroc"] 67 | table_results = wandb.Table(columns=columns, allow_mixed_types=True) 68 | columns = ["scenario", "run", "src_risk", "few_shot_risk", "trg_risk"] 69 | table_risks = wandb.Table(columns=columns, allow_mixed_types=True) 70 | 71 | for src_id, trg_id in self.dataset_configs.scenarios: 72 | for run_id in range(self.num_runs): 73 | # set random seed and create logger 74 | fix_randomness(run_id) 75 | self.logger, self.scenario_log_dir = starting_logs( self.dataset, self.da_method, self.exp_log_dir, src_id, trg_id, run_id ) 76 | 77 | # average meters 78 | self.loss_avg_meters = collections.defaultdict(lambda: AverageMeter()) 79 | 80 | # load data and train model 81 | self.load_data(src_id, trg_id) 82 | 83 | # initiate the domain adaptation algorithm 84 | self.initialize_algorithm() 85 | 86 | # Train the domain adaptation algorithm 87 | self.last_model, self.best_model = self.algorithm.update(self.src_train_dl, self.trg_train_dl, self.loss_avg_meters, self.logger) 88 | 89 | # calculate metrics and risks 90 | metrics = self.calculate_metrics() 91 | risks = self.calculate_risks() 92 | 93 | # append results to tables 94 | scenario = f"{src_id}_to_{trg_id}" 95 | table_results.add_data(scenario, run_id, *metrics) 96 | table_risks.add_data(scenario, run_id, *risks) 97 | 98 | # calculate overall metrics and risks 99 | total_results, summary_metrics = self.calculate_avg_std_wandb_table(table_results) 100 | total_risks, summary_risks = self.calculate_avg_std_wandb_table(table_risks) 101 | 102 | # log results to WandB 103 | self.wandb_logging(total_results, total_risks, summary_metrics, summary_risks) 104 | 105 | # finish the run 106 | run.finish() 107 | 108 | -------------------------------------------------------------------------------- /trainers/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import os 6 | import wandb 7 | import pandas as pd 8 | import numpy as np 9 | import warnings 10 | import sklearn.exceptions 11 | import collections 12 | import argparse 13 | import warnings 14 | import sklearn.exceptions 15 | 16 | from utils import fix_randomness, starting_logs, AverageMeter 17 | from algorithms.algorithms import get_algorithm_class 18 | from models.models import get_backbone_class 19 | from trainers.abstract_trainer import AbstractTrainer 20 | warnings.filterwarnings("ignore", category=sklearn.exceptions.UndefinedMetricWarning) 21 | parser = argparse.ArgumentParser() 22 | 23 | 24 | 25 | class Trainer(AbstractTrainer): 26 | """ 27 | This class contain the main training functions for our AdAtime 28 | """ 29 | 30 | def __init__(self, args): 31 | super().__init__(args) 32 | 33 | self.results_columns = ["scenario", "run", "acc", "f1_score", "auroc"] 34 | self.risks_columns = ["scenario", "run", "src_risk", "few_shot_risk", "trg_risk"] 35 | 36 | 37 | def fit(self): 38 | 39 | # table with metrics 40 | table_results = pd.DataFrame(columns=self.results_columns) 41 | 42 | # table with risks 43 | table_risks = pd.DataFrame(columns=self.risks_columns) 44 | 45 | 46 | # Trainer 47 | for src_id, trg_id in self.dataset_configs.scenarios: 48 | for run_id in range(self.num_runs): 49 | # fixing random seed 50 | fix_randomness(run_id) 51 | 52 | # Logging 53 | self.logger, self.scenario_log_dir = starting_logs(self.dataset, self.da_method, self.exp_log_dir, 54 | src_id, trg_id, run_id) 55 | # Average meters 56 | self.loss_avg_meters = collections.defaultdict(lambda: AverageMeter()) 57 | 58 | # Load data 59 | self.load_data(src_id, trg_id) 60 | 61 | # initiate the domain adaptation algorithm 62 | self.initialize_algorithm() 63 | 64 | # Train the domain adaptation algorithm 65 | self.last_model, self.best_model = self.algorithm.update(self.src_train_dl, self.trg_train_dl, self.loss_avg_meters, self.logger) 66 | 67 | # Save checkpoint 68 | self.save_checkpoint(self.home_path, self.scenario_log_dir, self.last_model, self.best_model) 69 | 70 | # Calculate risks and metrics 71 | metrics = self.calculate_metrics() 72 | risks = self.calculate_risks() 73 | 74 | # Append results to tables 75 | scenario = f"{src_id}_to_{trg_id}" 76 | table_results = self.append_results_to_tables(table_results, scenario, run_id, metrics) 77 | table_risks = self.append_results_to_tables(table_risks, scenario, run_id, risks) 78 | 79 | # Calculate and append mean and std to tables 80 | table_results = self.add_mean_std_table(table_results, self.results_columns) 81 | table_risks = self.add_mean_std_table(table_risks, self.risks_columns) 82 | 83 | 84 | # Save tables to file if needed 85 | self.save_tables_to_file(table_results, 'results') 86 | self.save_tables_to_file(table_risks, 'risks') 87 | 88 | def test(self): 89 | # Results dataframes 90 | last_results = pd.DataFrame(columns=self.results_columns) 91 | best_results = pd.DataFrame(columns=self.results_columns) 92 | 93 | # Cross-domain scenarios 94 | for src_id, trg_id in self.dataset_configs.scenarios: 95 | for run_id in range(self.num_runs): 96 | # fixing random seed 97 | fix_randomness(run_id) 98 | 99 | # Logging 100 | self.scenario_log_dir = os.path.join(self.exp_log_dir, src_id + "_to_" + trg_id + "_run_" + str(run_id)) 101 | 102 | self.loss_avg_meters = collections.defaultdict(lambda: AverageMeter()) 103 | 104 | # Load data 105 | self.load_data(src_id, trg_id) 106 | 107 | # Build model 108 | self.initialize_algorithm() 109 | 110 | # Load chechpoint 111 | last_chk, best_chk = self.load_checkpoint(self.scenario_log_dir) 112 | 113 | # Testing the last model 114 | self.algorithm.network.load_state_dict(last_chk) 115 | self.evaluate(self.trg_test_dl) 116 | last_metrics = self.calculate_metrics() 117 | last_results = self.append_results_to_tables(last_results, f"{src_id}_to_{trg_id}", run_id, 118 | last_metrics) 119 | 120 | 121 | # Testing the best model 122 | self.algorithm.network.load_state_dict(best_chk) 123 | self.evaluate(self.trg_test_dl) 124 | best_metrics = self.calculate_metrics() 125 | # Append results to tables 126 | best_results = self.append_results_to_tables(best_results, f"{src_id}_to_{trg_id}", run_id, 127 | best_metrics) 128 | 129 | last_scenario_mean_std = last_results.groupby('scenario')[['acc', 'f1_score', 'auroc']].agg(['mean', 'std']) 130 | best_scenario_mean_std = best_results.groupby('scenario')[['acc', 'f1_score', 'auroc']].agg(['mean', 'std']) 131 | 132 | 133 | # Save tables to file if needed 134 | self.save_tables_to_file(last_scenario_mean_std, 'last_results') 135 | self.save_tables_to_file(best_scenario_mean_std, 'best_results') 136 | 137 | # printing summary 138 | summary_last = {metric: np.mean(last_results[metric]) for metric in self.results_columns[2:]} 139 | summary_best = {metric: np.mean(best_results[metric]) for metric in self.results_columns[2:]} 140 | for summary_name, summary in [('Last', summary_last), ('Best', summary_best)]: 141 | for key, val in summary.items(): 142 | print(f'{summary_name}: {key}\t: {val:2.4f}') 143 | 144 | 145 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn as nn 4 | 5 | import random 6 | import os 7 | import sys 8 | import logging 9 | import numpy as np 10 | import pandas as pd 11 | from shutil import copy 12 | from datetime import datetime 13 | 14 | from skorch import NeuralNetClassifier # for DIV Risk 15 | from sklearn.model_selection import train_test_split 16 | from sklearn.metrics import classification_report, accuracy_score 17 | 18 | 19 | class AverageMeter(object): 20 | """Computes and stores the average and current value""" 21 | 22 | def __init__(self): 23 | self.reset() 24 | 25 | def reset(self): 26 | self.val = 0 27 | self.avg = 0 28 | self.sum = 0 29 | self.count = 0 30 | 31 | def update(self, val, n=1): 32 | self.val = val 33 | self.sum += val * n 34 | self.count += n 35 | self.avg = self.sum / self.count 36 | 37 | 38 | def fix_randomness(SEED): 39 | random.seed(SEED) 40 | np.random.seed(SEED) 41 | torch.manual_seed(SEED) 42 | torch.cuda.manual_seed(SEED) 43 | torch.backends.cudnn.deterministic = True 44 | torch.backends.cudnn.benchmark = False 45 | 46 | 47 | def _logger(logger_name, level=logging.DEBUG): 48 | """ 49 | Method to return a custom logger with the given name and level 50 | """ 51 | logger = logging.getLogger(logger_name) 52 | logger.setLevel(level) 53 | format_string = "%(message)s" 54 | log_format = logging.Formatter(format_string) 55 | # Creating and adding the console handler 56 | console_handler = logging.StreamHandler(sys.stdout) 57 | console_handler.setFormatter(log_format) 58 | logger.addHandler(console_handler) 59 | # Creating and adding the file handler 60 | file_handler = logging.FileHandler(logger_name, mode='a') 61 | file_handler.setFormatter(log_format) 62 | logger.addHandler(file_handler) 63 | return logger 64 | 65 | 66 | def starting_logs(data_type, da_method, exp_log_dir, src_id, tgt_id, run_id): 67 | log_dir = os.path.join(exp_log_dir, src_id + "_to_" + tgt_id + "_run_" + str(run_id)) 68 | os.makedirs(log_dir, exist_ok=True) 69 | log_file_name = os.path.join(log_dir, f"logs_{datetime.now().strftime('%d_%m_%Y_%H_%M_%S')}.log") 70 | logger = _logger(log_file_name) 71 | logger.debug("=" * 45) 72 | logger.debug(f'Dataset: {data_type}') 73 | logger.debug(f'Method: {da_method}') 74 | logger.debug("=" * 45) 75 | logger.debug(f'Source: {src_id} ---> Target: {tgt_id}') 76 | logger.debug(f'Run ID: {run_id}') 77 | logger.debug("=" * 45) 78 | return logger, log_dir 79 | 80 | 81 | def save_checkpoint(home_path, algorithm, log_dir, last_model, best_model): 82 | save_dict = { 83 | "last": last_model, 84 | "best": best_model 85 | } 86 | # save classification report 87 | save_path = os.path.join(home_path, log_dir, f"checkpoint.pt") 88 | 89 | torch.save(save_dict, save_path) 90 | 91 | 92 | def weights_init(m): 93 | classname = m.__class__.__name__ 94 | if classname.find('Conv') != -1: 95 | m.weight.data.normal_(0.0, 0.02) 96 | elif classname.find('BatchNorm') != -1: 97 | m.weight.data.normal_(1.0, 0.02) 98 | m.bias.data.fill_(0) 99 | elif classname.find('Linear') != -1: 100 | m.weight.data.normal_(0.0, 0.1) 101 | m.bias.data.fill_(0) 102 | 103 | 104 | def _calc_metrics(pred_labels, true_labels, log_dir, home_path, target_names): 105 | pred_labels = np.array(pred_labels).astype(int) 106 | true_labels = np.array(true_labels).astype(int) 107 | 108 | r = classification_report(true_labels, pred_labels, target_names=target_names, digits=6, output_dict=True) 109 | 110 | df = pd.DataFrame(r) 111 | accuracy = accuracy_score(true_labels, pred_labels) 112 | df["accuracy"] = accuracy 113 | df = df * 100 114 | 115 | # save classification report 116 | file_name = "classification_report.xlsx" 117 | report_Save_path = os.path.join(home_path, log_dir, file_name) 118 | df.to_excel(report_Save_path) 119 | 120 | return accuracy * 100, r["macro avg"]["f1-score"] * 100 121 | 122 | 123 | def copy_Files(destination): 124 | destination_dir = os.path.join(destination, "MODEL_BACKUP_FILES") 125 | os.makedirs(destination_dir, exist_ok=True) 126 | copy("main.py", os.path.join(destination_dir, "main.py")) 127 | copy("utils.py", os.path.join(destination_dir, "utils.py")) 128 | copy(f"trainer.py", os.path.join(destination_dir, f"trainer.py")) 129 | copy(f"same_domain_trainer.py", os.path.join(destination_dir, f"same_domain_trainer.py")) 130 | copy("dataloader/dataloader.py", os.path.join(destination_dir, "dataloader.py")) 131 | copy(f"models/models.py", os.path.join(destination_dir, f"models.py")) 132 | copy(f"models/loss.py", os.path.join(destination_dir, f"loss.py")) 133 | copy("algorithms/algorithms.py", os.path.join(destination_dir, "algorithms.py")) 134 | copy(f"configs/data_model_configs.py", os.path.join(destination_dir, f"data_model_configs.py")) 135 | copy(f"configs/hparams.py", os.path.join(destination_dir, f"hparams.py")) 136 | copy(f"configs/sweep_params.py", os.path.join(destination_dir, f"sweep_params.py")) 137 | 138 | 139 | 140 | 141 | def get_iwcv_value(weight, error): 142 | N, d = weight.shape 143 | _N, _d = error.shape 144 | assert N == _N and d == _d, 'dimension mismatch!' 145 | weighted_error = weight * error 146 | return np.mean(weighted_error) 147 | 148 | 149 | def get_dev_value(weight, error): 150 | """ 151 | :param weight: shape [N, 1], the importance weight for N source samples in the validation set 152 | :param error: shape [N, 1], the error value for each source sample in the validation set 153 | (typically 0 for correct classification and 1 for wrong classification) 154 | """ 155 | N, d = weight.shape 156 | _N, _d = error.shape 157 | assert N == _N and d == _d, 'dimension mismatch!' 158 | weighted_error = weight * error 159 | cov = np.cov(np.concatenate((weighted_error, weight), axis=1), rowvar=False)[0][1] 160 | var_w = np.var(weight, ddof=1) 161 | eta = - cov / var_w 162 | return np.mean(weighted_error) + eta * np.mean(weight) - eta 163 | 164 | 165 | class simple_MLP(nn.Module): 166 | def __init__(self, inp_units, out_units=2): 167 | super(simple_MLP, self).__init__() 168 | 169 | self.dense0 = nn.Linear(inp_units, inp_units // 2) 170 | self.nonlin = nn.ReLU() 171 | self.output = nn.Linear(inp_units // 2, out_units) 172 | self.softmax = nn.Softmax(dim=-1) 173 | 174 | def forward(self, x, **kwargs): 175 | x = self.nonlin(self.dense0(x)) 176 | x = self.softmax(self.output(x)) 177 | return x 178 | 179 | 180 | def get_weight_gpu(source_feature, target_feature, validation_feature, configs, device): 181 | """ 182 | :param source_feature: shape [N_tr, d], features from training set 183 | :param target_feature: shape [N_te, d], features from test set 184 | :param validation_feature: shape [N_v, d], features from validation set 185 | :return: 186 | """ 187 | import copy 188 | N_s, d = source_feature.shape 189 | N_t, _d = target_feature.shape 190 | source_feature = copy.deepcopy(source_feature.detach().cpu()) # source_feature.clone() 191 | target_feature = copy.deepcopy(target_feature.detach().cpu()) # target_feature.clone() 192 | source_feature = source_feature.to(device) 193 | target_feature = target_feature.to(device) 194 | all_feature = torch.cat((source_feature, target_feature), dim=0) 195 | all_label = torch.from_numpy(np.asarray([1] * N_s + [0] * N_t, dtype=np.int32)).long() 196 | 197 | feature_for_train, feature_for_test, label_for_train, label_for_test = train_test_split(all_feature, all_label, 198 | train_size=0.8) 199 | learning_rates = [1e-1, 5e-2, 1e-2] 200 | val_acc = [] 201 | domain_classifiers = [] 202 | 203 | for lr in learning_rates: 204 | domain_classifier = NeuralNetClassifier( 205 | simple_MLP, 206 | module__inp_units=configs.final_out_channels * configs.features_len, 207 | max_epochs=30, 208 | lr=lr, 209 | device=device, 210 | # Shuffle training data on each epoch 211 | iterator_train__shuffle=True, 212 | callbacks="disable" 213 | ) 214 | domain_classifier.fit(feature_for_train.float(), label_for_train.long()) 215 | output = domain_classifier.predict(feature_for_test) 216 | acc = np.mean((label_for_test.numpy() == output).astype(np.float32)) 217 | val_acc.append(acc) 218 | domain_classifiers.append(domain_classifier) 219 | 220 | index = val_acc.index(max(val_acc)) 221 | domain_classifier = domain_classifiers[index] 222 | 223 | domain_out = domain_classifier.predict_proba(validation_feature.to(device).float()) 224 | return domain_out[:, :1] / domain_out[:, 1:] * N_s * 1.0 / N_t 225 | 226 | 227 | def calc_dev_risk(target_model, src_train_dl, tgt_train_dl, src_valid_dl, configs, device): 228 | src_train_feats = target_model.feature_extractor(src_train_dl.dataset.x_data.float().to(device)) 229 | tgt_train_feats = target_model.feature_extractor(tgt_train_dl.dataset.x_data.float().to(device)) 230 | src_valid_feats = target_model.feature_extractor(src_valid_dl.dataset.x_data.float().to(device)) 231 | src_valid_pred = target_model.classifier(src_valid_feats) 232 | 233 | dev_weights = get_weight_gpu(src_train_feats.to(device), tgt_train_feats.to(device), 234 | src_valid_feats.to(device), configs, device) 235 | dev_error = F.cross_entropy(src_valid_pred, src_valid_dl.dataset.y_data.long().to(device), reduction='none') 236 | dev_risk = get_dev_value(dev_weights, dev_error.unsqueeze(1).detach().cpu().numpy()) 237 | # iwcv_risk = get_iwcv_value(dev_weights, dev_error.unsqueeze(1).detach().cpu().numpy()) 238 | return dev_risk 239 | 240 | 241 | def calculate_risk(target_model, risk_dataloader, device): 242 | if type(risk_dataloader) == tuple: 243 | x_data = torch.cat((risk_dataloader[0].dataset.x_data, risk_dataloader[1].dataset.x_data), axis=0) 244 | y_data = torch.cat((risk_dataloader[0].dataset.y_data, risk_dataloader[1].dataset.y_data), axis=0) 245 | else: 246 | x_data = risk_dataloader.dataset.x_data 247 | y_data = risk_dataloader.dataset.y_data 248 | 249 | feat = target_model.feature_extractor(x_data.float().to(device)) 250 | pred = target_model.classifier(feat) 251 | cls_loss = F.cross_entropy(pred, y_data.long().to(device)) 252 | return cls_loss.item() 253 | 254 | 255 | class DictAsObject: 256 | def __init__(self, d): 257 | self.__dict__ = d 258 | 259 | def __getattr__(self, name): 260 | try: 261 | return self.__dict__[name] 262 | except KeyError: 263 | raise AttributeError(f"'DictAsObject' object has no attribute '{name}'") 264 | 265 | 266 | # For DIRT-T 267 | class EMA: 268 | def __init__(self, decay): 269 | self.decay = decay 270 | self.shadow = {} 271 | 272 | def register(self, model): 273 | for name, param in model.named_parameters(): 274 | if param.requires_grad: 275 | self.shadow[name] = param.data.clone() 276 | self.params = self.shadow.keys() 277 | 278 | def __call__(self, model): 279 | if self.decay > 0: 280 | for name, param in model.named_parameters(): 281 | if name in self.params and param.requires_grad: 282 | self.shadow[name] -= (1 - self.decay) * (self.shadow[name] - param.data) 283 | param.data = self.shadow[name] 284 | 285 | 286 | --------------------------------------------------------------------------------