├── data └── Readme.md ├── mmd.py ├── replay_memory.py ├── README.md ├── benchmark_KDSTDA_READ.py ├── main.py ├── configs ├── sweep_params.py ├── data_model_configs.py └── hparams.py ├── dataloader └── dataloader.py ├── utils.py ├── models ├── loss.py └── models.py ├── single_domain_trainer.py ├── trainer.py ├── benchmark_Mobileda_and_AAD.py ├── benchmark_Multi_Level_Distillation.py ├── proposed_RCD_KD.py └── benchmark_Max_Cluser_Difference.py /data/Readme.md: -------------------------------------------------------------------------------- 1 | Download the data and put into this folder 2 | -------------------------------------------------------------------------------- /mmd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class MMD_loss(nn.Module): 6 | def __init__(self, kernel_mul = 2.0, kernel_num = 5): 7 | super(MMD_loss, self).__init__() 8 | self.kernel_num = kernel_num 9 | self.kernel_mul = kernel_mul 10 | self.fix_sigma = None 11 | return 12 | def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): 13 | n_samples = int(source.size()[0])+int(target.size()[0]) 14 | total = torch.cat([source, target], dim=0) 15 | 16 | total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) 17 | total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) 18 | L2_distance = ((total0-total1)**2).sum(2) 19 | if fix_sigma: 20 | bandwidth = fix_sigma 21 | else: 22 | bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples) 23 | bandwidth /= kernel_mul ** (kernel_num // 2) 24 | bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)] 25 | kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list] 26 | return sum(kernel_val) 27 | 28 | def forward(self, source, target): 29 | batch_size = int(source.size()[0]) 30 | kernels = self.guassian_kernel(source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma) 31 | XX = kernels[:batch_size, :batch_size] 32 | YY = kernels[batch_size:, batch_size:] 33 | XY = kernels[:batch_size, batch_size:] 34 | YX = kernels[batch_size:, :batch_size] 35 | loss = torch.mean(XX + YY - XY -YX) 36 | return loss 37 | -------------------------------------------------------------------------------- /replay_memory.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from collections import deque 4 | 5 | class ReplayMemory: 6 | def __init__(self, capacity): 7 | self.capacity = capacity 8 | self.buffer = [] 9 | self.position = 0 10 | 11 | def push(self, state, action, reward, next_state, done): 12 | if len(self.buffer) < self.capacity: 13 | self.buffer.append(None) 14 | self.buffer[self.position] = (state, action, reward, next_state, done) 15 | self.position = (self.position + 1) % self.capacity 16 | 17 | def sample(self, batch_size): 18 | batch = random.sample(self.buffer, batch_size) 19 | state, action, reward, next_state, done = map(np.stack, zip(*batch)) 20 | return state, action, reward, next_state, done 21 | 22 | def __len__(self): 23 | return len(self.buffer) 24 | 25 | 26 | class Memory(object): 27 | def __init__(self, memory_size: int) -> None: 28 | self.memory_size = memory_size 29 | self.buffer = deque(maxlen=self.memory_size) 30 | 31 | def add(self, experience) -> None: 32 | self.buffer.append(experience) 33 | 34 | def size(self): 35 | return len(self.buffer) 36 | 37 | def sample(self, batch_size: int, continuous: bool = True): 38 | if batch_size > len(self.buffer): 39 | batch_size = len(self.buffer) 40 | if continuous: 41 | rand = random.randint(0, len(self.buffer) - batch_size) 42 | return [self.buffer[i] for i in range(rand, rand + batch_size)] 43 | else: 44 | indexes = np.random.choice(np.arange(len(self.buffer)), size=batch_size, replace=False) 45 | return [self.buffer[i] for i in indexes] 46 | 47 | def clear(self): 48 | self.buffer.clear() 49 | 50 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Reinforced-Cross-Domain-Knowledge-Distillation-on-Time-Series-Data 2 | This repository is a PyTorch implementation for NIPS 2024 Paper "Reinforced Cross-Domain Knowledge Distillation on Time Series Data". 3 | 4 | ## Datasets 5 | 6 | ### Available Datasets 7 | We used four public datasets in this study. We also provide the **preprocessed** versions as follows: 8 | 9 | - [UCIHAR](https://researchdata.ntu.edu.sg/dataset.xhtml?persistentId=doi:10.21979/N9/0SYHTZ) 10 | - [HHAR](https://researchdata.ntu.edu.sg/dataset.xhtml?persistentId=doi:10.21979/N9/OWDFXO) 11 | - [FD](https://mb.uni-paderborn.de/en/kat/main-research/datacenter/bearing-datacenter/data-sets-and-download) 12 | - [SSC](https://researchdata.ntu.edu.sg/dataset.xhtml?persistentId=doi:10.21979/N9/UD1IM9) 13 | 14 | Please download these datasets and put them in the respective folder in "data" 15 | 16 | 17 | ## Unsupervised Domain Adaptation Algorithms 18 | ### Existing Benchmark Algorithms 19 | - [KD-STDA](https://arxiv.org/pdf/2101.07308) 20 | - [KA-MCD](https://arxiv.org/pdf/1702.02052) 21 | - [MLD-DA](https://openaccess.thecvf.com/content/WACV2021W/AVV/papers/Kothandaraman_Domain_Adaptive_Knowledge_Distillation_for_Driving_Scene_Semantic_Segmentation_WACVW_2021_paper.pdf) 22 | - [REDA](https://junguangjiang.github.io/files/resource-efficient-domain-adaptation-acmmm20.pdf) 23 | - [AAD](https://arxiv.org/pdf/2010.11478.pdf) 24 | - [MobileDA](https://ieeexplore.ieee.org/abstract/document/9016215/) 25 | - [UNI-KD](https://arxiv.org/pdf/2307.03347) 26 | 27 | ## Runing Proposed RCD-KD Algorithm 28 | 29 | ### Teacher training 30 | Our approach requires a pre-trained teacher. We utilize DANN method to train a teacher and store them in 'experiments_logs/HAR/Teacher_CNN'. 31 | For different dataset, please save the teachers into respectively dataset folder. Note that for teacher, we set 'feature_dim = 64' in 'configs/data_model_configs.py' 32 | and for the student we set 'feature_dim = 16'. 33 | 34 | ## Student training 35 | To train a student with our proposed approach, run: 36 | 37 | ``` 38 | python proposed_RCD_KD.py --experiment_description exp1 \ 39 | --run_description run_1 \ 40 | --da_method RL_JointADKD \ 41 | --dataset HAR \ 42 | --backbone CNN \ 43 | --num_runs 3 \ 44 | ``` 45 | 46 | ## Claims 47 | Part of benchmark methods code are from [AdaTime](https://github.com/emadeldeen24/AdaTime) and [UNI-KD](https://arxiv.org/pdf/2307.03347) 48 | -------------------------------------------------------------------------------- /benchmark_KDSTDA_READ.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import warnings 4 | from trainer import cross_domain_trainer 5 | import sklearn.exceptions 6 | warnings.filterwarnings("ignore", category=sklearn.exceptions.UndefinedMetricWarning) 7 | 8 | parser = argparse.ArgumentParser() 9 | 10 | 11 | # ======== Experiments Name ================ 12 | parser.add_argument('--save_dir', default='experiments_logs_additional', type=str, help='Directory containing all experiments') 13 | parser.add_argument('--experiment_description', default='FD', type=str, help='Name of your experiment (HAR, HHAR_SA, FD, EEG, ') 14 | parser.add_argument('--run_description', default='KDSTDA', type=str, help='name of your runs, ') 15 | 16 | # ========= Select the DA methods ============ 17 | parser.add_argument('--da_method', default='KDSTDA', type=str, help='KDSTDA, REDA') 18 | 19 | # ========= Select the DATASET ============== 20 | parser.add_argument('--data_path', default=r'./data', type=str, help='Path containing dataset') 21 | parser.add_argument('--dataset', default='FD', type=str, help='Dataset of choice: (HAR, HHAR_SA, FD, EEG)') 22 | 23 | # ========= Select the BACKBONE ============== 24 | parser.add_argument('--backbone', default='CNN', type=str, help='Backbone of choice: (CNN_T') 25 | 26 | # ========= Experiment settings =============== 27 | parser.add_argument('--num_runs', default = 3, type=int, help='Number of consecutive run with different seeds') 28 | parser.add_argument('--device', default='cuda:0', type=str, help='cpu or cuda') 29 | 30 | # ======== sweep settings ===================== 31 | parser.add_argument('--is_sweep', default=False, type=bool, help='singe run or sweep') 32 | parser.add_argument('--num_sweeps', default=20, type=str, help='Number of sweep runs') 33 | 34 | # We run sweeps using wandb plateform, so next parameters are for wandb. 35 | parser.add_argument('--sweep_project_wandb', default='TEST_SOMETHING', type=str, help='Project name in Wandb') 36 | parser.add_argument('--wandb_entity', type=str, help='Entity name in Wandb (can be left blank if there is a default entity)') 37 | parser.add_argument('--hp_search_strategy', default="random", type=str, help='The way of selecting hyper-parameters (random-grid-bayes). in wandb see:https://docs.wandb.ai/guides/sweeps/configuration') 38 | parser.add_argument('--metric_to_minimize', default="src_risk", type=str, help='select one of: (src_risk - trg_risk - few_shot_trg_risk - dev_risk)') 39 | 40 | 41 | 42 | args = parser.parse_args() 43 | 44 | if __name__ == "__main__": 45 | 46 | trainer = cross_domain_trainer(args) 47 | 48 | if args.is_sweep: 49 | trainer.sweep() 50 | else: 51 | trainer.train() 52 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import warnings 4 | from trainer import cross_domain_trainer 5 | import sklearn.exceptions 6 | warnings.filterwarnings("ignore", category=sklearn.exceptions.UndefinedMetricWarning) 7 | 8 | parser = argparse.ArgumentParser() 9 | 10 | 11 | # ======== Experiments Name ================ 12 | parser.add_argument('--save_dir', default='experiments_logs_additional', type=str, help='Directory containing all experiments') 13 | parser.add_argument('--experiment_description', default='FD', type=str, help='Name of your experiment (HAR, HHAR_SA, FD, EEG, ') 14 | parser.add_argument('--run_description', default='DANN_CNN_teacher', type=str, help='name of your runs, ') 15 | 16 | # ========= Select the DA methods ============ 17 | parser.add_argument('--da_method', default='DANN_T', type=str, help='KDSTDA, REDA,HoMM, MMDA,DANN,CoDATS') 18 | 19 | # ========= Select the DATASET ============== 20 | parser.add_argument('--data_path', default=r'./data', type=str, help='Path containing dataset') 21 | parser.add_argument('--dataset', default='FD', type=str, help='Dataset of choice: (HAR, HHAR_SA, FD, EEG)') 22 | 23 | # ========= Select the BACKBONE ============== 24 | parser.add_argument('--backbone', default='CNN_T', type=str, help='Backbone of choice: (CNN_T') 25 | 26 | # ========= Experiment settings =============== 27 | parser.add_argument('--num_runs', default = 3, type=int, help='Number of consecutive run with different seeds') 28 | parser.add_argument('--device', default='cuda:0', type=str, help='cpu or cuda') 29 | 30 | # ======== sweep settings ===================== 31 | parser.add_argument('--is_sweep', default=False, type=bool, help='singe run or sweep') 32 | parser.add_argument('--num_sweeps', default=20, type=str, help='Number of sweep runs') 33 | 34 | # We run sweeps using wandb plateform, so next parameters are for wandb. 35 | parser.add_argument('--sweep_project_wandb', default='TEST_SOMETHING', type=str, help='Project name in Wandb') 36 | parser.add_argument('--wandb_entity', type=str, help='Entity name in Wandb (can be left blank if there is a default entity)') 37 | parser.add_argument('--hp_search_strategy', default="random", type=str, help='The way of selecting hyper-parameters (random-grid-bayes). in wandb see:https://docs.wandb.ai/guides/sweeps/configuration') 38 | parser.add_argument('--metric_to_minimize', default="src_risk", type=str, help='select one of: (src_risk - trg_risk - few_shot_trg_risk - dev_risk)') 39 | 40 | 41 | 42 | args = parser.parse_args() 43 | 44 | if __name__ == "__main__": 45 | 46 | trainer = cross_domain_trainer(args) 47 | 48 | if args.is_sweep: 49 | trainer.sweep() 50 | else: 51 | trainer.train() 52 | -------------------------------------------------------------------------------- /configs/sweep_params.py: -------------------------------------------------------------------------------- 1 | sweep_train_hparams = { 2 | 'num_epochs': {'values': [3, 4, 5, 6]}, 3 | 'batch_size': {'values': [32, 64]}, 4 | 'learning_rate':{'values': [1e-2, 5e-3, 1e-3, 5e-4]}, 5 | 'disc_lr': {'values': [1e-2, 5e-3, 1e-3, 5e-4]}, 6 | 'weight_decay': {'values': [1e-4, 1e-5, 1e-6]}, 7 | 'step_size': {'values': [5, 10, 30]}, 8 | 'gamma': {'values': [5, 10, 15, 20, 25]}, 9 | 'optimizer': {'values': ['adam']}, 10 | } 11 | sweep_alg_hparams = { 12 | 'DANN': { 13 | 'learning_rate': {'values': [1e-2, 5e-3, 1e-3, 5e-4]}, 14 | 'src_cls_loss_wt': {'distribution': 'uniform', 'min': 1e-1, 'max': 10}, 15 | 'domain_loss_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10}, 16 | }, 17 | 18 | 'AdvSKM': { 19 | 'learning_rate': {'values': [1e-2, 5e-3, 1e-3, 5e-4]}, 20 | 'src_cls_loss_wt': {'distribution': 'uniform', 'min': 1e-1, 'max': 10}, 21 | 'domain_loss_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10}, 22 | }, 23 | 24 | 'CoDATS': { 25 | 'learning_rate': {'values': [1e-2, 5e-3, 1e-3, 5e-4]}, 26 | 'src_cls_loss_wt': {'distribution': 'uniform', 'min': 1e-1, 'max': 10}, 27 | 'domain_loss_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10}, 28 | }, 29 | 30 | 'CDAN': { 31 | 'learning_rate': {'values': [1e-2, 5e-3, 1e-3, 5e-4]}, 32 | 'src_cls_loss_wt': {'distribution': 'uniform', 'min': 1e-1, 'max': 10}, 33 | 'domain_loss_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10}, 34 | 'cond_ent_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10}, 35 | }, 36 | 37 | 'Deep_Coral': { 38 | 'learning_rate': {'values': [1e-2, 5e-3, 1e-3, 5e-4]}, 39 | 'src_cls_loss_wt': {'distribution': 'uniform', 'min': 1e-1, 'max': 10}, 40 | 'coral_wt': {'distribution': 'uniform', 'min': 1e-1, 'max': 10}, 41 | }, 42 | 43 | 'DIRT': { 44 | 'learning_rate': {'values': [1e-2, 5e-3, 1e-3, 5e-4]}, 45 | 'src_cls_loss_wt': {'distribution': 'uniform', 'min': 1e-1, 'max': 10}, 46 | 'domain_loss_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10}, 47 | 'cond_ent_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10}, 48 | 'vat_loss_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10}, 49 | }, 50 | 51 | 'HoMM': { 52 | 'learning_rate': {'values': [1e-2, 5e-3, 1e-3, 5e-4]}, 53 | 'src_cls_loss_wt': {'distribution': 'uniform', 'min': 1e-1, 'max': 10}, 54 | 'hommd_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10}, 55 | }, 56 | 57 | 'MMDA': { 58 | 'learning_rate': {'values': [1e-2, 5e-3, 1e-3, 5e-4]}, 59 | 'src_cls_loss_wt': {'distribution': 'uniform', 'min': 1e-1, 'max': 10}, 60 | 'coral_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10}, 61 | 'cond_ent_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10}, 62 | 'mmd_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10}, 63 | }, 64 | 65 | 'DSAN': { 66 | 'learning_rate': {'values': [1e-2, 5e-3, 1e-3, 5e-4]}, 67 | 'src_cls_loss_wt': {'distribution': 'uniform', 'min': 1e-1, 'max': 10}, 68 | 'mmd_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10}, 69 | }, 70 | 71 | 'DDC': { 72 | 'learning_rate': {'values': [1e-2, 5e-3, 1e-3, 5e-4]}, 73 | 'src_cls_loss_wt': {'distribution': 'uniform', 'min': 1e-1, 'max': 10}, 74 | 'mmd_wt': {'distribution': 'uniform', 'min': 1e-2, 'max': 10}, 75 | }, 76 | } 77 | 78 | -------------------------------------------------------------------------------- /dataloader/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torch.utils.data import Dataset 4 | from torchvision import transforms 5 | 6 | from sklearn.model_selection import train_test_split 7 | 8 | import os 9 | import numpy as np 10 | import random 11 | 12 | 13 | class Load_Dataset(Dataset): 14 | def __init__(self, dataset, normalize): 15 | super(Load_Dataset, self).__init__() 16 | 17 | X_train = dataset["samples"] 18 | y_train = dataset["labels"] 19 | 20 | if len(X_train.shape) < 3: 21 | X_train = X_train.unsqueeze(2) 22 | 23 | if isinstance(X_train, np.ndarray): 24 | X_train = torch.from_numpy(X_train) 25 | y_train = torch.from_numpy(y_train).long() 26 | 27 | if X_train.shape.index(min(X_train.shape[1], X_train.shape[2])) != 1: # make sure the Channels in second dim 28 | X_train = X_train.permute(0, 2, 1) 29 | 30 | self.x_data = X_train 31 | self.y_data = y_train 32 | 33 | self.num_channels = X_train.shape[1] 34 | 35 | if normalize: 36 | # Assume datashape: num_samples, num_channels, seq_length 37 | data_mean = torch.FloatTensor(self.num_channels).fill_(0).tolist() # assume min= number of channels 38 | data_std = torch.FloatTensor(self.num_channels).fill_(1).tolist() # assume min= number of channels 39 | data_transform = transforms.Normalize(mean=data_mean, std=data_std) 40 | self.transform = data_transform 41 | else: 42 | self.transform = None 43 | 44 | self.len = X_train.shape[0] 45 | 46 | def __getitem__(self, index): 47 | if self.transform is not None: 48 | output = self.transform(self.x_data[index].view(self.num_channels, -1, 1)) 49 | self.x_data[index] = output.view(self.x_data[index].shape) 50 | 51 | return self.x_data[index].float(), self.y_data[index].long() 52 | 53 | def __len__(self): 54 | return self.len 55 | 56 | 57 | def data_generator(data_path, domain_id, dataset_configs, hparams): 58 | # loading path 59 | train_dataset = torch.load(os.path.join(data_path, "train_" + domain_id + ".pt")) 60 | test_dataset = torch.load(os.path.join(data_path, "test_" + domain_id + ".pt")) 61 | 62 | # Loading datasets 63 | train_dataset = Load_Dataset(train_dataset, dataset_configs.normalize) 64 | test_dataset = Load_Dataset(test_dataset, dataset_configs.normalize) 65 | 66 | # Dataloaders 67 | batch_size = hparams["batch_size"] 68 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, 69 | shuffle=True, drop_last=True, num_workers=0) 70 | 71 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, 72 | shuffle=False, drop_last=dataset_configs.drop_last, num_workers=0) 73 | return train_loader, test_loader 74 | 75 | 76 | def few_shot_data_generator(data_loader): 77 | x_data = data_loader.dataset.x_data 78 | y_data = data_loader.dataset.y_data 79 | if not isinstance(y_data, (np.ndarray)): 80 | y_data = y_data.numpy() 81 | 82 | NUM_SAMPLES_PER_CLASS = 5 83 | NUM_CLASSES = len(np.unique(y_data)) 84 | 85 | samples_count_dict = {id: 0 for id in range(NUM_CLASSES)} 86 | 87 | # if the min number of samples in one class is less than NUM_SAMPLES_PER_CLASS 88 | y_list = y_data.tolist() 89 | counts = [y_list.count(i) for i in range(NUM_CLASSES)] 90 | 91 | for i in samples_count_dict: 92 | if counts[i] < NUM_SAMPLES_PER_CLASS: 93 | samples_count_dict[i] = counts[i] 94 | else: 95 | samples_count_dict[i] = NUM_SAMPLES_PER_CLASS 96 | 97 | # if min(counts) < NUM_SAMPLES_PER_CLASS: 98 | # NUM_SAMPLES_PER_CLASS = min(counts) 99 | 100 | samples_ids = {} 101 | for i in range(NUM_CLASSES): 102 | samples_ids[i] = [np.where(y_data == i)[0]][0] 103 | 104 | selected_ids = {} 105 | for i in range(NUM_CLASSES): 106 | selected_ids[i] = random.sample(list(samples_ids[i]), samples_count_dict[i]) 107 | 108 | # select the samples according to the selected random ids 109 | y = torch.from_numpy(y_data) 110 | selected_x = x_data[list(selected_ids[0])] 111 | selected_y = y[list(selected_ids[0])] 112 | 113 | for i in range(1, NUM_CLASSES): 114 | selected_x = torch.cat((selected_x, x_data[list(selected_ids[i])]), dim=0) 115 | selected_y = torch.cat((selected_y, y[list(selected_ids[i])]), dim=0) 116 | 117 | few_shot_dataset = {"samples": selected_x, "labels": selected_y} 118 | # Loading datasets 119 | few_shot_dataset = Load_Dataset(few_shot_dataset, None) 120 | 121 | # Dataloaders 122 | few_shot_loader = torch.utils.data.DataLoader(dataset=few_shot_dataset, batch_size=len(few_shot_dataset), 123 | shuffle=False, drop_last=False, num_workers=0) 124 | return few_shot_loader 125 | 126 | 127 | def generator_percentage_of_data(data_loader): 128 | x_data = data_loader.dataset.x_data 129 | y_data = data_loader.dataset.y_data 130 | 131 | X_train, X_val, y_train, y_val = train_test_split(x_data, y_data, test_size=0.1, random_state=0) 132 | 133 | few_shot_dataset = {"samples": X_val, "labels": y_val} 134 | # Loading datasets 135 | few_shot_dataset = Load_Dataset(few_shot_dataset, None) 136 | 137 | few_shot_loader = torch.utils.data.DataLoader(dataset=few_shot_dataset, batch_size=32, 138 | shuffle=True, drop_last=True, num_workers=0) 139 | return few_shot_loader 140 | -------------------------------------------------------------------------------- /configs/data_model_configs.py: -------------------------------------------------------------------------------- 1 | def get_dataset_class(dataset_name): 2 | """Return the algorithm class with the given name.""" 3 | if dataset_name not in globals(): 4 | raise NotImplementedError("Dataset not found: {}".format(dataset_name)) 5 | return globals()[dataset_name] 6 | 7 | feature_dim = 16 #[16, 32,64] 8 | 9 | class HAR(): 10 | def __init__(self): 11 | super(HAR, self) 12 | self.scenarios = [("2", "11"), ("7", "13"), ("12", "16"), ("9", "18"), ("6", "23")] 13 | # self.scenarios = [("18", "27"), ("20", "5"), ("24", "8"), ("28", "27"), ("30", "20")] # additional 14 | self.class_names = ['walk', 'upstairs', 'downstairs', 'sit', 'stand', 'lie'] 15 | self.sequence_len = 128 16 | self.shuffle = True 17 | self.drop_last = True 18 | self.normalize = True 19 | 20 | # model configs 21 | self.input_channels = 9 22 | self.kernel_size = 5 23 | self.stride = 1 24 | self.dropout = 0.5 25 | self.num_classes = 6 26 | 27 | # CNN and RESNET features 28 | self.mid_channels = feature_dim #16 #32 #64 29 | self.final_out_channels =feature_dim * 2 #32 # 64 #128 30 | self.features_len = 1 31 | 32 | # Teacher model features 33 | self.mid_channels_t = 64 #64 34 | self.final_out_channels_t =128 #128 35 | self.features_len_t = 1 36 | 37 | # TCN features 38 | self.tcn_layers = [75, 150] 39 | self.tcn_final_out_channles = self.tcn_layers[-1] 40 | self.tcn_kernel_size = 17 41 | self.tcn_dropout = 0.0 42 | 43 | # lstm features 44 | self.lstm_hid = 128 45 | self.lstm_n_layers = 1 46 | self.lstm_bid = False 47 | 48 | # discriminator 49 | self.disc_hid_dim = 64 50 | self.hidden_dim = 500 51 | self.DSKN_disc_hid = 128 52 | 53 | 54 | class EEG(): 55 | def __init__(self): 56 | super(EEG, self).__init__() 57 | # data parameters 58 | self.num_classes = 5 59 | self.class_names = ['W', 'N1', 'N2', 'N3', 'REM'] 60 | self.sequence_len = 3000 61 | self.scenarios = [("0", "11"), ("12", "5"), ("7", "18"), ("16", "1"), ("9", "14")] 62 | # self.scenarios = [("3", "19"), ("18", "12"), ("13", "17"), ("5", "15"), ("6", "2")] #additional 63 | self.shuffle = True 64 | self.drop_last = True 65 | self.normalize = True 66 | 67 | # model configs 68 | self.input_channels = 1 69 | self.kernel_size = 25 70 | self.stride = 6 71 | self.dropout = 0.2 72 | 73 | # features 74 | self.mid_channels = feature_dim 75 | self.final_out_channels = feature_dim*2 76 | self.features_len = 1 77 | 78 | # Teacher model features 79 | self.mid_channels_t = 64 80 | self.final_out_channels_t =128 81 | self.features_len_t = 1 82 | 83 | # TCN features 84 | self.tcn_layers = [32,64] 85 | self.tcn_final_out_channles = self.tcn_layers[-1] 86 | self.tcn_kernel_size = 15# 25 87 | self.tcn_dropout = 0.0 88 | 89 | # lstm features 90 | self.lstm_hid = 128 91 | self.lstm_n_layers = 1 92 | self.lstm_bid = False 93 | 94 | # discriminator 95 | self.DSKN_disc_hid = 128 96 | self.hidden_dim = 500 97 | self.disc_hid_dim = 100 98 | 99 | 100 | class WISDM(object): 101 | def __init__(self): 102 | super(WISDM, self).__init__() 103 | self.class_names = ['walk', 'jog', 'sit', 'stand', 'upstairs', 'downstairs'] 104 | self.sequence_len = 128 105 | self.scenarios = [("7", "18"), ("20", "30"), ("35", "31"), ("18", "23"), ("6", "19")] 106 | self.num_classes = 6 107 | self.shuffle = True 108 | self.drop_last = False 109 | self.normalize = True 110 | 111 | # model configs 112 | self.input_channels = 3 113 | self.kernel_size = 5 114 | self.stride = 1 115 | self.dropout = 0.5 116 | self.num_classes = 6 117 | 118 | # features 119 | self.mid_channels = feature_dim 120 | self.final_out_channels = feature_dim *2 121 | self.features_len = 1 122 | 123 | # Teacher model features 124 | self.mid_channels_t = 64 125 | self.final_out_channels_t =128 126 | self.features_len_t = 1 127 | 128 | # TCN features 129 | self.tcn_layers = [75,150] 130 | self.tcn_final_out_channles = self.tcn_layers[-1] 131 | self.tcn_kernel_size = 17 132 | self.tcn_dropout = 0.0 133 | 134 | # lstm features 135 | self.lstm_hid = 128 136 | self.lstm_n_layers = 1 137 | self.lstm_bid = False 138 | 139 | # discriminator 140 | self.disc_hid_dim = 64 141 | self.DSKN_disc_hid = 128 142 | self.hidden_dim = 500 143 | 144 | 145 | class HHAR_SA(object): ## HHAR dataset, SAMSUNG device. 146 | def __init__(self): 147 | super(HHAR_SA, self).__init__() 148 | self.sequence_len = 128 149 | self.scenarios = [("2", "7"), ("0", "6"), ("1", "6"), ("3", "8"), ("4", "5")] 150 | # self.scenarios = [("5", "0"), ("6", "1"), ("7", "4"), ("8", "3"), ("0", "2")] # additional 151 | self.class_names = ['bike', 'sit', 'stand', 'walk', 'stairs_up', 'stairs_down'] 152 | self.num_classes = 6 153 | self.shuffle = True 154 | self.drop_last = True 155 | self.normalize = True 156 | 157 | # model configs 158 | self.input_channels = 3 159 | self.kernel_size = 5 160 | self.stride = 1 161 | self.dropout = 0.5 162 | 163 | # features 164 | self.mid_channels =feature_dim 165 | self.final_out_channels =feature_dim *2 166 | self.features_len = 1 167 | 168 | # Teacher model features 169 | self.mid_channels_t = 64 170 | self.final_out_channels_t =128 171 | self.features_len_t = 1 172 | 173 | # TCN features 174 | self.tcn_layers = [75,150] 175 | self.tcn_final_out_channles = self.tcn_layers[-1] 176 | self.tcn_kernel_size = 17 177 | self.tcn_dropout = 0.0 178 | 179 | # lstm features 180 | self.lstm_hid = 128 181 | self.lstm_n_layers = 1 182 | self.lstm_bid = False 183 | 184 | # discriminator 185 | self.disc_hid_dim = 64 186 | self.DSKN_disc_hid = 128 187 | self.hidden_dim = 500 188 | 189 | 190 | class FD(object): 191 | def __init__(self): 192 | super(FD, self).__init__() 193 | self.sequence_len = 5120 194 | self.scenarios = [ ("0", "3"), ("0", "1"),("2", "1"), ("1", "2"),("2", "3")] 195 | # self.scenarios = [ ("1", "0"), ("1", "3"), ("3", "0"), ("3", "1"), ("3", "2")] #additional 196 | self.class_names = ['Healthy', 'D1', 'D2'] 197 | self.num_classes = 3 198 | self.shuffle = True 199 | self.drop_last = True 200 | self.normalize = True 201 | 202 | # Model configs 203 | self.input_channels = 1 204 | self.kernel_size = 32 205 | self.stride = 6 206 | self.dropout = 0.5 207 | 208 | # CNN and RESNET features 209 | self.mid_channels = feature_dim 210 | self.final_out_channels = feature_dim * 2 211 | self.features_len = 1 212 | 213 | # Teacher model features 214 | self.mid_channels_t = 64 215 | self.final_out_channels_t = 128 216 | self.features_len_t = 1 217 | 218 | # TCN features 219 | self.tcn_layers = [75, 150] 220 | self.tcn_final_out_channles = self.tcn_layers[-1] 221 | self.tcn_kernel_size = 17 222 | self.tcn_dropout = 0.0 223 | 224 | # lstm features 225 | self.lstm_hid = 128 226 | self.lstm_n_layers = 1 227 | self.lstm_bid = False 228 | 229 | # discriminator 230 | self.disc_hid_dim = 64 231 | self.DSKN_disc_hid = 128 232 | self.hidden_dim = 500 233 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn as nn 4 | 5 | import random 6 | import os 7 | import sys 8 | import logging 9 | import numpy as np 10 | import pandas as pd 11 | from shutil import copy 12 | from datetime import datetime 13 | 14 | from skorch import NeuralNetClassifier # for DIV Risk 15 | from sklearn.model_selection import train_test_split 16 | from sklearn.metrics import classification_report, accuracy_score 17 | 18 | 19 | class AverageMeter(object): 20 | """Computes and stores the average and current value""" 21 | 22 | def __init__(self): 23 | self.reset() 24 | 25 | def reset(self): 26 | self.val = 0 27 | self.avg = 0 28 | self.sum = 0 29 | self.count = 0 30 | 31 | def update(self, val, n=1): 32 | self.val = val 33 | self.sum += val * n 34 | self.count += n 35 | self.avg = self.sum / self.count 36 | 37 | 38 | def fix_randomness(SEED): 39 | random.seed(SEED) 40 | np.random.seed(SEED) 41 | torch.manual_seed(SEED) 42 | torch.cuda.manual_seed(SEED) 43 | torch.backends.cudnn.deterministic = True 44 | torch.backends.cudnn.benchmark = False 45 | 46 | 47 | def _logger(logger_name, level=logging.DEBUG): 48 | """ 49 | Method to return a custom logger with the given name and level 50 | """ 51 | logger = logging.getLogger(logger_name) 52 | logger.setLevel(level) 53 | format_string = "%(message)s" 54 | log_format = logging.Formatter(format_string) 55 | # Creating and adding the console handler 56 | console_handler = logging.StreamHandler(sys.stdout) 57 | console_handler.setFormatter(log_format) 58 | logger.addHandler(console_handler) 59 | # Creating and adding the file handler 60 | file_handler = logging.FileHandler(logger_name, mode='a') 61 | file_handler.setFormatter(log_format) 62 | logger.addHandler(file_handler) 63 | return logger 64 | 65 | 66 | def starting_logs(data_type, da_method, exp_log_dir, src_id, tgt_id, run_id): 67 | log_dir = os.path.join(exp_log_dir, src_id + "_to_" + tgt_id + "_run_" + str(run_id)) 68 | os.makedirs(log_dir, exist_ok=True) 69 | log_file_name = os.path.join(log_dir, f"logs_{datetime.now().strftime('%d_%m_%Y_%H_%M_%S')}.log") 70 | logger = _logger(log_file_name) 71 | logger.debug("=" * 45) 72 | logger.debug(f'Dataset: {data_type}') 73 | logger.debug(f'Method: {da_method}') 74 | logger.debug("=" * 45) 75 | logger.debug(f'Source: {src_id} ---> Target: {tgt_id}') 76 | logger.debug(f'Run ID: {run_id}') 77 | logger.debug("=" * 45) 78 | return logger, log_dir 79 | 80 | 81 | def save_checkpoint(home_path, algorithm, selected_scenarios, dataset_configs, log_dir, hparams): 82 | save_dict = { 83 | "x-domains": selected_scenarios, 84 | "configs": dataset_configs.__dict__, 85 | "hparams": dict(hparams), 86 | "model_dict": algorithm.state_dict(), 87 | "network_dict": algorithm.network.state_dict(), 88 | # "discriminator": algorithm.domain_classifier.state_dict() 89 | } 90 | # save classification report 91 | save_path = os.path.join(home_path, log_dir, "checkpoint.pt") 92 | 93 | torch.save(save_dict, save_path) 94 | 95 | 96 | def weights_init(m): 97 | classname = m.__class__.__name__ 98 | if classname.find('Conv') != -1: 99 | m.weight.data.normal_(0.0, 0.02) 100 | elif classname.find('BatchNorm') != -1: 101 | m.weight.data.normal_(1.0, 0.02) 102 | m.bias.data.fill_(0) 103 | elif classname.find('Linear') != -1: 104 | m.weight.data.normal_(0.0, 0.1) 105 | m.bias.data.fill_(0) 106 | 107 | 108 | def _calc_metrics(pred_labels, true_labels, log_dir, home_path, target_names): 109 | pred_labels = np.array(pred_labels).astype(int) 110 | true_labels = np.array(true_labels).astype(int) 111 | 112 | r = classification_report(true_labels, pred_labels, target_names=target_names, digits=6, output_dict=True) 113 | 114 | df = pd.DataFrame(r) 115 | accuracy = accuracy_score(true_labels, pred_labels) 116 | df["accuracy"] = accuracy 117 | df = df * 100 118 | 119 | # save classification report 120 | file_name = "classification_report.xlsx" 121 | report_Save_path = os.path.join(home_path, log_dir, file_name) 122 | df.to_excel(report_Save_path) 123 | 124 | return accuracy * 100, r["macro avg"]["f1-score"] * 100 125 | 126 | 127 | def copy_Files(destination): 128 | destination_dir = os.path.join(destination, "MODEL_BACKUP_FILES") 129 | os.makedirs(destination_dir, exist_ok=True) 130 | copy("main.py", os.path.join(destination_dir, "main.py")) 131 | copy("utils.py", os.path.join(destination_dir, "utils.py")) 132 | copy(f"trainer.py", os.path.join(destination_dir, f"trainer.py")) 133 | # copy(f"same_domain_trainer.py", os.path.join(destination_dir, f"same_domain_trainer.py")) 134 | copy("dataloader/dataloader.py", os.path.join(destination_dir, "dataloader.py")) 135 | copy(f"models/models.py", os.path.join(destination_dir, f"models.py")) 136 | copy(f"models/loss.py", os.path.join(destination_dir, f"loss.py")) 137 | copy("algorithms/algorithms.py", os.path.join(destination_dir, "algorithms.py")) 138 | copy(f"configs/data_model_configs.py", os.path.join(destination_dir, f"data_model_configs.py")) 139 | copy(f"configs/hparams.py", os.path.join(destination_dir, f"hparams.py")) 140 | copy(f"configs/sweep_params.py", os.path.join(destination_dir, f"sweep_params.py")) 141 | 142 | 143 | def get_iwcv_value(weight, error): 144 | N, d = weight.shape 145 | _N, _d = error.shape 146 | assert N == _N and d == _d, 'dimension mismatch!' 147 | weighted_error = weight * error 148 | return np.mean(weighted_error) 149 | 150 | 151 | def get_dev_value(weight, error): 152 | """ 153 | :param weight: shape [N, 1], the importance weight for N source samples in the validation set 154 | :param error: shape [N, 1], the error value for each source sample in the validation set 155 | (typically 0 for correct classification and 1 for wrong classification) 156 | """ 157 | N, d = weight.shape 158 | _N, _d = error.shape 159 | assert N == _N and d == _d, 'dimension mismatch!' 160 | weighted_error = weight * error 161 | cov = np.cov(np.concatenate((weighted_error, weight), axis=1), rowvar=False)[0][1] 162 | var_w = np.var(weight, ddof=1) 163 | eta = - cov / var_w 164 | return np.mean(weighted_error) + eta * np.mean(weight) - eta 165 | 166 | 167 | class simple_MLP(nn.Module): 168 | def __init__(self, inp_units, out_units=2): 169 | super(simple_MLP, self).__init__() 170 | 171 | self.dense0 = nn.Linear(inp_units, inp_units // 2) 172 | self.nonlin = nn.ReLU() 173 | self.output = nn.Linear(inp_units // 2, out_units) 174 | self.softmax = nn.Softmax(dim=-1) 175 | 176 | def forward(self, x, **kwargs): 177 | x = self.nonlin(self.dense0(x)) 178 | x = self.softmax(self.output(x)) 179 | return x 180 | 181 | 182 | def get_weight_gpu(source_feature, target_feature, validation_feature, configs, device): 183 | """ 184 | :param source_feature: shape [N_tr, d], features from training set 185 | :param target_feature: shape [N_te, d], features from test set 186 | :param validation_feature: shape [N_v, d], features from validation set 187 | :return: 188 | """ 189 | import copy 190 | N_s, d = source_feature.shape 191 | N_t, _d = target_feature.shape 192 | source_feature = copy.deepcopy(source_feature.detach().cpu()) # source_feature.clone() 193 | target_feature = copy.deepcopy(target_feature.detach().cpu()) # target_feature.clone() 194 | source_feature = source_feature.to(device) 195 | target_feature = target_feature.to(device) 196 | all_feature = torch.cat((source_feature, target_feature), dim=0) 197 | all_label = torch.from_numpy(np.asarray([1] * N_s + [0] * N_t, dtype=np.int32)).long() 198 | 199 | feature_for_train, feature_for_test, label_for_train, label_for_test = train_test_split(all_feature, all_label, 200 | train_size=0.8) 201 | learning_rates = [1e-1, 5e-2, 1e-2] 202 | val_acc = [] 203 | domain_classifiers = [] 204 | 205 | for lr in learning_rates: 206 | domain_classifier = NeuralNetClassifier( 207 | simple_MLP, 208 | module__inp_units=configs.final_out_channels * configs.features_len, 209 | max_epochs=30, 210 | lr=lr, 211 | device=device, 212 | # Shuffle training data on each epoch 213 | iterator_train__shuffle=True, 214 | callbacks="disable" 215 | ) 216 | domain_classifier.fit(feature_for_train.float(), label_for_train.long()) 217 | output = domain_classifier.predict(feature_for_test) 218 | acc = np.mean((label_for_test.numpy() == output).astype(np.float32)) 219 | val_acc.append(acc) 220 | domain_classifiers.append(domain_classifier) 221 | 222 | index = val_acc.index(max(val_acc)) 223 | domain_classifier = domain_classifiers[index] 224 | 225 | domain_out = domain_classifier.predict_proba(validation_feature.to(device).float()) 226 | return domain_out[:, :1] / domain_out[:, 1:] * N_s * 1.0 / N_t 227 | 228 | 229 | def calc_dev_risk(target_model, src_train_dl, tgt_train_dl, src_valid_dl, configs, device): 230 | src_train_feats = target_model.feature_extractor(src_train_dl.dataset.x_data.float().to(device)) 231 | tgt_train_feats = target_model.feature_extractor(tgt_train_dl.dataset.x_data.float().to(device)) 232 | src_valid_feats = target_model.feature_extractor(src_valid_dl.dataset.x_data.float().to(device)) 233 | src_valid_pred = target_model.classifier(src_valid_feats) 234 | 235 | dev_weights = get_weight_gpu(src_train_feats.to(device), tgt_train_feats.to(device), 236 | src_valid_feats.to(device), configs, device) 237 | dev_error = F.cross_entropy(src_valid_pred, src_valid_dl.dataset.y_data.long().to(device), reduction='none') 238 | dev_risk = get_dev_value(dev_weights, dev_error.unsqueeze(1).detach().cpu().numpy()) 239 | # iwcv_risk = get_iwcv_value(dev_weights, dev_error.unsqueeze(1).detach().cpu().numpy()) 240 | return dev_risk 241 | 242 | 243 | def calculate_risk(target_model, risk_dataloader, device): 244 | if type(risk_dataloader) == tuple: 245 | x_data = torch.cat((risk_dataloader[0].dataset.x_data, risk_dataloader[1].dataset.x_data), axis=0) 246 | y_data = torch.cat((risk_dataloader[0].dataset.y_data, risk_dataloader[1].dataset.y_data), axis=0) 247 | else: 248 | x_data = risk_dataloader.dataset.x_data 249 | y_data = risk_dataloader.dataset.y_data 250 | 251 | feat = target_model.feature_extractor(x_data.float().to(device)) 252 | pred = target_model.classifier(feat) 253 | cls_loss = F.cross_entropy(pred, y_data.long().to(device)) 254 | return cls_loss.item() 255 | 256 | 257 | # For DIRT-T 258 | class EMA: 259 | def __init__(self, decay): 260 | self.decay = decay 261 | self.shadow = {} 262 | 263 | def register(self, model): 264 | for name, param in model.named_parameters(): 265 | if param.requires_grad: 266 | self.shadow[name] = param.data.clone() 267 | self.params = self.shadow.keys() 268 | 269 | def __call__(self, model): 270 | if self.decay > 0: 271 | for name, param in model.named_parameters(): 272 | if name in self.params and param.requires_grad: 273 | self.shadow[name] -= (1 - self.decay) * (self.shadow[name] - param.data) 274 | param.data = self.shadow[name] 275 | 276 | 277 | def jitter(x, device, sigma=0.03): 278 | # https://arxiv.org/pdf/1706.00527.pdf 279 | return x + torch.from_numpy(np.random.normal(loc=0., scale=sigma, size=x.shape)).float().to(device) 280 | -------------------------------------------------------------------------------- /configs/hparams.py: -------------------------------------------------------------------------------- 1 | 2 | ## The cuurent hyper-parameters values are not necessarily the best ones for a specific risk. 3 | def get_hparams_class(dataset_name): 4 | """Return the algorithm class with the given name.""" 5 | if dataset_name not in globals(): 6 | raise NotImplementedError("Dataset not found: {}".format(dataset_name)) 7 | return globals()[dataset_name] 8 | 9 | 10 | class HAR(): 11 | def __init__(self): 12 | super(HAR, self).__init__() 13 | self.train_params = { 14 | 'num_epochs': 80, 15 | 'batch_size': 32, 16 | 'weight_decay': 1e-4, 17 | 'learning_rate': 1e-2 18 | } 19 | self.alg_hparams = { 20 | 'DANN_T': {'learning_rate': 1e-2, 'src_cls_loss_wt': 1, 'domain_loss_wt': 1}, 21 | 'JointADKD': {'learning_rate': 1e-2, 'temperature': 2, 'kd_loss_wt':1, 'n_classes':6}, 22 | 'RL_JointADKD': {'learning_rate': 1e-2, 'temperature': 2, 'kd_loss_wt': 10, 'dc_loss_wt':1, 'n_classes': 6,'episode':5,'ddqn_lr':0.0001}, 23 | 24 | 'DAKD': {'learning_rate': 1e-3, 'dis_learning_rate':1e-4, 'temperature': 4}, 25 | 'KDDA': {'learning_rate': 1e-2, 'temperature': 2, "kd_loss_wt":1}, 26 | 'DANN': {'learning_rate': 0.0005, 'src_cls_loss_wt': 1, 'domain_loss_wt': 1}, 27 | 'KDSTDA': {'learning_rate': 1e-2, 'temperature': 2}, 28 | 'MCD': {'learning_rate': 1e-2}, 29 | 'MLD': {'learning_rate': 1e-2, 'tgt_loss_wt':1}, 30 | 'REDA': {'learning_rate': 1e-2, 'temperature': 2.5}, 31 | 'MobileDA': {'learning_rate': 1e-2, 'temperature': 2}, 32 | 'AAD': {'learning_rate': 1e-2, 'temperature': 4, 'src_cls_loss_wt': 1, 'soft_loss_wt': 1,'errG': 0.1}, 33 | 34 | 'AdvCDKD': {'learning_rate': 1e-2, 'temperature': 4, 'src_cls_loss_wt': 1, 'domain_loss_wt': 1, 'errG': 0.1 }, 35 | 'AdvCDKDv2': {'learning_rate': 1e-2, 'temperature': 4, 'src_cls_loss_wt': 1, 'domain_loss_wt': 1, 'errG': 1}, 36 | 'CDKD': {'learning_rate': 1e-2, 'temperature': 4, 'src_cls_loss_wt': 1, 'domain_loss_wt':1, "soft_loss_wt":1}, 37 | 'Deep_Coral': {'learning_rate': 5e-3, 'src_cls_loss_wt': 8.67, 'coral_wt': 0.44}, 38 | 'DDC': {'learning_rate': 5e-3, 'src_cls_loss_wt': 6.24, 'domain_loss_wt': 6.36}, 39 | 'HoMM': {'learning_rate': 1e-3, 'src_cls_loss_wt': 2.15, 'domain_loss_wt': 9.13}, 40 | 'CoDATS': {'learning_rate': 1e-3, 'src_cls_loss_wt': 6.21, 'domain_loss_wt': 1.72}, 41 | 'DSAN': {'learning_rate': 5e-4, 'src_cls_loss_wt': 1.76, 'domain_loss_wt': 1.59}, 42 | 'AdvSKM': {'learning_rate': 5e-3, 'src_cls_loss_wt': 3.05, 'domain_loss_wt': 2.876}, 43 | 'MMDA': {'learning_rate': 1e-3, 'src_cls_loss_wt': 6.13, 'mmd_wt': 2.37, 'coral_wt': 8.63, 'cond_ent_wt': 7.16}, 44 | 'CDAN': {'learning_rate': 1e-2, 'src_cls_loss_wt': 5.19, 'domain_loss_wt': 2.91, 'cond_ent_wt': 1.73}, 45 | 'DIRT': {'learning_rate': 5e-4, 'src_cls_loss_wt': 7.00, 'domain_loss_wt': 4.51, 'cond_ent_wt': 0.79, 'vat_loss_wt': 9.31} 46 | } 47 | 48 | 49 | class EEG(): 50 | def __init__(self): 51 | super(EEG, self).__init__() 52 | self.train_params = { 53 | 'num_epochs': 80, 54 | 'batch_size': 128, 55 | 'weight_decay': 1e-4, 56 | 'learning_rate': 1e-2 57 | } 58 | self.alg_hparams = { 59 | 'DANN_T': {'learning_rate': 1e-2, 'src_cls_loss_wt': 1, 'domain_loss_wt': 1}, 60 | 'JointADKD': {'learning_rate': 1e-2, 'temperature': 1, 'kd_loss_wt': 1,'n_classes':5}, #3 61 | 'RL_JointADKD': {'learning_rate': 1e-2, 'temperature': 1, 'kd_loss_wt': 1, 'dc_loss_wt':1,'n_classes': 5, 'episode':5,'ddqn_lr':0.0001}, 62 | 63 | 'DAKD': {'learning_rate': 1e-3, 'temperature': 1, 'dis_learning_rate':1e-4}, 64 | 'KDDA': {'learning_rate': 1e-2, 'temperature': 1, "kd_loss_wt":1}, 65 | 'DANN': {'learning_rate': 0.0005, 'src_cls_loss_wt': 1, 'domain_loss_wt': 1, }, 66 | 'KDSTDA': {'learning_rate': 1e-2, 'temperature': 2}, 67 | 'MCD': {'learning_rate': 1e-2}, 68 | 'MLD': {'learning_rate': 1e-2, 'tgt_loss_wt': 1}, 69 | 'REDA': {'learning_rate': 1e-2, 'temperature': 2.5}, 70 | 'MobileDA': {'learning_rate': 1e-2, 'temperature': 2}, 71 | 'AAD': {'learning_rate': 1e-2, 'temperature': 4, 'src_cls_loss_wt': 1, 'soft_loss_wt': 1, 'errG': 0.1}, 72 | 73 | 'AdvCDKD': {'learning_rate': 1e-2, 'temperature': 4, 'src_cls_loss_wt': 1, 'domain_loss_wt': 1, 'errG': 0.1 }, 74 | 'AdvCDKDv2': {'learning_rate': 1e-2, 'temperature': 4, 'src_cls_loss_wt': 1, 'domain_loss_wt': 0.01, 'errG': 0.1 }, 75 | 'CDKD': {'learning_rate': 1e-2, 'temperature': 4, 'src_cls_loss_wt': 1, 'domain_loss_wt': 0.1, "soft_loss_wt":0.1}, 76 | 'Deep_Coral': {'learning_rate': 0.0005, 'src_cls_loss_wt': 9.39, 'coral_wt': 0.19, }, 77 | 'DDC': {'learning_rate': 0.0005, 'src_cls_loss_wt': 2.951, 'domain_loss_wt': 8.923, }, 78 | 'HoMM': {'learning_rate': 0.0005, 'src_cls_loss_wt': 0.197, 'domain_loss_wt': 1.102, }, 79 | 'CoDATS': {'learning_rate': 0.01, 'src_cls_loss_wt': 9.239, 'domain_loss_wt': 1.342, }, 80 | 'DSAN': {'learning_rate': 0.001, 'src_cls_loss_wt': 6.713, 'domain_loss_wt': 6.708, }, 81 | 'AdvSKM': {'learning_rate': 0.0005, 'src_cls_loss_wt': 2.50, 'domain_loss_wt': 2.50, }, 82 | 'MMDA': {'learning_rate': 0.0005, 'src_cls_loss_wt': 4.48, 'mmd_wt': 5.951, 'coral_wt': 3.36, 'cond_ent_wt': 6.13, }, 83 | 'CDAN': {'learning_rate': 0.001, 'src_cls_loss_wt': 6.803, 'domain_loss_wt': 4.726, 'cond_ent_wt': 1.307, }, 84 | 'DIRT': {'learning_rate': 0.005, 'src_cls_loss_wt': 9.183, 'domain_loss_wt': 7.411, 'cond_ent_wt': 2.564, 'vat_loss_wt': 3.583, }, 85 | } 86 | 87 | 88 | 89 | class HHAR_SA(): 90 | def __init__(self): 91 | super(HHAR_SA, self).__init__() 92 | self.train_params = { 93 | 'num_epochs': 80, 94 | 'batch_size': 32, 95 | 'weight_decay': 1e-4, 96 | 'learning_rate': 1e-2 97 | } 98 | self.alg_hparams = { 99 | 'DANN_T': {'learning_rate': 0.0005, 'src_cls_loss_wt': 0.9603, 'domain_loss_wt':0.9238}, 100 | 'JointADKD': {'learning_rate': 1e-2, 'temperature': 2, 'kd_loss_wt': 1,'n_classes':6}, 101 | 'RL_JointADKD': {'learning_rate': 1e-2, 'temperature': 1, 'kd_loss_wt': 1,'dc_loss_wt':1, 'n_classes': 6,'episode':5,'ddqn_lr':0.0001}, 102 | 'DAKD': {'learning_rate': 0.0005, 'dis_learning_rate':3e-2, 'temperature': 1}, 103 | 'KDDA': {'learning_rate': 1e-2, 'temperature': 1, "kd_loss_wt":1}, 104 | 'DANN': {'learning_rate': 0.0005, 'src_cls_loss_wt': 1.0, 'domain_loss_wt': 1.0}, 105 | 'KDSTDA': {'learning_rate': 1e-2, 'temperature': 2}, 106 | 'MCD': {'learning_rate': 1e-2}, 107 | 'MLD': {'learning_rate': 1e-2, 'tgt_loss_wt': 1}, 108 | 'REDA': {'learning_rate': 1e-2, 'temperature': 2.5}, 109 | 'MobileDA': {'learning_rate': 1e-2, 'temperature': 2}, 110 | 'AAD': {'learning_rate': 1e-2, 'temperature': 4, 'src_cls_loss_wt': 1, 'soft_loss_wt': 1, 'errG': 0.1}, 111 | 112 | 'AdvCDKD': {'learning_rate': 1e-2, 'temperature': 4, 'src_cls_loss_wt': 1, 'domain_loss_wt': 1,"soft_loss_wt": 0.1, 'errG': 0.1}, 113 | 'AdvCDKDv2': {'learning_rate': 1e-2, 'temperature': 4, 'src_cls_loss_wt': 1, 'domain_loss_wt': 1,'errG': 1}, 114 | 'CDKD': {'learning_rate': 1e-2, 'temperature': 4, 'src_cls_loss_wt': 1, 'domain_loss_wt': 0.1, "soft_loss_wt":1}, 115 | 'Deep_Coral': {'learning_rate': 0.0005, 'src_cls_loss_wt': 0.05931, 'coral_wt': 8.452}, 116 | 'DDC': {'learning_rate': 0.01, 'src_cls_loss_wt': 0.1593, 'domain_loss_wt': 0.2048}, 117 | 'HoMM': {'learning_rate':0.001, 'src_cls_loss_wt': 0.2429, 'domain_loss_wt': 0.9824}, 118 | 'CoDATS': {'learning_rate': 0.0005, 'src_cls_loss_wt': 0.5416, 'domain_loss_wt': 0.5582}, 119 | 'DSAN': {'learning_rate': 0.005, 'src_cls_loss_wt':0.4133, 'domain_loss_wt': 0.16}, 120 | 'AdvSKM': {'learning_rate': 0.001, 'src_cls_loss_wt': 0.4637, 'domain_loss_wt': 0.1511}, 121 | 'MMDA': {'learning_rate': 0.001, 'src_cls_loss_wt': 0.9505, 'mmd_wt': 0.5476, 'cond_ent_wt': 0.5167, 'coral_wt': 0.5838, }, 122 | 'CDAN': {'learning_rate': 0.001, 'src_cls_loss_wt': 0.6636, 'domain_loss_wt': 0.1954, 'cond_ent_wt':0.0124}, 123 | 'DIRT': {'learning_rate': 0.001, 'src_cls_loss_wt': 0.9752, 'domain_loss_wt': 0.3892, 'cond_ent_wt': 0.09228, 'vat_loss_wt': 0.1947} 124 | } 125 | 126 | 127 | class FD(): 128 | def __init__(self): 129 | super(FD, self).__init__() 130 | self.train_params = { 131 | 'num_epochs': 80, 132 | 'batch_size': 32, 133 | 'weight_decay': 1e-4, 134 | 'learning_rate': 1e-2 135 | } 136 | self.alg_hparams = { 137 | 'DANN_T': {'learning_rate': 1e-2, 'src_cls_loss_wt': 1, 'domain_loss_wt': 1}, 138 | 'JointADKD': {'learning_rate': 1e-2, 'temperature': 2, 'kd_loss_wt': 1,'n_classes':3}, 139 | 'RL_JointADKD': {'learning_rate': 1e-2, 'temperature': 2, 'kd_loss_wt': 1, 'dc_loss_wt':10, 'n_classes': 3,'episode':5,'ddqn_lr':0.0001}, 140 | 141 | 'DAKD': {'learning_rate': 1e-2, 'temperature': 2, 'dis_learning_rate':1e-5}, 142 | 'KDDA': {'learning_rate': 1e-2, 'temperature': 1, "kd_loss_wt":1}, 143 | 'DANN': {'learning_rate': 0.0005, 'src_cls_loss_wt': 1, 'domain_loss_wt': 1}, 144 | 'KDSTDA': {'learning_rate': 1e-2, 'temperature': 2}, 145 | 'MCD': {'learning_rate': 1e-2}, 146 | 'MLD': {'learning_rate': 1e-2, 'tgt_loss_wt': 1}, 147 | 'REDA': {'learning_rate': 1e-2, 'temperature': 2.5}, 148 | 'MobileDA': {'learning_rate': 1e-2, 'temperature': 2}, 149 | 'AAD': {'learning_rate': 1e-2, 'temperature': 4, 'src_cls_loss_wt': 1, 'soft_loss_wt': 1, 'errG': 0.1}, 150 | 'AdvCDKD': {'learning_rate': 1e-2, 'temperature': 4, 'src_cls_loss_wt': 1, 'domain_loss_wt': 1,"soft_loss_wt": 0.1, 'errG': 0.1}, 151 | 'AdvCDKDv2': {'learning_rate': 1e-2, 'temperature': 4, 'src_cls_loss_wt': 1, 'domain_loss_wt': 1, 'errG': 1}, 152 | 'CDKD': {'learning_rate': 1e-2, 'temperature': 4, 'src_cls_loss_wt': 1, 'domain_loss_wt': 0.1,"soft_loss_wt": 1}, 153 | 'Deep_Coral': {'learning_rate': 0.0005, 'src_cls_loss_wt': 0.05931, 'coral_wt': 8.452}, 154 | 'DDC': {'learning_rate': 0.01, 'src_cls_loss_wt': 0.1593, 'domain_loss_wt': 0.2048}, 155 | 'HoMM': {'learning_rate':0.001, 'src_cls_loss_wt': 0.2429, 'domain_loss_wt': 0.9824}, 156 | 'CoDATS': {'learning_rate': 0.0005, 'src_cls_loss_wt': 0.5416, 'domain_loss_wt': 0.5582}, 157 | 'DSAN': {'learning_rate': 0.005, 'src_cls_loss_wt':0.4133, 'domain_loss_wt': 0.16}, 158 | 'AdvSKM': {'learning_rate': 0.001, 'src_cls_loss_wt': 0.4637, 'domain_loss_wt': 0.1511}, 159 | 'MMDA': {'learning_rate': 0.001, 'src_cls_loss_wt': 0.9505, 'mmd_wt': 0.5476, 'cond_ent_wt': 0.5167, 'coral_wt': 0.5838, }, 160 | 'CDAN': {'learning_rate': 0.001, 'src_cls_loss_wt': 0.5, 'domain_loss_wt': 0.1, 'cond_ent_wt':0.1}, 161 | 'DIRT': {'learning_rate': 0.001, 'src_cls_loss_wt': 1.0, 'domain_loss_wt': 0.5, 'cond_ent_wt': 0.1, 'vat_loss_wt': 0.1} 162 | } 163 | -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class ConditionalEntropyLoss(torch.nn.Module): 8 | def __init__(self): 9 | super(ConditionalEntropyLoss, self).__init__() 10 | 11 | def forward(self, x): 12 | b = F.softmax(x, dim=1) * F.log_softmax(x, dim=1) 13 | b = b.sum(dim=1) 14 | return -1.0 * b.mean(dim=0) 15 | 16 | 17 | class VAT(nn.Module): 18 | def __init__(self, model, device): 19 | super(VAT, self).__init__() 20 | self.n_power = 1 21 | self.XI = 1e-6 22 | self.model = model 23 | self.epsilon = 3.5 24 | self.device = device 25 | 26 | def forward(self, X, logit): 27 | vat_loss = self.virtual_adversarial_loss(X, logit) 28 | return vat_loss 29 | 30 | def generate_virtual_adversarial_perturbation(self, x, logit): 31 | d = torch.randn_like(x, device=self.device) 32 | 33 | for _ in range(self.n_power): 34 | d = self.XI * self.get_normalized_vector(d).requires_grad_() 35 | logit_m = self.model(x + d) 36 | dist = self.kl_divergence_with_logit(logit, logit_m) 37 | grad = torch.autograd.grad(dist, [d])[0] 38 | d = grad.detach() 39 | 40 | return self.epsilon * self.get_normalized_vector(d) 41 | 42 | def kl_divergence_with_logit(self, q_logit, p_logit): 43 | q = F.softmax(q_logit, dim=1) 44 | qlogq = torch.mean(torch.sum(q * F.log_softmax(q_logit, dim=1), dim=1)) 45 | qlogp = torch.mean(torch.sum(q * F.log_softmax(p_logit, dim=1), dim=1)) 46 | return qlogq - qlogp 47 | 48 | def get_normalized_vector(self, d): 49 | return F.normalize(d.view(d.size(0), -1), p=2, dim=1).reshape(d.size()) 50 | 51 | def virtual_adversarial_loss(self, x, logit): 52 | r_vadv = self.generate_virtual_adversarial_perturbation(x, logit) 53 | logit_p = logit.detach() 54 | logit_m = self.model(x + r_vadv) 55 | loss = self.kl_divergence_with_logit(logit_p, logit_m) 56 | return loss 57 | 58 | 59 | class MMD_loss(nn.Module): 60 | def __init__(self, kernel_type='rbf', kernel_mul=2.0, kernel_num=5): 61 | super(MMD_loss, self).__init__() 62 | self.kernel_num = kernel_num 63 | self.kernel_mul = kernel_mul 64 | self.fix_sigma = None 65 | self.kernel_type = kernel_type 66 | 67 | def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): 68 | n_samples = int(source.size()[0]) + int(target.size()[0]) 69 | total = torch.cat([source, target], dim=0) 70 | total0 = total.unsqueeze(0).expand( 71 | int(total.size(0)), int(total.size(0)), int(total.size(1))) 72 | total1 = total.unsqueeze(1).expand( 73 | int(total.size(0)), int(total.size(0)), int(total.size(1))) 74 | L2_distance = ((total0 - total1) ** 2).sum(2) 75 | if fix_sigma: 76 | bandwidth = fix_sigma 77 | else: 78 | bandwidth = torch.sum(L2_distance.data) / (n_samples ** 2 - n_samples) 79 | bandwidth /= kernel_mul ** (kernel_num // 2) 80 | bandwidth_list = [bandwidth * (kernel_mul ** i) 81 | for i in range(kernel_num)] 82 | kernel_val = [torch.exp(-L2_distance / bandwidth_temp) 83 | for bandwidth_temp in bandwidth_list] 84 | return sum(kernel_val) 85 | 86 | def linear_mmd2(self, f_of_X, f_of_Y): 87 | loss = 0.0 88 | delta = f_of_X.float().mean(0) - f_of_Y.float().mean(0) 89 | loss = delta.dot(delta.T) 90 | return loss 91 | 92 | def forward(self, source, target): 93 | if self.kernel_type == 'linear': 94 | return self.linear_mmd2(source, target) 95 | elif self.kernel_type == 'rbf': 96 | batch_size = int(source.size()[0]) 97 | kernels = self.guassian_kernel( 98 | source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma) 99 | with torch.no_grad(): 100 | XX = torch.mean(kernels[:batch_size, :batch_size]) 101 | YY = torch.mean(kernels[batch_size:, batch_size:]) 102 | XY = torch.mean(kernels[:batch_size, batch_size:]) 103 | YX = torch.mean(kernels[batch_size:, :batch_size]) 104 | loss = torch.mean(XX + YY - XY - YX) 105 | torch.cuda.empty_cache() 106 | return loss 107 | 108 | 109 | class CORAL(nn.Module): 110 | def __init__(self): 111 | super(CORAL, self).__init__() 112 | 113 | def forward(self, source, target): 114 | d = source.size(1) 115 | 116 | # source covariance 117 | xm = torch.mean(source, 0, keepdim=True) - source 118 | xc = xm.t() @ xm 119 | 120 | # target covariance 121 | xmt = torch.mean(target, 0, keepdim=True) - target 122 | xct = xmt.t() @ xmt 123 | 124 | # frobenius norm between source and target 125 | loss = torch.mean(torch.mul((xc - xct), (xc - xct))) 126 | loss = loss / (4 * d * d) 127 | return loss 128 | 129 | 130 | ### FOR DCAN ####################### 131 | def EntropyLoss(input_): 132 | mask = input_.ge(0.0000001) 133 | mask_out = torch.masked_select(input_, mask) 134 | entropy = - (torch.sum(mask_out * torch.log(mask_out))) 135 | return entropy / float(input_.size(0)) 136 | 137 | 138 | def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): 139 | n_samples = int(source.size()[0]) + int(target.size()[0]) 140 | total = torch.cat([source, target], dim=0) 141 | total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) 142 | total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) 143 | L2_distance = ((total0 - total1) ** 2).sum(2) 144 | if fix_sigma: 145 | bandwidth = fix_sigma 146 | else: 147 | bandwidth = torch.sum(L2_distance.data) / (n_samples ** 2 - n_samples) 148 | bandwidth /= kernel_mul ** (kernel_num // 2) 149 | bandwidth_list = [bandwidth * (kernel_mul ** i) for i in range(kernel_num)] 150 | kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list] 151 | return sum(kernel_val) # /len(kernel_val) 152 | 153 | 154 | def MMD(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): 155 | batch_size = int(source.size()[0]) 156 | kernels = guassian_kernel(source, target, 157 | kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma) 158 | loss = 0 159 | for i in range(batch_size): 160 | s1, s2 = i, (i + 1) % batch_size 161 | t1, t2 = s1 + batch_size, s2 + batch_size 162 | loss += kernels[s1, s2] + kernels[t1, t2] 163 | loss -= kernels[s1, t2] + kernels[s2, t1] 164 | return loss / float(batch_size) 165 | 166 | 167 | def MMD_reg(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): 168 | batch_size_source = int(source.size()[0]) 169 | batch_size_target = int(target.size()[0]) 170 | kernels = guassian_kernel(source, target, 171 | kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma) 172 | loss = 0 173 | for i in range(batch_size_source): 174 | s1, s2 = i, (i + 1) % batch_size_source 175 | t1, t2 = s1 + batch_size_target, s2 + batch_size_target 176 | loss += kernels[s1, s2] + kernels[t1, t2] 177 | loss -= kernels[s1, t2] + kernels[s2, t1] 178 | return loss / float(batch_size_source + batch_size_target) 179 | 180 | 181 | ### FOR HoMM ####################### 182 | class HoMM_loss(nn.Module): 183 | def __init__(self): 184 | super(HoMM_loss, self).__init__() 185 | 186 | def forward(self, xs, xt): 187 | xs = xs - torch.mean(xs, axis=0) 188 | xt = xt - torch.mean(xt, axis=0) 189 | xs = torch.unsqueeze(xs, axis=-1) 190 | xs = torch.unsqueeze(xs, axis=-1) 191 | xt = torch.unsqueeze(xt, axis=-1) 192 | xt = torch.unsqueeze(xt, axis=-1) 193 | xs_1 = xs.permute(0, 2, 1, 3) 194 | xs_2 = xs.permute(0, 2, 3, 1) 195 | xt_1 = xt.permute(0, 2, 1, 3) 196 | xt_2 = xt.permute(0, 2, 3, 1) 197 | HR_Xs = xs * xs_1 * xs_2 # dim: b*L*L*L 198 | HR_Xs = torch.mean(HR_Xs, axis=0) # dim: L*L*L 199 | HR_Xt = xt * xt_1 * xt_2 200 | HR_Xt = torch.mean(HR_Xt, axis=0) 201 | return torch.mean((HR_Xs - HR_Xt) ** 2) 202 | 203 | 204 | ### FOR DSAN ####################### 205 | class LMMD_loss(nn.Module): 206 | def __init__(self, device, class_num=3, kernel_type='rbf', kernel_mul=2.0, kernel_num=5, fix_sigma=None): 207 | super(LMMD_loss, self).__init__() 208 | self.class_num = class_num 209 | self.kernel_num = kernel_num 210 | self.kernel_mul = kernel_mul 211 | self.fix_sigma = fix_sigma 212 | self.kernel_type = kernel_type 213 | self.device = device 214 | 215 | def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): 216 | n_samples = int(source.size()[0]) + int(target.size()[0]) 217 | total = torch.cat([source, target], dim=0) 218 | total0 = total.unsqueeze(0).expand( 219 | int(total.size(0)), int(total.size(0)), int(total.size(1))) 220 | total1 = total.unsqueeze(1).expand( 221 | int(total.size(0)), int(total.size(0)), int(total.size(1))) 222 | L2_distance = ((total0 - total1) ** 2).sum(2) 223 | if fix_sigma: 224 | bandwidth = fix_sigma 225 | else: 226 | bandwidth = torch.sum(L2_distance.data) / (n_samples ** 2 - n_samples) 227 | bandwidth /= kernel_mul ** (kernel_num // 2) 228 | bandwidth_list = [bandwidth * (kernel_mul ** i) 229 | for i in range(kernel_num)] 230 | kernel_val = [torch.exp(-L2_distance / bandwidth_temp) 231 | for bandwidth_temp in bandwidth_list] 232 | return sum(kernel_val) 233 | 234 | def get_loss(self, source, target, s_label, t_label): 235 | batch_size = source.size()[0] 236 | weight_ss, weight_tt, weight_st = self.cal_weight( 237 | s_label, t_label, batch_size=batch_size, class_num=self.class_num) 238 | weight_ss = torch.from_numpy(weight_ss).to(self.device) 239 | weight_tt = torch.from_numpy(weight_tt).to(self.device) 240 | weight_st = torch.from_numpy(weight_st).to(self.device) 241 | 242 | kernels = self.guassian_kernel(source, target, 243 | kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma) 244 | loss = torch.Tensor([0]).to(self.device) 245 | if torch.sum(torch.isnan(sum(kernels))): 246 | return loss 247 | SS = kernels[:batch_size, :batch_size] 248 | TT = kernels[batch_size:, batch_size:] 249 | ST = kernels[:batch_size, batch_size:] 250 | 251 | loss += torch.sum(weight_ss * SS + weight_tt * TT - 2 * weight_st * ST) 252 | return loss 253 | 254 | def convert_to_onehot(self, sca_label, class_num=31): 255 | return np.eye(class_num)[sca_label] 256 | 257 | def cal_weight(self, s_label, t_label, batch_size=32, class_num=4): 258 | batch_size = s_label.size()[0] 259 | s_sca_label = s_label.cpu().data.numpy() 260 | s_vec_label = self.convert_to_onehot(s_sca_label, class_num=self.class_num) 261 | s_sum = np.sum(s_vec_label, axis=0).reshape(1, class_num) 262 | s_sum[s_sum == 0] = 100 263 | s_vec_label = s_vec_label / s_sum 264 | 265 | t_sca_label = t_label.cpu().data.max(1)[1].numpy() 266 | t_vec_label = t_label.cpu().data.numpy() 267 | t_sum = np.sum(t_vec_label, axis=0).reshape(1, class_num) 268 | t_sum[t_sum == 0] = 100 269 | t_vec_label = t_vec_label / t_sum 270 | 271 | index = list(set(s_sca_label) & set(t_sca_label)) 272 | mask_arr = np.zeros((batch_size, class_num)) 273 | mask_arr[:, index] = 1 274 | t_vec_label = t_vec_label * mask_arr 275 | s_vec_label = s_vec_label * mask_arr 276 | 277 | weight_ss = np.matmul(s_vec_label, s_vec_label.T) 278 | weight_tt = np.matmul(t_vec_label, t_vec_label.T) 279 | weight_st = np.matmul(s_vec_label, t_vec_label.T) 280 | 281 | length = len(index) 282 | if length != 0: 283 | weight_ss = weight_ss / length 284 | weight_tt = weight_tt / length 285 | weight_st = weight_st / length 286 | else: 287 | weight_ss = np.array([0]) 288 | weight_tt = np.array([0]) 289 | weight_st = np.array([0]) 290 | return weight_ss.astype('float32'), weight_tt.astype('float32'), weight_st.astype('float32') 291 | -------------------------------------------------------------------------------- /single_domain_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | import os 5 | import wandb 6 | import pandas as pd 7 | import numpy as np 8 | from dataloader.dataloader import data_generator, few_shot_data_generator 9 | from configs.data_model_configs import get_dataset_class 10 | from configs.hparams import get_hparams_class 11 | 12 | from configs.sweep_params import sweep_alg_hparams 13 | from utils import fix_randomness, copy_Files, starting_logs, save_checkpoint, _calc_metrics 14 | from utils import calc_dev_risk, calculate_risk 15 | import warnings 16 | 17 | import sklearn.exceptions 18 | 19 | warnings.filterwarnings("ignore", category=sklearn.exceptions.UndefinedMetricWarning) 20 | 21 | import collections 22 | from algorithms.algorithms import get_algorithm_class 23 | from models.models import get_backbone_class 24 | from utils import AverageMeter 25 | import argparse 26 | 27 | torch.backends.cudnn.benchmark = True # to fasten TCN 28 | 29 | class same_domain_Trainer(object): 30 | """ 31 | This class contain the main training functions for our pretrainer 32 | """ 33 | 34 | def __init__(self, args): 35 | self.da_method = args.da_method # Selected DA Method 36 | self.data_type = args.selected_dataset # Selected Dataset 37 | self.backbone = args.backbone 38 | self.device = torch.device(args.device) # device 39 | self.num_sweeps = args.num_sweeps 40 | 41 | # Exp Description 42 | self.run_description = args.da_method + '_' + args.backbone 43 | self.experiment_description = args.selected_dataset 44 | self.is_sweep = args.is_sweep 45 | 46 | # paths 47 | self.home_path = os.getcwd() 48 | self.save_dir = 'experiments_logs' 49 | self.data_path = os.path.join(args.data_path, self.data_type) 50 | self.create_save_dir() 51 | 52 | # Specify runs 53 | self.num_runs = args.num_runs 54 | 55 | # get dataset and base model configs 56 | self.dataset_configs, self.hparams_class = self.get_configs() 57 | 58 | # to fix dimension of features in classifier and discriminator networks. 59 | self.dataset_configs.final_out_channels = self.dataset_configs.tcn_final_out_channles if args.backbone == "TCN" else self.dataset_configs.final_out_channels 60 | self.dataset_configs.final_out_channels = self.dataset_configs.lstm_hid if args.backbone == "LSTM" else self.dataset_configs.final_out_channels 61 | 62 | # Specify number of hparams 63 | self.default_hparams = {**self.hparams_class.train_params} 64 | 65 | def sweep(self): 66 | 67 | # sweep configurations 68 | sweep_runs_count = self.num_sweeps 69 | sweep_config = { 70 | 'method': "random", 71 | 'metric': {'name': f'src_risk', 'goal': 'minimize'}, 72 | 'name': self.da_method, 73 | 'parameters': {**sweep_alg_hparams[self.da_method]} 74 | } 75 | sweep_id = wandb.sweep(sweep_config, project='HHAR_SA', entity='iclr_rebuttal') 76 | wandb.agent(sweep_id, self.train, count=sweep_runs_count) # Training with sweep 77 | 78 | # resuming sweep 79 | # wandb.agent('8wkaibgr', self.train, count=25,project='HHAR_SA_Resnet', entity= 'iclr_rebuttal' ) 80 | 81 | def train(self): 82 | if self.is_sweep: 83 | wandb.init(config=self.default_hparams) 84 | run_name = f"sweep_{self.data_type}" 85 | else: 86 | run_name = f"{self.run_description}" 87 | wandb.init(config=self.default_hparams, mode="offline", name=run_name) 88 | 89 | self.hparams = wandb.config 90 | # Logging 91 | self.exp_log_dir = os.path.join(self.save_dir, self.experiment_description, run_name) 92 | os.makedirs(self.exp_log_dir, exist_ok=True) 93 | copy_Files(self.exp_log_dir) # save a copy of training files: 94 | 95 | scenarios = self.dataset_configs.scenarios # return the scenarios given a specific dataset. 96 | 97 | self.metrics = {'accuracy': [], 'f1_score': []} 98 | 99 | for i in scenarios: 100 | if self.da_method == 'Source_only': # training on source and testing on target 101 | src_id = i[0] 102 | trg_id = i[1] 103 | elif self.da_method == 'Target_only': # training on target and testing target 104 | src_id = i[1] 105 | trg_id = i[1] 106 | else: 107 | raise NotImplementedError("select the the base method") 108 | 109 | for run_id in range(self.num_runs): # specify number of consecutive runs 110 | # fixing random seed 111 | fix_randomness(run_id) 112 | 113 | # Logging 114 | self.logger, self.scenario_log_dir = starting_logs(self.data_type, self.da_method, self.exp_log_dir, 115 | src_id, trg_id, run_id) 116 | 117 | # Load data 118 | self.load_data(src_id, trg_id) 119 | 120 | # get algorithm 121 | algorithm_class = get_algorithm_class('Lower_Upper_bounds') 122 | 123 | backbone_fe = get_backbone_class(self.backbone) 124 | algorithm = algorithm_class(backbone_fe, self.dataset_configs, self.hparams, self.device) 125 | algorithm.to(self.device) 126 | 127 | # Average meters 128 | loss_avg_meters = collections.defaultdict(lambda: AverageMeter()) 129 | 130 | # training.. 131 | for epoch in range(1, self.hparams["num_epochs"] + 1): 132 | joint_loaders = enumerate(zip(self.src_train_dl, self.trg_train_dl)) 133 | len_dataloader = min(len(self.src_train_dl), len(self.trg_train_dl)) 134 | algorithm.train() 135 | 136 | for step, ((src_x, src_y), (trg_x, _)) in joint_loaders: 137 | src_x, src_y, trg_x = src_x.float().to(self.device), src_y.long().to(self.device), \ 138 | trg_x.float().to(self.device) 139 | 140 | losses = algorithm.update(src_x, src_y) 141 | 142 | for key, val in losses.items(): 143 | loss_avg_meters[key].update(val, src_x.size(0)) 144 | 145 | # logging 146 | self.logger.debug(f'[Epoch : {epoch}/{self.hparams["num_epochs"]}]') 147 | for key, val in loss_avg_meters.items(): 148 | self.logger.debug(f'{key}\t: {val.avg:2.4f}') 149 | self.logger.debug(f'-------------------------------------') 150 | 151 | self.algorithm = algorithm 152 | save_checkpoint(self.home_path, self.algorithm, scenarios, self.dataset_configs, 153 | self.scenario_log_dir, self.hparams) 154 | 155 | self.evaluate() 156 | 157 | self.calc_results_per_run() 158 | 159 | # logging metrics 160 | self.calc_overall_results() 161 | average_metrics = {metric: np.mean(value) for (metric, value) in self.metrics.items()} 162 | wandb.log(average_metrics) 163 | wandb.log({'hparams': wandb.Table( 164 | dataframe=pd.DataFrame(dict(self.hparams).items(), columns=['parameter', 'value']), 165 | allow_mixed_types=True)}) 166 | 167 | def evaluate(self): 168 | feature_extractor = self.algorithm.feature_extractor.to(self.device) 169 | classifier = self.algorithm.classifier.to(self.device) 170 | 171 | feature_extractor.eval() 172 | classifier.eval() 173 | 174 | total_loss_ = [] 175 | 176 | self.trg_pred_labels = np.array([]) 177 | self.trg_true_labels = np.array([]) 178 | 179 | with torch.no_grad(): 180 | for data, labels in self.trg_test_dl: 181 | data = data.float().to(self.device) 182 | labels = labels.view((-1)).long().to(self.device) 183 | 184 | # forward pass 185 | features = feature_extractor(data) 186 | predictions = classifier(features) 187 | 188 | # compute loss 189 | 190 | loss = F.cross_entropy(predictions, labels) 191 | 192 | total_loss_.append(loss.item()) 193 | 194 | pred = predictions.detach().argmax(dim=1) # get the index of the max log-probability 195 | 196 | self.trg_pred_labels = np.append(self.trg_pred_labels, pred.cpu().numpy()) 197 | self.trg_true_labels = np.append(self.trg_true_labels, labels.data.cpu().numpy()) 198 | 199 | self.trg_loss = torch.tensor(total_loss_).mean() # average loss 200 | 201 | def get_configs(self): 202 | dataset_class = get_dataset_class(self.data_type) 203 | hparams_class = get_hparams_class(self.data_type) 204 | return dataset_class(), hparams_class() 205 | 206 | def load_data(self, src_id, trg_id): 207 | self.src_train_dl, self.src_test_dl = data_generator(self.data_path, src_id, self.dataset_configs, 208 | self.hparams) 209 | self.trg_train_dl, self.trg_test_dl = data_generator(self.data_path, trg_id, self.dataset_configs, 210 | self.hparams) 211 | self.few_shot_dl = few_shot_data_generator(self.trg_test_dl) 212 | 213 | def create_save_dir(self): 214 | if not os.path.exists(self.save_dir): 215 | os.mkdir(self.save_dir) 216 | 217 | def calc_results_per_run(self): 218 | ''' 219 | Calculates the acc, f1 and risk values for each cross-domain scenario 220 | ''' 221 | self.acc, self.f1 = _calc_metrics(self.trg_pred_labels, self.trg_true_labels, self.scenario_log_dir, 222 | self.home_path, 223 | self.dataset_configs.class_names) 224 | 225 | run_metrics = {'accuracy': self.acc, 'f1_score': self.f1} 226 | for (key, val) in run_metrics.items(): self.metrics[key].append(val) 227 | 228 | df = pd.DataFrame(columns=["acc", "f1"]) 229 | df.loc[0] = [self.acc, self.f1] 230 | scores_save_path = os.path.join(self.home_path, self.scenario_log_dir, "scores.xlsx") 231 | df.to_excel(scores_save_path, index=False) 232 | self.results_df = df 233 | 234 | def calc_overall_results(self): 235 | exp = self.exp_log_dir 236 | 237 | # for exp in experiments: 238 | results = pd.DataFrame( 239 | columns=["scenario", "acc", "f1"]) 240 | 241 | single_exp = os.listdir(exp) 242 | single_exp = [i for i in single_exp if "_to_" in i] 243 | 244 | src_ids = [single_exp[i].split("_")[0] for i in range(len(single_exp))] 245 | # num_runsuns = src_ids.count(src_ids[0]) 246 | num_runs = 3 247 | scenarios_ids = np.unique(["_".join(i.split("_")[:3]) for i in single_exp]) 248 | 249 | for scenario in single_exp: 250 | scenario_dir = os.path.join(exp, scenario) 251 | scores = pd.read_excel(os.path.join(scenario_dir, 'scores.xlsx')) 252 | results = results.append(scores) 253 | results.iloc[len(results) - 1, 0] = scenario 254 | 255 | # avg_results = results.groupby(np.arange(len(results)) // num_runs).mean() 256 | avg_results = results.groupby(np.arange(len(results)) // num_runs).agg(['mean', 'std']) 257 | 258 | avg_results.loc[len(avg_results)] = avg_results.mean() 259 | avg_results.insert(0, "scenario", list(scenarios_ids) + ['mean'], True) 260 | 261 | report_save_path_avg = os.path.join(exp, f"Average_results.xlsx") 262 | 263 | self.averages_results_df = avg_results 264 | avg_results.to_excel(report_save_path_avg) 265 | 266 | 267 | if __name__ == "__main__": 268 | 269 | parser = argparse.ArgumentParser() 270 | # ========= Select the DA methods ============ 271 | parser.add_argument('--da_method', default='Source_only', type=str, 272 | help='Source_only, Target_only') 273 | 274 | # ========= Select the DATASET ============== 275 | parser.add_argument('--selected_dataset', default='EEG', type=str, 276 | help='Dataset of choice: FD, EEG, HAR, HHAR_SA') 277 | 278 | # ========= Select the BACKBONE ============== 279 | parser.add_argument('--backbone', default='CNN_T', type=str, 280 | help='Backbone of choice: CNN - RESNET18 - RESNET18_REDUCED - TCN') 281 | 282 | # ========= Experiment settings =============== 283 | parser.add_argument('--data_path', default='./data', type=str, help='Path containing dataset') 284 | parser.add_argument('--num_runs', default=3, type=int, help='Number of consecutive run with different seeds') 285 | parser.add_argument('--device', default='cuda:0', type=str, help='cpu or cuda') 286 | parser.add_argument('--is_sweep', default=False, type=bool, help='singe run or sweep') 287 | parser.add_argument('--num_sweeps', default=20, type=str, help='Number of sweep runs') 288 | 289 | args = parser.parse_args() 290 | 291 | trainer = same_domain_Trainer(args) 292 | if args.is_sweep: 293 | trainer.sweep() 294 | else: 295 | trainer.train() 296 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | import os 5 | import wandb 6 | import pandas as pd 7 | import numpy as np 8 | from dataloader.dataloader import data_generator, few_shot_data_generator, generator_percentage_of_data 9 | from configs.data_model_configs import get_dataset_class 10 | from configs.hparams import get_hparams_class 11 | 12 | from configs.sweep_params import sweep_alg_hparams 13 | from utils import fix_randomness, copy_Files, starting_logs, save_checkpoint, _calc_metrics 14 | from utils import calc_dev_risk, calculate_risk 15 | import warnings 16 | 17 | import sklearn.exceptions 18 | warnings.filterwarnings("ignore", category=sklearn.exceptions.UndefinedMetricWarning) 19 | 20 | import collections 21 | from algorithms.algorithms import get_algorithm_class 22 | from models.models import get_backbone_class 23 | from utils import AverageMeter 24 | 25 | torch.backends.cudnn.benchmark = True # to fasten TCN 26 | 27 | class cross_domain_trainer(object): 28 | """ 29 | This class contain the main training functions for our AdAtime 30 | """ 31 | def __init__(self, args): 32 | self.da_method = args.da_method # Selected DA Method 33 | self.dataset = args.dataset # Selected Dataset 34 | self.backbone = args.backbone 35 | self.device = torch.device(args.device) # device 36 | self.num_sweeps = args.num_sweeps 37 | 38 | # Exp Description 39 | self.run_description = args.run_description 40 | self.experiment_description = args.experiment_description 41 | # sweep parameters 42 | self.is_sweep = args.is_sweep 43 | self.sweep_project_wandb = args.sweep_project_wandb 44 | self.wandb_entity = args.wandb_entity 45 | self.hp_search_strategy = args.hp_search_strategy 46 | self.metric_to_minimize = args.metric_to_minimize 47 | 48 | # paths 49 | self.home_path = os.getcwd() 50 | self.save_dir = args.save_dir 51 | self.data_path = os.path.join(args.data_path, self.dataset) 52 | self.create_save_dir() 53 | 54 | # Specify runs 55 | self.num_runs = args.num_runs 56 | 57 | # get dataset and base model configs 58 | self.dataset_configs, self.hparams_class = self.get_configs() 59 | 60 | # to fix dimension of features in classifier and discriminator networks. 61 | self.dataset_configs.final_out_channels = self.dataset_configs.tcn_final_out_channles if args.backbone == "TCN" else self.dataset_configs.final_out_channels 62 | 63 | # Specify number of hparams 64 | self.default_hparams = {**self.hparams_class.alg_hparams[self.da_method], 65 | **self.hparams_class.train_params} 66 | 67 | def sweep(self): 68 | # sweep configurations 69 | sweep_runs_count = self.num_sweeps 70 | sweep_config = { 71 | 'method': self.hp_search_strategy, 72 | 'metric': {'name': self.metric_to_minimize, 'goal': 'minimize'}, 73 | 'name': self.da_method, 74 | 'parameters': {**sweep_alg_hparams[self.da_method]} 75 | } 76 | sweep_id = wandb.sweep(sweep_config, project=self.sweep_project_wandb, entity=self.wandb_entity) 77 | 78 | wandb.agent(sweep_id, self.train, count=sweep_runs_count) # Training with sweep 79 | 80 | # resuming sweep 81 | # wandb.agent('8wkaibgr', self.train, count=25,project='HHAR_SA_Resnet', entity= 'iclr_rebuttal' ) 82 | 83 | def train(self): 84 | if self.is_sweep: 85 | wandb.init(config=self.default_hparams) 86 | run_name = f"sweep_{self.dataset}" 87 | else: 88 | run_name = f"{self.run_description}" 89 | wandb.init(config=self.default_hparams, mode="online", name=run_name) 90 | 91 | self.hparams = wandb.config 92 | # Logging 93 | self.exp_log_dir = os.path.join(self.save_dir, self.experiment_description, run_name) 94 | os.makedirs(self.exp_log_dir, exist_ok=True) 95 | copy_Files(self.exp_log_dir) # save a copy of training files: 96 | 97 | scenarios = self.dataset_configs.scenarios # return the scenarios given a specific dataset. 98 | 99 | self.metrics = {'accuracy': [], 'f1_score': [], 'src_risk': [], 'few_shot_trg_risk': [], 100 | 'trg_risk': [], 'dev_risk': []} 101 | 102 | for i in scenarios: 103 | src_id = i[0] 104 | trg_id = i[1] 105 | 106 | for run_id in range(self.num_runs): # specify number of consecutive runs 107 | # fixing random seed 108 | fix_randomness(run_id) 109 | 110 | # Logging 111 | self.logger, self.scenario_log_dir = starting_logs(self.dataset, self.da_method, self.exp_log_dir, 112 | src_id, trg_id, run_id) 113 | 114 | # Load data 115 | self.load_data(src_id, trg_id) 116 | 117 | # get algorithm 118 | algorithm_class = get_algorithm_class(self.da_method) 119 | backbone_fe = get_backbone_class(self.backbone) 120 | 121 | algorithm = algorithm_class(backbone_fe, self.dataset_configs, self.hparams, self.device) 122 | algorithm.to(self.device) 123 | 124 | ######## Measure model complexity in terms of Flops and Parameters################# 125 | # from thop import profile 126 | # import torch 127 | # input = torch.randn(1, 9, 128).to(self.device) 128 | # flops, para = profile(algorithm.network, inputs=(input,)) 129 | # print("Model_t Flops ={}, Parameters={}".format(flops/1e6,para/1e6)) 130 | ######## End ################# 131 | 132 | # Average meters 133 | loss_avg_meters = collections.defaultdict(lambda: AverageMeter()) 134 | 135 | # training.. 136 | for epoch in range(1, self.hparams["num_epochs"] + 1): 137 | joint_loaders = enumerate(zip(self.src_train_dl, self.trg_train_dl)) 138 | len_dataloader = min(len(self.src_train_dl), len(self.trg_train_dl)) 139 | algorithm.train() 140 | 141 | for step, ((src_x, src_y), (trg_x, _)) in joint_loaders: 142 | src_x, src_y, trg_x = src_x.float().to(self.device), src_y.long().to(self.device), \ 143 | trg_x.float().to(self.device) 144 | 145 | if self.da_method == "DANN" or self.da_method == "DANN_T" or self.da_method == "KDSTDA" or self.da_method == "REDA"\ 146 | or self.da_method == "CoDATS" : 147 | losses = algorithm.update(src_x, src_y, trg_x, step, epoch, len_dataloader) 148 | else: 149 | losses = algorithm.update(src_x, src_y, trg_x) 150 | 151 | for key, val in losses.items(): 152 | loss_avg_meters[key].update(val, src_x.size(0)) 153 | 154 | # logging 155 | self.logger.debug(f'[Epoch : {epoch}/{self.hparams["num_epochs"]}]') 156 | for key, val in loss_avg_meters.items(): 157 | self.logger.debug(f'{key}\t: {val.avg:2.4f}') 158 | self.logger.debug(f'-------------------------------------') 159 | 160 | self.algorithm = algorithm 161 | save_checkpoint(self.home_path, self.algorithm, scenarios, self.dataset_configs, 162 | self.scenario_log_dir, self.hparams) 163 | if self.da_method == "DANN_T": 164 | # save discriminator 165 | save_path = os.path.join(self.home_path, self.scenario_log_dir, "discriminator_T.pt") 166 | torch.save(algorithm.domain_classifier.state_dict(), save_path) 167 | if self.da_method == "REDA": 168 | self.evaluate_reda() 169 | else: 170 | self.evaluate() 171 | self.calc_results_per_run() 172 | 173 | # logging metrics 174 | self.calc_overall_results() 175 | average_metrics = {metric: np.mean(value) for (metric, value) in self.metrics.items()} 176 | wandb.log(average_metrics) 177 | wandb.log({'hparams': wandb.Table( 178 | dataframe=pd.DataFrame(dict(self.hparams).items(), columns=['parameter', 'value']), 179 | allow_mixed_types=True)}) 180 | wandb.log({'avg_results': wandb.Table(dataframe=self.averages_results_df, allow_mixed_types=True)}) 181 | wandb.log({'std_results': wandb.Table(dataframe=self.std_results_df, allow_mixed_types=True)}) 182 | 183 | def evaluate_reda(self): 184 | network = self.algorithm.network.to(self.device) 185 | network.eval() 186 | 187 | total_loss_ = [] 188 | 189 | self.trg_pred_labels = np.array([]) 190 | self.trg_true_labels = np.array([]) 191 | 192 | with torch.no_grad(): 193 | for data, labels in self.trg_test_dl: 194 | data = data.float().to(self.device) 195 | labels = labels.view((-1)).long().to(self.device) 196 | 197 | # forward pass 198 | _, _, predictions,_ = network(data) 199 | 200 | # compute loss 201 | loss = F.cross_entropy(predictions, labels) 202 | total_loss_.append(loss.item()) 203 | pred = predictions.detach().argmax(dim=1) # get the index of the max log-probability 204 | 205 | self.trg_pred_labels = np.append(self.trg_pred_labels, pred.cpu().numpy()) 206 | self.trg_true_labels = np.append(self.trg_true_labels, labels.data.cpu().numpy()) 207 | 208 | self.trg_loss = torch.tensor(total_loss_).mean() # average loss 209 | 210 | 211 | def evaluate(self): 212 | feature_extractor = self.algorithm.feature_extractor.to(self.device) 213 | classifier = self.algorithm.classifier.to(self.device) 214 | 215 | feature_extractor.eval() 216 | classifier.eval() 217 | 218 | total_loss_ = [] 219 | 220 | self.trg_pred_labels = np.array([]) 221 | self.trg_true_labels = np.array([]) 222 | 223 | with torch.no_grad(): 224 | for data, labels in self.trg_test_dl: 225 | data = data.float().to(self.device) 226 | labels = labels.view((-1)).long().to(self.device) 227 | 228 | # forward pass 229 | features = feature_extractor(data) 230 | predictions = classifier(features) 231 | 232 | # compute loss 233 | loss = F.cross_entropy(predictions, labels) 234 | total_loss_.append(loss.item()) 235 | pred = predictions.detach().argmax(dim=1) # get the index of the max log-probability 236 | 237 | self.trg_pred_labels = np.append(self.trg_pred_labels, pred.cpu().numpy()) 238 | self.trg_true_labels = np.append(self.trg_true_labels, labels.data.cpu().numpy()) 239 | 240 | self.trg_loss = torch.tensor(total_loss_).mean() # average loss 241 | 242 | def get_configs(self): 243 | dataset_class = get_dataset_class(self.dataset) 244 | hparams_class = get_hparams_class(self.dataset) 245 | return dataset_class(), hparams_class() 246 | 247 | def load_data(self, src_id, trg_id): 248 | self.src_train_dl, self.src_test_dl = data_generator(self.data_path, src_id, self.dataset_configs, 249 | self.hparams) 250 | self.trg_train_dl, self.trg_test_dl = data_generator(self.data_path, trg_id, self.dataset_configs, 251 | self.hparams) 252 | self.few_shot_dl = few_shot_data_generator(self.trg_test_dl) 253 | 254 | # self.src_train_dl = generator_percentage_of_data(self.src_train_dl_) 255 | # self.trg_train_dl = generator_percentage_of_data(self.trg_train_dl_) 256 | 257 | def create_save_dir(self): 258 | if not os.path.exists(self.save_dir): 259 | os.mkdir(self.save_dir) 260 | 261 | def calc_results_per_run(self): 262 | ''' 263 | Calculates the acc, f1 and risk values for each cross-domain scenario 264 | ''' 265 | 266 | self.acc, self.f1 = _calc_metrics(self.trg_pred_labels, self.trg_true_labels, self.scenario_log_dir, 267 | self.home_path, 268 | self.dataset_configs.class_names) 269 | if self.is_sweep: 270 | self.src_risk = calculate_risk(self.algorithm, self.src_test_dl, self.device) 271 | self.trg_risk = calculate_risk(self.algorithm, self.trg_test_dl, self.device) 272 | self.few_shot_trg_risk = calculate_risk(self.algorithm, self.few_shot_dl, self.device) 273 | self.dev_risk = calc_dev_risk(self.algorithm, self.src_train_dl, self.trg_train_dl, self.src_test_dl, 274 | self.dataset_configs, self.device) 275 | 276 | run_metrics = {'accuracy': self.acc, 277 | 'f1_score': self.f1, 278 | 'src_risk': self.src_risk, 279 | 'few_shot_trg_risk': self.few_shot_trg_risk, 280 | 'trg_risk': self.trg_risk, 281 | 'dev_risk': self.dev_risk} 282 | 283 | df = pd.DataFrame(columns=["acc", "f1", "src_risk", "few_shot_trg_risk", "trg_risk", "dev_risk"]) 284 | df.loc[0] = [self.acc, self.f1, self.src_risk, self.few_shot_trg_risk, self.trg_risk, 285 | self.dev_risk] 286 | else: 287 | run_metrics = {'accuracy': self.acc, 'f1_score': self.f1} 288 | df = pd.DataFrame(columns=["acc", "f1"]) 289 | df.loc[0] = [self.acc, self.f1] 290 | 291 | for (key, val) in run_metrics.items(): self.metrics[key].append(val) 292 | 293 | scores_save_path = os.path.join(self.home_path, self.scenario_log_dir, "scores.xlsx") 294 | df.to_excel(scores_save_path, index=False) 295 | self.results_df = df 296 | 297 | def calc_overall_results(self): 298 | exp = self.exp_log_dir 299 | 300 | # for exp in experiments: 301 | if self.is_sweep: 302 | results = pd.DataFrame( 303 | columns=["scenario", "acc", "f1", "src_risk", "few_shot_trg_risk", "trg_risk", "dev_risk"]) 304 | else: 305 | results = pd.DataFrame(columns=["scenario", "acc", "f1"]) 306 | 307 | scenarios_list = os.listdir(exp) 308 | scenarios_list = [i for i in scenarios_list if "_to_" in i] 309 | scenarios_list.sort() 310 | 311 | unique_scenarios_names = [f'{i}_to_{j}' for i, j in self.dataset_configs.scenarios] 312 | 313 | for scenario in scenarios_list: 314 | scenario_dir = os.path.join(exp, scenario) 315 | scores = pd.read_excel(os.path.join(scenario_dir, 'scores.xlsx')) 316 | scores.insert(0, 'scenario', '_'.join(scenario.split('_')[:-2])) 317 | results = pd.concat([results, scores]) 318 | 319 | avg_results = results.groupby('scenario').mean() 320 | std_results = results.groupby('scenario').std() 321 | 322 | avg_results.loc[len(avg_results)] = avg_results.mean() 323 | avg_results.insert(0, "scenario", list(unique_scenarios_names) + ['mean'], True) 324 | std_results.insert(0, "scenario", list(unique_scenarios_names), True) 325 | 326 | report_save_path_avg = os.path.join(exp, f"Average_results.xlsx") 327 | report_save_path_std = os.path.join(exp, f"std_results.xlsx") 328 | 329 | self.averages_results_df = avg_results 330 | self.std_results_df = std_results 331 | avg_results.to_excel(report_save_path_avg) 332 | std_results.to_excel(report_save_path_std) 333 | -------------------------------------------------------------------------------- /benchmark_Mobileda_and_AAD.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import warnings 4 | import sklearn.exceptions 5 | warnings.filterwarnings("ignore", category=sklearn.exceptions.UndefinedMetricWarning) 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | import os 11 | import wandb 12 | import pandas as pd 13 | import numpy as np 14 | from dataloader.dataloader import data_generator, few_shot_data_generator, generator_percentage_of_data 15 | from configs.data_model_configs import get_dataset_class 16 | from configs.hparams import get_hparams_class 17 | 18 | from configs.sweep_params import sweep_alg_hparams 19 | from utils import fix_randomness, copy_Files, starting_logs, save_checkpoint, _calc_metrics 20 | from utils import calc_dev_risk, calculate_risk 21 | import warnings 22 | 23 | import sklearn.exceptions 24 | warnings.filterwarnings("ignore", category=sklearn.exceptions.UndefinedMetricWarning) 25 | 26 | import collections 27 | from algorithms.algorithms import get_algorithm_class 28 | from models.models import get_backbone_class 29 | from utils import AverageMeter 30 | 31 | 32 | torch.backends.cudnn.benchmark = True # to fasten TCN 33 | 34 | class joint_uda_kd_trainer(object): 35 | """ 36 | This class contain the main training functions for our AdAtime 37 | """ 38 | def __init__(self, args): 39 | self.da_method = args.da_method # Selected DA Method 40 | self.dataset = args.dataset # Selected Dataset 41 | self.backbone = args.backbone 42 | self.device = torch.device(args.device) # device 43 | self.num_sweeps = args.num_sweeps 44 | 45 | # Exp Description 46 | self.run_description = args.run_description 47 | self.experiment_description = args.experiment_description 48 | # sweep parameters 49 | self.is_sweep = args.is_sweep 50 | self.sweep_project_wandb = args.sweep_project_wandb 51 | self.wandb_entity = args.wandb_entity 52 | self.hp_search_strategy = args.hp_search_strategy 53 | self.metric_to_minimize = args.metric_to_minimize 54 | 55 | # paths 56 | self.home_path = os.getcwd() 57 | self.save_dir = args.save_dir 58 | self.data_path = os.path.join(args.data_path, self.dataset) 59 | self.create_save_dir() 60 | 61 | # Specify runs 62 | self.num_runs = args.num_runs 63 | 64 | # get dataset and base model configs 65 | self.dataset_configs, self.hparams_class = self.get_configs() 66 | 67 | # to fix dimension of features in classifier and discriminator networks. 68 | self.dataset_configs.final_out_channels = self.dataset_configs.tcn_final_out_channles if args.backbone == "TCN" else self.dataset_configs.final_out_channels 69 | 70 | # Specify number of hparams 71 | self.default_hparams = {**self.hparams_class.alg_hparams[self.da_method], 72 | **self.hparams_class.train_params} 73 | 74 | def sweep(self): 75 | # sweep configurations 76 | sweep_runs_count = self.num_sweeps 77 | sweep_config = { 78 | 'method': self.hp_search_strategy, 79 | 'metric': {'name': self.metric_to_minimize, 'goal': 'minimize'}, 80 | 'name': self.da_method, 81 | 'parameters': {**sweep_alg_hparams[self.da_method]} 82 | } 83 | sweep_id = wandb.sweep(sweep_config, project=self.sweep_project_wandb, entity=self.wandb_entity) 84 | 85 | wandb.agent(sweep_id, self.train, count=sweep_runs_count) # Training with sweep 86 | 87 | # resuming sweep 88 | # wandb.agent('8wkaibgr', self.train, count=25,project='HHAR_SA_Resnet', entity= 'iclr_rebuttal' ) 89 | 90 | def train(self): 91 | if self.is_sweep: 92 | wandb.init(config=self.default_hparams) 93 | run_name = f"sweep_{self.dataset}" 94 | else: 95 | run_name = f"{self.run_description}" 96 | wandb.init(config=self.default_hparams, mode="online", name=run_name) 97 | 98 | self.hparams = wandb.config 99 | # Logging 100 | self.exp_log_dir = os.path.join(self.save_dir, self.experiment_description, run_name) 101 | os.makedirs(self.exp_log_dir, exist_ok=True) 102 | copy_Files(self.exp_log_dir) # save a copy of training files: 103 | 104 | scenarios = self.dataset_configs.scenarios # return the scenarios given a specific dataset. 105 | 106 | self.metrics = {'accuracy': [], 'f1_score': [], 'src_risk': [], 'few_shot_trg_risk': [], 107 | 'trg_risk': [], 'dev_risk': []} 108 | 109 | for i in scenarios: 110 | src_id = i[0] 111 | trg_id = i[1] 112 | 113 | for run_id in range(self.num_runs): # specify number of consecutive runs 114 | # fixing random seed 115 | fix_randomness(run_id) 116 | 117 | # Logging 118 | self.logger, self.scenario_log_dir = starting_logs(self.dataset, self.da_method, self.exp_log_dir, 119 | src_id, trg_id, run_id) 120 | 121 | # Load data 122 | self.load_data(src_id, trg_id) 123 | 124 | # get student algorithm 125 | algorithm_class = get_algorithm_class(self.da_method) 126 | backbone_fe = get_backbone_class(self.backbone) 127 | 128 | algorithm = algorithm_class(backbone_fe, self.dataset_configs, self.hparams, self.device) 129 | 130 | best_teacher = src_id+'_to_'+trg_id+'_checkpoint.pt' 131 | model_t_name = os.path.join(self.save_dir,self.dataset,'DANN_CNN_teacher\Archive',best_teacher) 132 | 133 | checkpoint = torch.load(model_t_name) 134 | algorithm.network_t.load_state_dict(checkpoint["network_dict"]) 135 | 136 | algorithm.to(self.device) 137 | 138 | ######## Measure model complexity in terms of Flops and Parameters################# 139 | # from thop import profile 140 | # import torch 141 | # input = torch.randn(1, 9, 128).to(self.device) 142 | # flops, para = profile(algorithm.network, inputs=(input,)) 143 | # print("Model_t Flops ={}, Parameters={}".format(flops/1e6,para/1e6)) 144 | ######## End ################# 145 | 146 | # Average meters 147 | loss_avg_meters = collections.defaultdict(lambda: AverageMeter()) 148 | 149 | # training.. 150 | for epoch in range(1, self.hparams["num_epochs"] + 1): 151 | joint_loaders = enumerate(zip(self.src_train_dl, self.trg_train_dl)) 152 | len_dataloader = min(len(self.src_train_dl), len(self.trg_train_dl)) 153 | algorithm.train() 154 | 155 | for step, ((src_x, src_y), (trg_x, _)) in joint_loaders: 156 | src_x, src_y, trg_x = src_x.float().to(self.device), src_y.long().to(self.device), \ 157 | trg_x.float().to(self.device) 158 | losses = algorithm.update(src_x, src_y, trg_x, step, epoch, len_dataloader) 159 | for key, val in losses.items(): 160 | loss_avg_meters[key].update(val, src_x.size(0)) 161 | 162 | # logging 163 | self.logger.debug(f'[Epoch : {epoch}/{self.hparams["num_epochs"]}]') 164 | for key, val in loss_avg_meters.items(): 165 | self.logger.debug(f'{key}\t: {val.avg:2.4f}') 166 | self.logger.debug(f'-------------------------------------') 167 | 168 | self.algorithm = algorithm 169 | save_checkpoint(self.home_path, self.algorithm, scenarios, self.dataset_configs, 170 | self.scenario_log_dir, self.hparams) 171 | 172 | self.evaluate() 173 | 174 | self.calc_results_per_run() 175 | 176 | # logging metrics 177 | self.calc_overall_results() 178 | average_metrics = {metric: np.mean(value) for (metric, value) in self.metrics.items()} 179 | wandb.log(average_metrics) 180 | wandb.log({'hparams': wandb.Table( 181 | dataframe=pd.DataFrame(dict(self.hparams).items(), columns=['parameter', 'value']), 182 | allow_mixed_types=True)}) 183 | wandb.log({'avg_results': wandb.Table(dataframe=self.averages_results_df, allow_mixed_types=True)}) 184 | wandb.log({'std_results': wandb.Table(dataframe=self.std_results_df, allow_mixed_types=True)}) 185 | 186 | def evaluate(self): 187 | feature_extractor = self.algorithm.feature_extractor.to(self.device) 188 | classifier = self.algorithm.classifier.to(self.device) 189 | 190 | feature_extractor.eval() 191 | classifier.eval() 192 | 193 | total_loss_ = [] 194 | 195 | self.trg_pred_labels = np.array([]) 196 | self.trg_true_labels = np.array([]) 197 | 198 | with torch.no_grad(): 199 | for data, labels in self.trg_test_dl: 200 | data = data.float().to(self.device) 201 | labels = labels.view((-1)).long().to(self.device) 202 | 203 | # forward pass 204 | features = feature_extractor(data) 205 | predictions = classifier(features) 206 | 207 | # compute loss 208 | loss = F.cross_entropy(predictions, labels) 209 | total_loss_.append(loss.item()) 210 | pred = predictions.detach().argmax(dim=1) # get the index of the max log-probability 211 | 212 | self.trg_pred_labels = np.append(self.trg_pred_labels, pred.cpu().numpy()) 213 | self.trg_true_labels = np.append(self.trg_true_labels, labels.data.cpu().numpy()) 214 | 215 | self.trg_loss = torch.tensor(total_loss_).mean() # average loss 216 | 217 | def get_configs(self): 218 | dataset_class = get_dataset_class(self.dataset) 219 | hparams_class = get_hparams_class(self.dataset) 220 | return dataset_class(), hparams_class() 221 | 222 | def load_data(self, src_id, trg_id): 223 | self.src_train_dl, self.src_test_dl = data_generator(self.data_path, src_id, self.dataset_configs, 224 | self.hparams) 225 | self.trg_train_dl, self.trg_test_dl = data_generator(self.data_path, trg_id, self.dataset_configs, 226 | self.hparams) 227 | self.few_shot_dl = few_shot_data_generator(self.trg_test_dl) 228 | 229 | # self.src_train_dl = generator_percentage_of_data(self.src_train_dl_) 230 | # self.trg_train_dl = generator_percentage_of_data(self.trg_train_dl_) 231 | 232 | def create_save_dir(self): 233 | if not os.path.exists(self.save_dir): 234 | os.mkdir(self.save_dir) 235 | 236 | def calc_results_per_run(self): 237 | ''' 238 | Calculates the acc, f1 and risk values for each cross-domain scenario 239 | ''' 240 | 241 | self.acc, self.f1 = _calc_metrics(self.trg_pred_labels, self.trg_true_labels, self.scenario_log_dir, 242 | self.home_path, 243 | self.dataset_configs.class_names) 244 | if self.is_sweep: 245 | self.src_risk = calculate_risk(self.algorithm, self.src_test_dl, self.device) 246 | self.trg_risk = calculate_risk(self.algorithm, self.trg_test_dl, self.device) 247 | self.few_shot_trg_risk = calculate_risk(self.algorithm, self.few_shot_dl, self.device) 248 | self.dev_risk = calc_dev_risk(self.algorithm, self.src_train_dl, self.trg_train_dl, self.src_test_dl, 249 | self.dataset_configs, self.device) 250 | 251 | run_metrics = {'accuracy': self.acc, 252 | 'f1_score': self.f1, 253 | 'src_risk': self.src_risk, 254 | 'few_shot_trg_risk': self.few_shot_trg_risk, 255 | 'trg_risk': self.trg_risk, 256 | 'dev_risk': self.dev_risk} 257 | 258 | df = pd.DataFrame(columns=["acc", "f1", "src_risk", "few_shot_trg_risk", "trg_risk", "dev_risk"]) 259 | df.loc[0] = [self.acc, self.f1, self.src_risk, self.few_shot_trg_risk, self.trg_risk, 260 | self.dev_risk] 261 | else: 262 | run_metrics = {'accuracy': self.acc, 'f1_score': self.f1} 263 | df = pd.DataFrame(columns=["acc", "f1"]) 264 | df.loc[0] = [self.acc, self.f1] 265 | 266 | for (key, val) in run_metrics.items(): self.metrics[key].append(val) 267 | 268 | scores_save_path = os.path.join(self.home_path, self.scenario_log_dir, "scores.xlsx") 269 | df.to_excel(scores_save_path, index=False) 270 | self.results_df = df 271 | 272 | def calc_overall_results(self): 273 | exp = self.exp_log_dir 274 | 275 | # for exp in experiments: 276 | if self.is_sweep: 277 | results = pd.DataFrame( 278 | columns=["scenario", "acc", "f1", "src_risk", "few_shot_trg_risk", "trg_risk", "dev_risk"]) 279 | else: 280 | results = pd.DataFrame(columns=["scenario", "acc", "f1"]) 281 | 282 | scenarios_list = os.listdir(exp) 283 | scenarios_list = [i for i in scenarios_list if "_to_" in i] 284 | scenarios_list.sort() 285 | 286 | unique_scenarios_names = [f'{i}_to_{j}' for i, j in self.dataset_configs.scenarios] 287 | 288 | for scenario in scenarios_list: 289 | scenario_dir = os.path.join(exp, scenario) 290 | scores = pd.read_excel(os.path.join(scenario_dir, 'scores.xlsx')) 291 | scores.insert(0, 'scenario', '_'.join(scenario.split('_')[:-2])) 292 | results = pd.concat([results, scores]) 293 | 294 | avg_results = results.groupby('scenario').mean() 295 | std_results = results.groupby('scenario').std() 296 | 297 | avg_results.loc[len(avg_results)] = avg_results.mean() 298 | avg_results.insert(0, "scenario", list(unique_scenarios_names) + ['mean'], True) 299 | std_results.insert(0, "scenario", list(unique_scenarios_names), True) 300 | 301 | report_save_path_avg = os.path.join(exp, f"Average_results.xlsx") 302 | report_save_path_std = os.path.join(exp, f"std_results.xlsx") 303 | 304 | self.averages_results_df = avg_results 305 | self.std_results_df = std_results 306 | avg_results.to_excel(report_save_path_avg) 307 | std_results.to_excel(report_save_path_std) 308 | 309 | 310 | parser = argparse.ArgumentParser() 311 | 312 | # ======== Experiments Name ================ 313 | parser.add_argument('--save_dir', default='experiments_logs', type=str, help='Directory containing all experiments') 314 | parser.add_argument('--experiment_description', default='EEG', type=str, help='Name of your experiment (EEG, HAR, HHAR_SA, FD') 315 | parser.add_argument('--run_description', default='AAD', type=str, help='name of your runs, ') 316 | 317 | # ========= Select the DA methods ============ 318 | parser.add_argument('--da_method', default='AAD', type=str, help='MobileDA, AAD') 319 | 320 | # ========= Select the DATASET ============== 321 | parser.add_argument('--data_path', default=r'./data', type=str, help='Path containing dataset') 322 | parser.add_argument('--dataset', default='EEG', type=str, help='Dataset of choice: (FD - EEG - HAR - HHAR_SA)') 323 | 324 | # ========= Select the BACKBONE ============== 325 | parser.add_argument('--backbone', default='CNN', type=str, help='Backbone of choice: (CNN - RESNET18 - TCN - RESNET34 -RESNET1D_WANG)') 326 | 327 | # ========= Experiment settings =============== 328 | parser.add_argument('--num_runs', default=3, type=int, help='Number of consecutive run with different seeds') 329 | parser.add_argument('--device', default='cuda:0', type=str, help='cpu or cuda') 330 | 331 | # ======== sweep settings ===================== 332 | parser.add_argument('--is_sweep', default=False, type=bool, help='singe run or sweep') 333 | parser.add_argument('--num_sweeps', default=20, type=str, help='Number of sweep runs') 334 | 335 | # We run sweeps using wandb plateform, so next parameters are for wandb. 336 | parser.add_argument('--sweep_project_wandb', default='TEST_SOMETHING', type=str, help='Project name in Wandb') 337 | parser.add_argument('--wandb_entity', type=str, help='Entity name in Wandb (can be left blank if there is a default entity)') 338 | parser.add_argument('--hp_search_strategy', default="random", type=str, help='The way of selecting hyper-parameters (random-grid-bayes). in wandb see:https://docs.wandb.ai/guides/sweeps/configuration') 339 | parser.add_argument('--metric_to_minimize', default="src_risk", type=str, help='select one of: (src_risk - trg_risk - few_shot_trg_risk - dev_risk)') 340 | 341 | args = parser.parse_args() 342 | 343 | 344 | if __name__ == "__main__": 345 | trainer = joint_uda_kd_trainer(args) 346 | 347 | if args.is_sweep: 348 | trainer.sweep() 349 | else: 350 | trainer.train() 351 | -------------------------------------------------------------------------------- /benchmark_Multi_Level_Distillation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import warnings 4 | import sklearn.exceptions 5 | warnings.filterwarnings("ignore", category=sklearn.exceptions.UndefinedMetricWarning) 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | import os 11 | import wandb 12 | import pandas as pd 13 | import numpy as np 14 | from dataloader.dataloader import data_generator, few_shot_data_generator, generator_percentage_of_data 15 | from configs.data_model_configs import get_dataset_class 16 | from configs.hparams import get_hparams_class 17 | 18 | from configs.sweep_params import sweep_alg_hparams 19 | from utils import fix_randomness, copy_Files, starting_logs, save_checkpoint, _calc_metrics 20 | from utils import calc_dev_risk, calculate_risk 21 | import warnings 22 | 23 | import sklearn.exceptions 24 | warnings.filterwarnings("ignore", category=sklearn.exceptions.UndefinedMetricWarning) 25 | 26 | import collections 27 | from algorithms.algorithms import get_algorithm_class 28 | from models.models import get_backbone_class 29 | from utils import AverageMeter 30 | 31 | 32 | torch.backends.cudnn.benchmark = True # to fasten TCN 33 | 34 | class mld_kd_trainer(object): 35 | """ 36 | This class contain the main training functions for our AdAtime 37 | """ 38 | def __init__(self, args): 39 | self.da_method = args.da_method # Selected DA Method 40 | self.dataset = args.dataset # Selected Dataset 41 | self.backbone = args.backbone 42 | self.device = torch.device(args.device) # device 43 | self.num_sweeps = args.num_sweeps 44 | 45 | # Exp Description 46 | self.run_description = args.run_description 47 | self.experiment_description = args.experiment_description 48 | # sweep parameters 49 | self.is_sweep = args.is_sweep 50 | self.sweep_project_wandb = args.sweep_project_wandb 51 | self.wandb_entity = args.wandb_entity 52 | self.hp_search_strategy = args.hp_search_strategy 53 | self.metric_to_minimize = args.metric_to_minimize 54 | 55 | # paths 56 | self.home_path = os.getcwd() 57 | self.save_dir = args.save_dir 58 | self.data_path = os.path.join(args.data_path, self.dataset) 59 | self.create_save_dir() 60 | 61 | # Specify runs 62 | self.num_runs = args.num_runs 63 | 64 | # get dataset and base model configs 65 | self.dataset_configs, self.hparams_class = self.get_configs() 66 | 67 | # to fix dimension of features in classifier and discriminator networks. 68 | self.dataset_configs.final_out_channels = self.dataset_configs.tcn_final_out_channles if args.backbone == "TCN" else self.dataset_configs.final_out_channels 69 | 70 | # Specify number of hparams 71 | self.default_hparams = {**self.hparams_class.alg_hparams[self.da_method], 72 | **self.hparams_class.train_params} 73 | 74 | def sweep(self): 75 | # sweep configurations 76 | sweep_runs_count = self.num_sweeps 77 | sweep_config = { 78 | 'method': self.hp_search_strategy, 79 | 'metric': {'name': self.metric_to_minimize, 'goal': 'minimize'}, 80 | 'name': self.da_method, 81 | 'parameters': {**sweep_alg_hparams[self.da_method]} 82 | } 83 | sweep_id = wandb.sweep(sweep_config, project=self.sweep_project_wandb, entity=self.wandb_entity) 84 | 85 | wandb.agent(sweep_id, self.train, count=sweep_runs_count) # Training with sweep 86 | 87 | # resuming sweep 88 | # wandb.agent('8wkaibgr', self.train, count=25,project='HHAR_SA_Resnet', entity= 'iclr_rebuttal' ) 89 | 90 | def train(self): 91 | if self.is_sweep: 92 | wandb.init(config=self.default_hparams) 93 | run_name = f"sweep_{self.dataset}" 94 | else: 95 | run_name = f"{self.run_description}" 96 | wandb.init(config=self.default_hparams, mode="online", name=run_name) 97 | 98 | self.hparams = wandb.config 99 | # Logging 100 | self.exp_log_dir = os.path.join(self.save_dir, self.experiment_description, run_name) 101 | os.makedirs(self.exp_log_dir, exist_ok=True) 102 | copy_Files(self.exp_log_dir) # save a copy of training files: 103 | 104 | scenarios = self.dataset_configs.scenarios # return the scenarios given a specific dataset. 105 | 106 | self.metrics = {'accuracy': [], 'f1_score': [], 'src_risk': [], 'few_shot_trg_risk': [], 107 | 'trg_risk': [], 'dev_risk': []} 108 | 109 | for i in scenarios: 110 | src_id = i[0] 111 | trg_id = i[1] 112 | 113 | for run_id in range(self.num_runs): # specify number of consecutive runs 114 | # fixing random seed 115 | fix_randomness(run_id) 116 | 117 | # Logging 118 | self.logger, self.scenario_log_dir = starting_logs(self.dataset, self.da_method, self.exp_log_dir, 119 | src_id, trg_id, run_id) 120 | 121 | # Load data 122 | self.load_data(src_id, trg_id) 123 | 124 | # get student algorithm 125 | algorithm_class = get_algorithm_class(self.da_method) 126 | backbone_fe = get_backbone_class(self.backbone) 127 | 128 | algorithm = algorithm_class(backbone_fe, self.dataset_configs, self.hparams, self.device) 129 | 130 | # Load Pre-trained Teacher model 131 | best_teacher = src_id+'_to_'+trg_id+'_checkpoint.pt' 132 | model_t_name = os.path.join(self.save_dir,self.dataset,'DANN_CNN_teacher',best_teacher) 133 | checkpoint = torch.load(model_t_name) 134 | algorithm.network_t.load_state_dict(checkpoint["network_dict"]) 135 | 136 | algorithm.to(self.device) 137 | 138 | # Average meters 139 | loss_avg_meters = collections.defaultdict(lambda: AverageMeter()) 140 | 141 | # Pre-train student with src + tgt 142 | for epoch in range(1, self.hparams["num_epochs"] + 1): 143 | joint_loaders = enumerate(zip(self.src_train_dl, self.trg_train_dl)) 144 | len_dataloader = min(len(self.src_train_dl), len(self.trg_train_dl)) 145 | algorithm.network.train() 146 | 147 | for step, ((src_x, src_y), (trg_x, _)) in joint_loaders: 148 | src_x, src_y, trg_x = src_x.float().to(self.device), src_y.long().to(self.device), \ 149 | trg_x.float().to(self.device) 150 | 151 | losses = algorithm.update_s_both(src_x, src_y, trg_x) 152 | 153 | for key, val in losses.items(): 154 | loss_avg_meters[key].update(val, src_x.size(0)) 155 | 156 | # logging 157 | self.logger.debug(f'[Epoch : {epoch}/{self.hparams["num_epochs"]}]') 158 | for key, val in loss_avg_meters.items(): 159 | self.logger.debug(f'{key}\t: {val.avg:2.4f}') 160 | self.logger.debug(f'-------------------------------------') 161 | 162 | # Perform distillation 163 | for epoch in range(1, self.hparams["num_epochs"] + 1): 164 | joint_loaders = enumerate(zip(self.src_train_dl, self.trg_train_dl)) 165 | len_dataloader = min(len(self.src_train_dl), len(self.trg_train_dl)) 166 | algorithm.network.train() 167 | algorithm.network_t.eval() 168 | 169 | for step, ((src_x, src_y), (trg_x, _)) in joint_loaders: 170 | src_x, src_y, trg_x = src_x.float().to(self.device), src_y.long().to(self.device), \ 171 | trg_x.float().to(self.device) 172 | 173 | losses = algorithm.update_s_tgt(src_x, src_y, trg_x) 174 | 175 | for key, val in losses.items(): 176 | loss_avg_meters[key].update(val, src_x.size(0)) 177 | 178 | # logging 179 | self.logger.debug(f'[Epoch : {epoch}/{self.hparams["num_epochs"]}]') 180 | for key, val in loss_avg_meters.items(): 181 | self.logger.debug(f'{key}\t: {val.avg:2.4f}') 182 | self.logger.debug(f'-------------------------------------') 183 | 184 | 185 | self.algorithm = algorithm 186 | save_checkpoint(self.home_path, self.algorithm, scenarios, self.dataset_configs, 187 | self.scenario_log_dir, self.hparams) 188 | 189 | self.evaluate() 190 | self.calc_results_per_run() 191 | 192 | # logging metrics 193 | self.calc_overall_results() 194 | average_metrics = {metric: np.mean(value) for (metric, value) in self.metrics.items()} 195 | wandb.log(average_metrics) 196 | wandb.log({'hparams': wandb.Table( 197 | dataframe=pd.DataFrame(dict(self.hparams).items(), columns=['parameter', 'value']), 198 | allow_mixed_types=True)}) 199 | wandb.log({'avg_results': wandb.Table(dataframe=self.averages_results_df, allow_mixed_types=True)}) 200 | wandb.log({'std_results': wandb.Table(dataframe=self.std_results_df, allow_mixed_types=True)}) 201 | 202 | def evaluate(self): 203 | feature_extractor = self.algorithm.feature_extractor.to(self.device) 204 | classifier = self.algorithm.classifier.to(self.device) 205 | 206 | feature_extractor.eval() 207 | classifier.eval() 208 | 209 | total_loss_ = [] 210 | 211 | self.trg_pred_labels = np.array([]) 212 | self.trg_true_labels = np.array([]) 213 | 214 | with torch.no_grad(): 215 | for data, labels in self.trg_test_dl: 216 | data = data.float().to(self.device) 217 | labels = labels.view((-1)).long().to(self.device) 218 | 219 | # forward pass 220 | features = feature_extractor(data) 221 | predictions = classifier(features) 222 | 223 | # compute loss 224 | loss = F.cross_entropy(predictions, labels) 225 | total_loss_.append(loss.item()) 226 | pred = predictions.detach().argmax(dim=1) # get the index of the max log-probability 227 | 228 | self.trg_pred_labels = np.append(self.trg_pred_labels, pred.cpu().numpy()) 229 | self.trg_true_labels = np.append(self.trg_true_labels, labels.data.cpu().numpy()) 230 | 231 | self.trg_loss = torch.tensor(total_loss_).mean() # average loss 232 | 233 | def get_configs(self): 234 | dataset_class = get_dataset_class(self.dataset) 235 | hparams_class = get_hparams_class(self.dataset) 236 | return dataset_class(), hparams_class() 237 | 238 | def load_data(self, src_id, trg_id): 239 | self.src_train_dl, self.src_test_dl = data_generator(self.data_path, src_id, self.dataset_configs, 240 | self.hparams) 241 | self.trg_train_dl, self.trg_test_dl = data_generator(self.data_path, trg_id, self.dataset_configs, 242 | self.hparams) 243 | self.few_shot_dl = few_shot_data_generator(self.trg_test_dl) 244 | 245 | # self.src_train_dl = generator_percentage_of_data(self.src_train_dl_) 246 | # self.trg_train_dl = generator_percentage_of_data(self.trg_train_dl_) 247 | 248 | def create_save_dir(self): 249 | if not os.path.exists(self.save_dir): 250 | os.mkdir(self.save_dir) 251 | 252 | def calc_results_per_run(self): 253 | ''' 254 | Calculates the acc, f1 and risk values for each cross-domain scenario 255 | ''' 256 | 257 | self.acc, self.f1 = _calc_metrics(self.trg_pred_labels, self.trg_true_labels, self.scenario_log_dir, 258 | self.home_path, 259 | self.dataset_configs.class_names) 260 | if self.is_sweep: 261 | self.src_risk = calculate_risk(self.algorithm, self.src_test_dl, self.device) 262 | self.trg_risk = calculate_risk(self.algorithm, self.trg_test_dl, self.device) 263 | self.few_shot_trg_risk = calculate_risk(self.algorithm, self.few_shot_dl, self.device) 264 | self.dev_risk = calc_dev_risk(self.algorithm, self.src_train_dl, self.trg_train_dl, self.src_test_dl, 265 | self.dataset_configs, self.device) 266 | 267 | run_metrics = {'accuracy': self.acc, 268 | 'f1_score': self.f1, 269 | 'src_risk': self.src_risk, 270 | 'few_shot_trg_risk': self.few_shot_trg_risk, 271 | 'trg_risk': self.trg_risk, 272 | 'dev_risk': self.dev_risk} 273 | 274 | df = pd.DataFrame(columns=["acc", "f1", "src_risk", "few_shot_trg_risk", "trg_risk", "dev_risk"]) 275 | df.loc[0] = [self.acc, self.f1, self.src_risk, self.few_shot_trg_risk, self.trg_risk, 276 | self.dev_risk] 277 | else: 278 | run_metrics = {'accuracy': self.acc, 'f1_score': self.f1} 279 | df = pd.DataFrame(columns=["acc", "f1"]) 280 | df.loc[0] = [self.acc, self.f1] 281 | 282 | for (key, val) in run_metrics.items(): self.metrics[key].append(val) 283 | 284 | scores_save_path = os.path.join(self.home_path, self.scenario_log_dir, "scores.xlsx") 285 | df.to_excel(scores_save_path, index=False) 286 | self.results_df = df 287 | 288 | def calc_overall_results(self): 289 | exp = self.exp_log_dir 290 | 291 | # for exp in experiments: 292 | if self.is_sweep: 293 | results = pd.DataFrame( 294 | columns=["scenario", "acc", "f1", "src_risk", "few_shot_trg_risk", "trg_risk", "dev_risk"]) 295 | else: 296 | results = pd.DataFrame(columns=["scenario", "acc", "f1"]) 297 | 298 | scenarios_list = os.listdir(exp) 299 | scenarios_list = [i for i in scenarios_list if "_to_" in i] 300 | scenarios_list.sort() 301 | 302 | unique_scenarios_names = [f'{i}_to_{j}' for i, j in self.dataset_configs.scenarios] 303 | 304 | for scenario in scenarios_list: 305 | scenario_dir = os.path.join(exp, scenario) 306 | scores = pd.read_excel(os.path.join(scenario_dir, 'scores.xlsx')) 307 | scores.insert(0, 'scenario', '_'.join(scenario.split('_')[:-2])) 308 | results = pd.concat([results, scores]) 309 | 310 | avg_results = results.groupby('scenario').mean() 311 | std_results = results.groupby('scenario').std() 312 | 313 | avg_results.loc[len(avg_results)] = avg_results.mean() 314 | avg_results.insert(0, "scenario", list(unique_scenarios_names) + ['mean'], True) 315 | std_results.insert(0, "scenario", list(unique_scenarios_names), True) 316 | 317 | report_save_path_avg = os.path.join(exp, f"Average_results.xlsx") 318 | report_save_path_std = os.path.join(exp, f"std_results.xlsx") 319 | 320 | self.averages_results_df = avg_results 321 | self.std_results_df = std_results 322 | avg_results.to_excel(report_save_path_avg) 323 | std_results.to_excel(report_save_path_std) 324 | 325 | 326 | parser = argparse.ArgumentParser() 327 | 328 | # ======== Experiments Name ================ 329 | parser.add_argument('--save_dir', default='experiments_logs_additional', type=str, help='Directory containing all experiments') 330 | parser.add_argument('--experiment_description', default='FD', type=str, help='Name of your experiment (HAR, FD, HHAR_SA,EEG ') 331 | parser.add_argument('--run_description', default='MLD', type=str, help='name of your runs') 332 | 333 | # ========= Select the DA methods ============ 334 | parser.add_argument('--da_method', default='MLD', type=str, help='JointADKD') 335 | 336 | # ========= Select the DATASET ============== 337 | parser.add_argument('--data_path', default=r'./data', type=str, help='Path containing dataset') 338 | parser.add_argument('--dataset', default='FD', type=str, help='Dataset of choice: (HAR - FD - HHAR_SA - EEG)') 339 | 340 | # ========= Select the BACKBONE ============== 341 | parser.add_argument('--backbone', default='CNN', type=str, help='Backbone of choice: (CNN - RESNET18 - TCN - RESNET34 -RESNET1D_WANG)') 342 | 343 | # ========= Experiment settings =============== 344 | parser.add_argument('--num_runs', default=3, type=int, help='Number of consecutive run with different seeds') 345 | parser.add_argument('--device', default='cuda:0', type=str, help='cpu or cuda') 346 | 347 | # ======== sweep settings ===================== 348 | parser.add_argument('--is_sweep', default=False, type=bool, help='singe run or sweep') 349 | parser.add_argument('--num_sweeps', default=20, type=str, help='Number of sweep runs') 350 | 351 | # We run sweeps using wandb plateform, so next parameters are for wandb. 352 | parser.add_argument('--sweep_project_wandb', default='TEST_SOMETHING', type=str, help='Project name in Wandb') 353 | parser.add_argument('--wandb_entity', type=str, help='Entity name in Wandb (can be left blank if there is a default entity)') 354 | parser.add_argument('--hp_search_strategy', default="random", type=str, help='The way of selecting hyper-parameters (random-grid-bayes). in wandb see:https://docs.wandb.ai/guides/sweeps/configuration') 355 | parser.add_argument('--metric_to_minimize', default="src_risk", type=str, help='select one of: (src_risk - trg_risk - few_shot_trg_risk - dev_risk)') 356 | 357 | args = parser.parse_args() 358 | 359 | 360 | if __name__ == "__main__": 361 | trainer = mld_kd_trainer(args) 362 | 363 | if args.is_sweep: 364 | trainer.sweep() 365 | else: 366 | trainer.train() 367 | -------------------------------------------------------------------------------- /proposed_RCD_KD.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import warnings 4 | import sklearn.exceptions 5 | warnings.filterwarnings("ignore", category=sklearn.exceptions.UndefinedMetricWarning) 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import time 10 | import os 11 | import wandb 12 | import pandas as pd 13 | import numpy as np 14 | from dataloader.dataloader import data_generator, few_shot_data_generator, generator_percentage_of_data 15 | from configs.data_model_configs import get_dataset_class 16 | from configs.hparams import get_hparams_class 17 | 18 | from configs.sweep_params import sweep_alg_hparams 19 | from utils import fix_randomness, copy_Files, starting_logs, save_checkpoint, _calc_metrics 20 | from utils import calc_dev_risk, calculate_risk 21 | import warnings 22 | 23 | import sklearn.exceptions 24 | warnings.filterwarnings("ignore", category=sklearn.exceptions.UndefinedMetricWarning) 25 | 26 | import collections 27 | from algorithms.algorithms import get_algorithm_class 28 | from models.models import get_backbone_class 29 | from utils import AverageMeter 30 | from torch.utils.data import Dataset 31 | 32 | torch.backends.cudnn.benchmark = True # to fasten TCN 33 | 34 | 35 | class MyDataset(Dataset): 36 | def __init__(self, x): 37 | super(MyDataset, self).__init__() 38 | self.len = x.shape[0] 39 | if torch.cuda.is_available(): 40 | device = 'cuda' 41 | else: 42 | device = 'cpu' 43 | self.x_data = torch.as_tensor(x, device=device, dtype=torch.float) 44 | 45 | def __getitem__(self, index): 46 | return self.x_data[index] 47 | 48 | def __len__(self): 49 | return self.len 50 | 51 | class adv_cross_domain_kd_trainer(object): 52 | """ 53 | This class contain the main training functions for our AdAtime 54 | """ 55 | def __init__(self, args): 56 | self.da_method = args.da_method # Selected DA Method 57 | self.dataset = args.dataset # Selected Dataset 58 | self.backbone = args.backbone 59 | self.device = torch.device(args.device) # device 60 | self.num_sweeps = args.num_sweeps 61 | 62 | # Exp Description 63 | self.run_description = args.run_description 64 | self.experiment_description = args.experiment_description 65 | # sweep parameters 66 | self.is_sweep = args.is_sweep 67 | self.sweep_project_wandb = args.sweep_project_wandb 68 | self.wandb_entity = args.wandb_entity 69 | self.hp_search_strategy = args.hp_search_strategy 70 | self.metric_to_minimize = args.metric_to_minimize 71 | 72 | # paths 73 | self.home_path = os.getcwd() 74 | self.save_dir = args.save_dir 75 | self.data_path = os.path.join(args.data_path, self.dataset) 76 | self.create_save_dir() 77 | 78 | # Specify runs 79 | self.num_runs = args.num_runs 80 | 81 | # get dataset and base model configs 82 | self.dataset_configs, self.hparams_class = self.get_configs() 83 | 84 | # to fix dimension of features in classifier and discriminator networks. 85 | self.dataset_configs.final_out_channels = self.dataset_configs.tcn_final_out_channles if args.backbone == "TCN" else self.dataset_configs.final_out_channels 86 | 87 | # Specify number of hparams 88 | self.default_hparams = {**self.hparams_class.alg_hparams[self.da_method], 89 | **self.hparams_class.train_params} 90 | 91 | def sweep(self): 92 | # sweep configurations 93 | sweep_runs_count = self.num_sweeps 94 | sweep_config = { 95 | 'method': self.hp_search_strategy, 96 | 'metric': {'name': self.metric_to_minimize, 'goal': 'minimize'}, 97 | 'name': self.da_method, 98 | 'parameters': {**sweep_alg_hparams[self.da_method]} 99 | } 100 | sweep_id = wandb.sweep(sweep_config, project=self.sweep_project_wandb, entity=self.wandb_entity) 101 | 102 | wandb.agent(sweep_id, self.train, count=sweep_runs_count) # Training with sweep 103 | 104 | # resuming sweep 105 | # wandb.agent('8wkaibgr', self.train, count=25,project='HHAR_SA_Resnet', entity= 'iclr_rebuttal' ) 106 | 107 | def train(self): 108 | if self.is_sweep: 109 | wandb.init(config=self.default_hparams) 110 | run_name = f"sweep_{self.dataset}" 111 | else: 112 | run_name = f"{self.run_description}" 113 | wandb.init(config=self.default_hparams, mode="online", name=run_name) 114 | 115 | self.hparams = wandb.config 116 | # Logging 117 | self.exp_log_dir = os.path.join(self.save_dir, self.experiment_description, run_name) 118 | os.makedirs(self.exp_log_dir, exist_ok=True) 119 | copy_Files(self.exp_log_dir) # save a copy of training files: 120 | 121 | scenarios = self.dataset_configs.scenarios # return the scenarios given a specific dataset. 122 | 123 | self.metrics = {'accuracy': [], 'f1_score': [], 'src_risk': [], 'few_shot_trg_risk': [], 124 | 'trg_risk': [], 'dev_risk': []} 125 | 126 | for i in scenarios: 127 | src_id = i[0] 128 | trg_id = i[1] 129 | 130 | for run_id in range(self.num_runs): # specify number of consecutive runs 131 | # fixing random seed 132 | fix_randomness(run_id) 133 | 134 | # Logging 135 | self.logger, self.scenario_log_dir = starting_logs(self.dataset, self.da_method, self.exp_log_dir, 136 | src_id, trg_id, run_id) 137 | 138 | # Load data 139 | self.load_data(src_id, trg_id) 140 | 141 | # get student algorithm 142 | algorithm_class = get_algorithm_class(self.da_method) 143 | backbone_fe = get_backbone_class(self.backbone) 144 | 145 | algorithm = algorithm_class(backbone_fe, self.dataset_configs, self.hparams, self.device) 146 | 147 | # Load Pre-trained Teacher model 148 | best_teacher = src_id+'_to_'+trg_id+'_checkpoint.pt' 149 | model_t_name = os.path.join(self.save_dir,self.dataset,'DANN_CNN_teacher',best_teacher) 150 | checkpoint = torch.load(model_t_name) 151 | algorithm.network_t.load_state_dict(checkpoint["network_dict"]) 152 | 153 | # Load pre-trained teacher discriminator 154 | discriminator_name = src_id+'_to_'+trg_id+'_discriminator_T.pt' 155 | discriminator_path = os.path.join(self.save_dir, self.dataset, 'DANN_CNN_teacher', discriminator_name) 156 | dis_checkpoint = torch.load(discriminator_path) 157 | algorithm.domain_classifier.load_state_dict(dis_checkpoint) 158 | 159 | algorithm.to(self.device) 160 | 161 | # Average meters 162 | loss_avg_meters = collections.defaultdict(lambda: AverageMeter()) 163 | 164 | # training.. 165 | global_step = 0 166 | for epoch in range(1, self.hparams["num_epochs"] + 1): 167 | 168 | # since = time.time() 169 | 170 | joint_loaders = enumerate(zip(self.src_train_dl, self.trg_train_dl)) 171 | len_dataloader = min(len(self.src_train_dl), len(self.trg_train_dl)) 172 | algorithm.train() 173 | 174 | for step, ((src_x, src_y), (trg_x, _)) in joint_loaders: 175 | src_x, src_y, trg_x = src_x.float().to(self.device), src_y.long().to(self.device), \ 176 | trg_x.float().to(self.device) 177 | losses = algorithm.update_mcdo_new_2(src_x, src_y, trg_x, global_step, step, epoch, len_dataloader) 178 | # losses = algorithm.update_trans(src_x, src_y, trg_x) 179 | # losses = algorithm.update_mcdo_trans(src_x, src_y, trg_x) 180 | for key, val in losses.items(): 181 | loss_avg_meters[key].update(val, src_x.size(0)) 182 | 183 | global_step+=1 184 | 185 | # logging 186 | self.logger.debug(f'[Epoch : {epoch}/{self.hparams["num_epochs"]}]') 187 | for key, val in loss_avg_meters.items(): 188 | self.logger.debug(f'{key}\t: {val.avg:2.4f}') 189 | self.logger.debug(f'-------------------------------------') 190 | 191 | # elapsed_time = time.time() - since 192 | # print("Training completed in {:.4f}s".format(elapsed_time)) 193 | 194 | self.algorithm = algorithm 195 | save_checkpoint(self.home_path, self.algorithm, scenarios, self.dataset_configs, 196 | self.scenario_log_dir, self.hparams) 197 | 198 | self.evaluate() 199 | self.calc_results_per_run() 200 | 201 | # logging metrics 202 | self.calc_overall_results() 203 | average_metrics = {metric: np.mean(value) for (metric, value) in self.metrics.items()} 204 | wandb.log(average_metrics) 205 | wandb.log({'hparams': wandb.Table( 206 | dataframe=pd.DataFrame(dict(self.hparams).items(), columns=['parameter', 'value']), 207 | allow_mixed_types=True)}) 208 | wandb.log({'avg_results': wandb.Table(dataframe=self.averages_results_df, allow_mixed_types=True)}) 209 | wandb.log({'std_results': wandb.Table(dataframe=self.std_results_df, allow_mixed_types=True)}) 210 | 211 | def mcd_score(self, predictions, clusters): 212 | scores = 0 213 | num_class = clusters.shape[1] 214 | for i in range(num_class-1): 215 | for j in range(i+1, num_class): 216 | scores +=torch.abs(torch.nn.functional.cosine_similarity(predictions, clusters[i])- 217 | torch.nn.functional.cosine_similarity(predictions, clusters[j])) 218 | scores = scores /(num_class*(num_class-1)/2) 219 | return scores 220 | 221 | 222 | def evaluate(self): 223 | feature_extractor = self.algorithm.feature_extractor.to(self.device) 224 | classifier = self.algorithm.classifier.to(self.device) 225 | 226 | feature_extractor.eval() 227 | classifier.eval() 228 | 229 | total_loss_ = [] 230 | 231 | self.trg_pred_labels = np.array([]) 232 | self.trg_true_labels = np.array([]) 233 | 234 | with torch.no_grad(): 235 | for data, labels in self.trg_test_dl: 236 | data = data.float().to(self.device) 237 | labels = labels.view((-1)).long().to(self.device) 238 | 239 | # forward pass 240 | features = feature_extractor(data) 241 | predictions = classifier(features) 242 | 243 | # compute loss 244 | loss = F.cross_entropy(predictions, labels) 245 | total_loss_.append(loss.item()) 246 | pred = predictions.detach().argmax(dim=1) # get the index of the max log-probability 247 | 248 | self.trg_pred_labels = np.append(self.trg_pred_labels, pred.cpu().numpy()) 249 | self.trg_true_labels = np.append(self.trg_true_labels, labels.data.cpu().numpy()) 250 | 251 | self.trg_loss = torch.tensor(total_loss_).mean() # average loss 252 | 253 | 254 | def get_configs(self): 255 | dataset_class = get_dataset_class(self.dataset) 256 | hparams_class = get_hparams_class(self.dataset) 257 | return dataset_class(), hparams_class() 258 | 259 | def load_data(self, src_id, trg_id): 260 | self.src_train_dl, self.src_test_dl = data_generator(self.data_path, src_id, self.dataset_configs, 261 | self.hparams) 262 | self.trg_train_dl, self.trg_test_dl = data_generator(self.data_path, trg_id, self.dataset_configs, 263 | self.hparams) 264 | self.few_shot_dl = few_shot_data_generator(self.trg_test_dl) 265 | 266 | # self.src_train_dl = generator_percentage_of_data(self.src_train_dl_) 267 | # self.trg_train_dl = generator_percentage_of_data(self.trg_train_dl_) 268 | 269 | def create_save_dir(self): 270 | if not os.path.exists(self.save_dir): 271 | os.mkdir(self.save_dir) 272 | 273 | def calc_results_per_run(self): 274 | ''' 275 | Calculates the acc, f1 and risk values for each cross-domain scenario 276 | ''' 277 | 278 | self.acc, self.f1 = _calc_metrics(self.trg_pred_labels, self.trg_true_labels, self.scenario_log_dir, 279 | self.home_path, 280 | self.dataset_configs.class_names) 281 | if self.is_sweep: 282 | self.src_risk = calculate_risk(self.algorithm, self.src_test_dl, self.device) 283 | self.trg_risk = calculate_risk(self.algorithm, self.trg_test_dl, self.device) 284 | self.few_shot_trg_risk = calculate_risk(self.algorithm, self.few_shot_dl, self.device) 285 | self.dev_risk = calc_dev_risk(self.algorithm, self.src_train_dl, self.trg_train_dl, self.src_test_dl, 286 | self.dataset_configs, self.device) 287 | 288 | run_metrics = {'accuracy': self.acc, 289 | 'f1_score': self.f1, 290 | 'src_risk': self.src_risk, 291 | 'few_shot_trg_risk': self.few_shot_trg_risk, 292 | 'trg_risk': self.trg_risk, 293 | 'dev_risk': self.dev_risk} 294 | 295 | df = pd.DataFrame(columns=["acc", "f1", "src_risk", "few_shot_trg_risk", "trg_risk", "dev_risk"]) 296 | df.loc[0] = [self.acc, self.f1, self.src_risk, self.few_shot_trg_risk, self.trg_risk, 297 | self.dev_risk] 298 | else: 299 | run_metrics = {'accuracy': self.acc, 'f1_score': self.f1} 300 | df = pd.DataFrame(columns=["acc", "f1"]) 301 | df.loc[0] = [self.acc, self.f1] 302 | 303 | for (key, val) in run_metrics.items(): self.metrics[key].append(val) 304 | 305 | scores_save_path = os.path.join(self.home_path, self.scenario_log_dir, "scores.xlsx") 306 | df.to_excel(scores_save_path, index=False) 307 | self.results_df = df 308 | 309 | def calc_overall_results(self): 310 | exp = self.exp_log_dir 311 | 312 | # for exp in experiments: 313 | if self.is_sweep: 314 | results = pd.DataFrame( 315 | columns=["scenario", "acc", "f1", "src_risk", "few_shot_trg_risk", "trg_risk", "dev_risk"]) 316 | else: 317 | results = pd.DataFrame(columns=["scenario", "acc", "f1"]) 318 | 319 | scenarios_list = os.listdir(exp) 320 | scenarios_list = [i for i in scenarios_list if "_to_" in i] 321 | scenarios_list.sort() 322 | 323 | unique_scenarios_names = [f'{i}_to_{j}' for i, j in self.dataset_configs.scenarios] 324 | 325 | for scenario in scenarios_list: 326 | scenario_dir = os.path.join(exp, scenario) 327 | scores = pd.read_excel(os.path.join(scenario_dir, 'scores.xlsx')) 328 | scores.insert(0, 'scenario', '_'.join(scenario.split('_')[:-2])) 329 | results = pd.concat([results, scores]) 330 | 331 | avg_results = results.groupby('scenario').mean() 332 | std_results = results.groupby('scenario').std() 333 | 334 | avg_results.loc[len(avg_results)] = avg_results.mean() 335 | avg_results.insert(0, "scenario", list(unique_scenarios_names) + ['mean'], True) 336 | std_results.insert(0, "scenario", list(unique_scenarios_names), True) 337 | 338 | report_save_path_avg = os.path.join(exp, f"Average_results.xlsx") 339 | report_save_path_std = os.path.join(exp, f"std_results.xlsx") 340 | 341 | self.averages_results_df = avg_results 342 | self.std_results_df = std_results 343 | avg_results.to_excel(report_save_path_avg) 344 | std_results.to_excel(report_save_path_std) 345 | 346 | 347 | parser = argparse.ArgumentParser() 348 | 349 | # ======== Experiments Name ================ 350 | parser.add_argument('--save_dir', default='experiments_logs', type=str, help='Directory containing all experiments') 351 | parser.add_argument('--experiment_description', default='HAR', type=str, help='Name of your experiment (HAR, FD, HHAR_SA,EEG ') 352 | parser.add_argument('--run_description', default='RL_JointADKD_new_2', type=str, help='name of your runs') 353 | 354 | # ========= Select the DA methods ============ 355 | parser.add_argument('--da_method', default='RL_JointADKD', type=str, help='JointADKD,RL_JointADKD') 356 | 357 | # ========= Select the DATASET ============== 358 | parser.add_argument('--data_path', default=r'./data', type=str, help='Path containing dataset') 359 | parser.add_argument('--dataset', default='HAR', type=str, help='Dataset of choice: (HAR - FD - HHAR_SA - EEG)') 360 | 361 | # ========= Select the BACKBONE ============== 362 | parser.add_argument('--backbone', default='CNN', type=str, help='Backbone of choice: (CNN - RESNET18 - TCN - RESNET34 -RESNET1D_WANG)') 363 | 364 | # ========= Experiment settings =============== 365 | parser.add_argument('--num_runs', default=3, type=int, help='Number of consecutive run with different seeds') 366 | parser.add_argument('--device', default='cuda:0', type=str, help='cpu or cuda') 367 | 368 | # ======== sweep settings ===================== 369 | parser.add_argument('--is_sweep', default=False, type=bool, help='singe run or sweep') 370 | parser.add_argument('--num_sweeps', default=20, type=str, help='Number of sweep runs') 371 | 372 | # We run sweeps using wandb plateform, so next parameters are for wandb. 373 | parser.add_argument('--sweep_project_wandb', default='TEST_SOMETHING', type=str, help='Project name in Wandb') 374 | parser.add_argument('--wandb_entity', type=str, help='Entity name in Wandb (can be left blank if there is a default entity)') 375 | parser.add_argument('--hp_search_strategy', default="random", type=str, help='The way of selecting hyper-parameters (random-grid-bayes). in wandb see:https://docs.wandb.ai/guides/sweeps/configuration') 376 | parser.add_argument('--metric_to_minimize', default="src_risk", type=str, help='select one of: (src_risk - trg_risk - few_shot_trg_risk - dev_risk)') 377 | 378 | args = parser.parse_args() 379 | 380 | 381 | if __name__ == "__main__": 382 | trainer = adv_cross_domain_kd_trainer(args) 383 | 384 | if args.is_sweep: 385 | trainer.sweep() 386 | else: 387 | trainer.train() 388 | -------------------------------------------------------------------------------- /benchmark_Max_Cluser_Difference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import warnings 4 | import sklearn.exceptions 5 | warnings.filterwarnings("ignore", category=sklearn.exceptions.UndefinedMetricWarning) 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | import os 11 | import wandb 12 | import pandas as pd 13 | import numpy as np 14 | from dataloader.dataloader import data_generator, few_shot_data_generator, generator_percentage_of_data 15 | from configs.data_model_configs import get_dataset_class 16 | from configs.hparams import get_hparams_class 17 | 18 | from configs.sweep_params import sweep_alg_hparams 19 | from utils import fix_randomness, copy_Files, starting_logs, save_checkpoint, _calc_metrics 20 | from utils import calc_dev_risk, calculate_risk 21 | import warnings 22 | 23 | import sklearn.exceptions 24 | warnings.filterwarnings("ignore", category=sklearn.exceptions.UndefinedMetricWarning) 25 | 26 | import collections 27 | from algorithms.algorithms import get_algorithm_class 28 | from models.models import get_backbone_class 29 | from utils import AverageMeter 30 | from torch.utils.data import Dataset 31 | 32 | torch.backends.cudnn.benchmark = True # to fasten TCN 33 | 34 | class MyDataset(Dataset): 35 | def __init__(self, x): 36 | super(MyDataset, self).__init__() 37 | self.len = x.shape[0] 38 | if torch.cuda.is_available(): 39 | device = 'cuda' 40 | else: 41 | device = 'cpu' 42 | self.x_data = torch.as_tensor(x, device=device, dtype=torch.float) 43 | 44 | def __getitem__(self, index): 45 | return self.x_data[index] 46 | 47 | def __len__(self): 48 | return self.len 49 | 50 | class mcd_kd_trainer(object): 51 | """ 52 | This class contain the main training functions for our AdAtime 53 | """ 54 | def __init__(self, args): 55 | self.da_method = args.da_method # Selected DA Method 56 | self.dataset = args.dataset # Selected Dataset 57 | self.backbone = args.backbone 58 | self.device = torch.device(args.device) # device 59 | self.num_sweeps = args.num_sweeps 60 | 61 | # Exp Description 62 | self.run_description = args.run_description 63 | self.experiment_description = args.experiment_description 64 | # sweep parameters 65 | self.is_sweep = args.is_sweep 66 | self.sweep_project_wandb = args.sweep_project_wandb 67 | self.wandb_entity = args.wandb_entity 68 | self.hp_search_strategy = args.hp_search_strategy 69 | self.metric_to_minimize = args.metric_to_minimize 70 | 71 | # paths 72 | self.home_path = os.getcwd() 73 | self.save_dir = args.save_dir 74 | self.data_path = os.path.join(args.data_path, self.dataset) 75 | self.create_save_dir() 76 | 77 | # Specify runs 78 | self.num_runs = args.num_runs 79 | 80 | # get dataset and base model configs 81 | self.dataset_configs, self.hparams_class = self.get_configs() 82 | 83 | # to fix dimension of features in classifier and discriminator networks. 84 | self.dataset_configs.final_out_channels = self.dataset_configs.tcn_final_out_channles if args.backbone == "TCN" else self.dataset_configs.final_out_channels 85 | 86 | # Specify number of hparams 87 | self.default_hparams = {**self.hparams_class.alg_hparams[self.da_method], 88 | **self.hparams_class.train_params} 89 | 90 | def sweep(self): 91 | # sweep configurations 92 | sweep_runs_count = self.num_sweeps 93 | sweep_config = { 94 | 'method': self.hp_search_strategy, 95 | 'metric': {'name': self.metric_to_minimize, 'goal': 'minimize'}, 96 | 'name': self.da_method, 97 | 'parameters': {**sweep_alg_hparams[self.da_method]} 98 | } 99 | sweep_id = wandb.sweep(sweep_config, project=self.sweep_project_wandb, entity=self.wandb_entity) 100 | 101 | wandb.agent(sweep_id, self.train, count=sweep_runs_count) # Training with sweep 102 | 103 | # resuming sweep 104 | # wandb.agent('8wkaibgr', self.train, count=25,project='HHAR_SA_Resnet', entity= 'iclr_rebuttal' ) 105 | 106 | def train(self): 107 | if self.is_sweep: 108 | wandb.init(config=self.default_hparams) 109 | run_name = f"sweep_{self.dataset}" 110 | else: 111 | run_name = f"{self.run_description}" 112 | wandb.init(config=self.default_hparams, mode="online", name=run_name) 113 | 114 | self.hparams = wandb.config 115 | # Logging 116 | self.exp_log_dir = os.path.join(self.save_dir, self.experiment_description, run_name) 117 | os.makedirs(self.exp_log_dir, exist_ok=True) 118 | copy_Files(self.exp_log_dir) # save a copy of training files: 119 | 120 | scenarios = self.dataset_configs.scenarios # return the scenarios given a specific dataset. 121 | 122 | self.metrics = {'accuracy': [], 'f1_score': [], 'src_risk': [], 'few_shot_trg_risk': [], 123 | 'trg_risk': [], 'dev_risk': []} 124 | 125 | for i in scenarios: 126 | src_id = i[0] 127 | trg_id = i[1] 128 | 129 | for run_id in range(self.num_runs): # specify number of consecutive runs 130 | # fixing random seed 131 | fix_randomness(run_id) 132 | 133 | # Logging 134 | self.logger, self.scenario_log_dir = starting_logs(self.dataset, self.da_method, self.exp_log_dir, 135 | src_id, trg_id, run_id) 136 | 137 | # Load data 138 | self.load_data(src_id, trg_id) 139 | 140 | # get student algorithm 141 | algorithm_class = get_algorithm_class(self.da_method) 142 | backbone_fe = get_backbone_class(self.backbone) 143 | 144 | algorithm = algorithm_class(backbone_fe, self.dataset_configs, self.hparams, self.device) 145 | algorithm.to(self.device) 146 | 147 | # Average meters 148 | loss_avg_meters = collections.defaultdict(lambda: AverageMeter()) 149 | 150 | # Pre-trained Teacher model on Source Only 151 | for epoch in range(1, self.hparams["num_epochs"] + 1): 152 | joint_loaders = enumerate(zip(self.src_train_dl, self.trg_train_dl)) 153 | len_dataloader = min(len(self.src_train_dl), len(self.trg_train_dl)) 154 | algorithm.network_t.train() 155 | 156 | for step, ((src_x, src_y), (trg_x, _)) in joint_loaders: 157 | src_x, src_y = src_x.float().to(self.device), src_y.long().to(self.device) 158 | losses = algorithm.update_t(src_x, src_y) 159 | 160 | for key, val in losses.items(): 161 | loss_avg_meters[key].update(val, src_x.size(0)) 162 | # logging 163 | self.logger.debug(f'[Epoch : {epoch}/{self.hparams["num_epochs"]}]') 164 | for key, val in loss_avg_meters.items(): 165 | self.logger.debug(f'{key}\t: {val.avg:2.4f}') 166 | self.logger.debug(f'-------------------------------------') 167 | 168 | self.trg_pseudo_labels = np.array([]) 169 | self.trg_predictions = np.array([]) 170 | # Calculate the clusters on target data based on teacher's prediction 171 | algorithm.network_t.eval() 172 | with torch.no_grad(): 173 | for data, labels in self.trg_train_dl: 174 | data = data.float().to(self.device) 175 | 176 | # forward pass 177 | predictions = algorithm.network_t(data) 178 | predictions = torch.nn.functional.softmax(predictions) 179 | pred = predictions.detach().argmax(dim=1) # get the index of the max log-probability 180 | 181 | self.trg_pseudo_labels = np.append(self.trg_pseudo_labels, pred.cpu().numpy()) 182 | self.trg_predictions = np.append(self.trg_predictions, predictions.data.cpu().numpy()) 183 | 184 | self.clusters = np.array([]) 185 | trg_predictions = self.trg_predictions.reshape((-1,self.dataset_configs.num_classes)) 186 | 187 | for i in range(self.dataset_configs.num_classes): 188 | cluster_samples = np.take(trg_predictions, np.where(self.trg_pseudo_labels == i)[0],axis=0) 189 | cluster = np.mean(cluster_samples, axis=0) 190 | self.clusters = np.append(self.clusters, cluster) 191 | 192 | clusters = self.clusters.reshape((-1,self.dataset_configs.num_classes)) 193 | if np.isnan(np.min(self.clusters)): 194 | nan_index = np.argwhere(np.isnan(clusters))[0][0] 195 | for i in range(self.dataset_configs.num_classes): 196 | clusters[nan_index][i]=0 197 | clusters[nan_index][nan_index] = 1 198 | clusters = torch.from_numpy(clusters).to(self.device) 199 | 200 | # Calculate the MCD scores for each samples 201 | self.mcd_score_list = np.array([]) 202 | self.tgt_data = [] 203 | for data, labels in self.trg_train_dl: 204 | data = data.float().to(self.device) 205 | # forward pass 206 | predictions = algorithm.network_t(data) 207 | predictions = torch.nn.functional.softmax(predictions) 208 | mcd = self.mcd_score(predictions, clusters) 209 | self.mcd_score_list = np.append(self.mcd_score_list, mcd.cpu().detach().numpy()) 210 | self.tgt_data.extend(data.tolist()) 211 | tgt_data = np.asarray(self.tgt_data) 212 | # select top x% mcd scores 213 | ratio = 0.9 # 50% 214 | num_selected = int(ratio * tgt_data.shape[0]) 215 | selected_tgt_data = tgt_data[np.argpartition(self.mcd_score_list, -num_selected)[-num_selected:]] 216 | 217 | X_train = torch.from_numpy(selected_tgt_data) 218 | train_ds = MyDataset(X_train) 219 | self.optimal_tgt_data = torch.utils.data.DataLoader(dataset=train_ds, batch_size=self.hparams["batch_size"], 220 | shuffle=True, drop_last=True, num_workers=0) 221 | 222 | 223 | # Student training.. 224 | for epoch in range(1, self.hparams["num_epochs"] + 1): 225 | joint_loaders = enumerate(zip(self.src_train_dl, self.optimal_tgt_data)) 226 | len_dataloader = min(len(self.src_train_dl), len(self.optimal_tgt_data)) 227 | algorithm.network.train() 228 | 229 | for step, ((_, _), (trg_x)) in joint_loaders: 230 | trg_x = trg_x.float().to(self.device) 231 | losses = algorithm.update_s(trg_x, clusters) 232 | 233 | for key, val in losses.items(): 234 | loss_avg_meters[key].update(val, src_x.size(0)) 235 | 236 | # logging 237 | self.logger.debug(f'[Epoch : {epoch}/{self.hparams["num_epochs"]}]') 238 | for key, val in loss_avg_meters.items(): 239 | self.logger.debug(f'{key}\t: {val.avg:2.4f}') 240 | self.logger.debug(f'-------------------------------------') 241 | 242 | self.algorithm = algorithm 243 | save_checkpoint(self.home_path, self.algorithm, scenarios, self.dataset_configs, 244 | self.scenario_log_dir, self.hparams) 245 | 246 | self.evaluate() 247 | self.calc_results_per_run() 248 | 249 | # logging metrics 250 | self.calc_overall_results() 251 | average_metrics = {metric: np.mean(value) for (metric, value) in self.metrics.items()} 252 | wandb.log(average_metrics) 253 | wandb.log({'hparams': wandb.Table( 254 | dataframe=pd.DataFrame(dict(self.hparams).items(), columns=['parameter', 'value']), 255 | allow_mixed_types=True)}) 256 | wandb.log({'avg_results': wandb.Table(dataframe=self.averages_results_df, allow_mixed_types=True)}) 257 | wandb.log({'std_results': wandb.Table(dataframe=self.std_results_df, allow_mixed_types=True)}) 258 | 259 | def mcd_score(self, predictions, clusters): 260 | scores = 0 261 | num_class = clusters.shape[1] 262 | for i in range(num_class-1): 263 | for j in range(i+1, num_class): 264 | scores +=torch.abs(torch.nn.functional.cosine_similarity(predictions, clusters[i])- 265 | torch.nn.functional.cosine_similarity(predictions, clusters[j])) 266 | scores = scores /(num_class*(num_class-1)/2) 267 | return scores 268 | 269 | def evaluate(self): 270 | feature_extractor = self.algorithm.feature_extractor.to(self.device) 271 | classifier = self.algorithm.classifier.to(self.device) 272 | 273 | feature_extractor.eval() 274 | classifier.eval() 275 | 276 | total_loss_ = [] 277 | 278 | self.trg_pred_labels = np.array([]) 279 | self.trg_true_labels = np.array([]) 280 | 281 | with torch.no_grad(): 282 | for data, labels in self.trg_test_dl: 283 | data = data.float().to(self.device) 284 | labels = labels.view((-1)).long().to(self.device) 285 | 286 | # forward pass 287 | features = feature_extractor(data) 288 | predictions = classifier(features) 289 | 290 | # compute loss 291 | loss = F.cross_entropy(predictions, labels) 292 | total_loss_.append(loss.item()) 293 | pred = predictions.detach().argmax(dim=1) # get the index of the max log-probability 294 | 295 | self.trg_pred_labels = np.append(self.trg_pred_labels, pred.cpu().numpy()) 296 | self.trg_true_labels = np.append(self.trg_true_labels, labels.data.cpu().numpy()) 297 | 298 | self.trg_loss = torch.tensor(total_loss_).mean() # average loss 299 | 300 | def get_configs(self): 301 | dataset_class = get_dataset_class(self.dataset) 302 | hparams_class = get_hparams_class(self.dataset) 303 | return dataset_class(), hparams_class() 304 | 305 | def load_data(self, src_id, trg_id): 306 | self.src_train_dl, self.src_test_dl = data_generator(self.data_path, src_id, self.dataset_configs, 307 | self.hparams) 308 | self.trg_train_dl, self.trg_test_dl = data_generator(self.data_path, trg_id, self.dataset_configs, 309 | self.hparams) 310 | self.few_shot_dl = few_shot_data_generator(self.trg_test_dl) 311 | 312 | # self.src_train_dl = generator_percentage_of_data(self.src_train_dl_) 313 | # self.trg_train_dl = generator_percentage_of_data(self.trg_train_dl_) 314 | 315 | def create_save_dir(self): 316 | if not os.path.exists(self.save_dir): 317 | os.mkdir(self.save_dir) 318 | 319 | def calc_results_per_run(self): 320 | ''' 321 | Calculates the acc, f1 and risk values for each cross-domain scenario 322 | ''' 323 | 324 | self.acc, self.f1 = _calc_metrics(self.trg_pred_labels, self.trg_true_labels, self.scenario_log_dir, 325 | self.home_path, 326 | self.dataset_configs.class_names) 327 | if self.is_sweep: 328 | self.src_risk = calculate_risk(self.algorithm, self.src_test_dl, self.device) 329 | self.trg_risk = calculate_risk(self.algorithm, self.trg_test_dl, self.device) 330 | self.few_shot_trg_risk = calculate_risk(self.algorithm, self.few_shot_dl, self.device) 331 | self.dev_risk = calc_dev_risk(self.algorithm, self.src_train_dl, self.trg_train_dl, self.src_test_dl, 332 | self.dataset_configs, self.device) 333 | 334 | run_metrics = {'accuracy': self.acc, 335 | 'f1_score': self.f1, 336 | 'src_risk': self.src_risk, 337 | 'few_shot_trg_risk': self.few_shot_trg_risk, 338 | 'trg_risk': self.trg_risk, 339 | 'dev_risk': self.dev_risk} 340 | 341 | df = pd.DataFrame(columns=["acc", "f1", "src_risk", "few_shot_trg_risk", "trg_risk", "dev_risk"]) 342 | df.loc[0] = [self.acc, self.f1, self.src_risk, self.few_shot_trg_risk, self.trg_risk, 343 | self.dev_risk] 344 | else: 345 | run_metrics = {'accuracy': self.acc, 'f1_score': self.f1} 346 | df = pd.DataFrame(columns=["acc", "f1"]) 347 | df.loc[0] = [self.acc, self.f1] 348 | 349 | for (key, val) in run_metrics.items(): self.metrics[key].append(val) 350 | 351 | scores_save_path = os.path.join(self.home_path, self.scenario_log_dir, "scores.xlsx") 352 | df.to_excel(scores_save_path, index=False) 353 | self.results_df = df 354 | 355 | def calc_overall_results(self): 356 | exp = self.exp_log_dir 357 | 358 | # for exp in experiments: 359 | if self.is_sweep: 360 | results = pd.DataFrame( 361 | columns=["scenario", "acc", "f1", "src_risk", "few_shot_trg_risk", "trg_risk", "dev_risk"]) 362 | else: 363 | results = pd.DataFrame(columns=["scenario", "acc", "f1"]) 364 | 365 | scenarios_list = os.listdir(exp) 366 | scenarios_list = [i for i in scenarios_list if "_to_" in i] 367 | scenarios_list.sort() 368 | 369 | unique_scenarios_names = [f'{i}_to_{j}' for i, j in self.dataset_configs.scenarios] 370 | 371 | for scenario in scenarios_list: 372 | scenario_dir = os.path.join(exp, scenario) 373 | scores = pd.read_excel(os.path.join(scenario_dir, 'scores.xlsx')) 374 | scores.insert(0, 'scenario', '_'.join(scenario.split('_')[:-2])) 375 | results = pd.concat([results, scores]) 376 | 377 | avg_results = results.groupby('scenario').mean() 378 | std_results = results.groupby('scenario').std() 379 | 380 | avg_results.loc[len(avg_results)] = avg_results.mean() 381 | avg_results.insert(0, "scenario", list(unique_scenarios_names) + ['mean'], True) 382 | std_results.insert(0, "scenario", list(unique_scenarios_names), True) 383 | 384 | report_save_path_avg = os.path.join(exp, f"Average_results.xlsx") 385 | report_save_path_std = os.path.join(exp, f"std_results.xlsx") 386 | 387 | self.averages_results_df = avg_results 388 | self.std_results_df = std_results 389 | avg_results.to_excel(report_save_path_avg) 390 | std_results.to_excel(report_save_path_std) 391 | 392 | 393 | parser = argparse.ArgumentParser() 394 | 395 | # ======== Experiments Name ================ 396 | parser.add_argument('--save_dir', default='experiments_logs_additional', type=str, help='Directory containing all experiments') 397 | parser.add_argument('--experiment_description', default='FD', type=str, help='Name of your experiment (HAR, FD, HHAR_SA,EEG ') 398 | parser.add_argument('--run_description', default='MCD', type=str, help='name of your runs') 399 | # ========= Select the DA methods ============ 400 | parser.add_argument('--da_method', default='MCD', type=str, help='MCD') 401 | 402 | # ========= Select the DATASET ============== 403 | parser.add_argument('--data_path', default=r'./data', type=str, help='Path containing dataset') 404 | parser.add_argument('--dataset', default='FD', type=str, help='Dataset of choice: (HAR - FD - HHAR_SA - EEG)') 405 | 406 | # ========= Select the BACKBONE ============== 407 | parser.add_argument('--backbone', default='CNN', type=str, help='Backbone of choice: (CNN - RESNET18 - TCN - RESNET34 -RESNET1D_WANG)') 408 | 409 | # ========= Experiment settings =============== 410 | parser.add_argument('--num_runs', default=3, type=int, help='Number of consecutive run with different seeds') 411 | parser.add_argument('--device', default='cuda:0', type=str, help='cpu or cuda') 412 | 413 | # ======== sweep settings ===================== 414 | parser.add_argument('--is_sweep', default=False, type=bool, help='singe run or sweep') 415 | parser.add_argument('--num_sweeps', default=20, type=str, help='Number of sweep runs') 416 | 417 | # We run sweeps using wandb plateform, so next parameters are for wandb. 418 | parser.add_argument('--sweep_project_wandb', default='TEST_SOMETHING', type=str, help='Project name in Wandb') 419 | parser.add_argument('--wandb_entity', type=str, help='Entity name in Wandb (can be left blank if there is a default entity)') 420 | parser.add_argument('--hp_search_strategy', default="random", type=str, help='The way of selecting hyper-parameters (random-grid-bayes). in wandb see:https://docs.wandb.ai/guides/sweeps/configuration') 421 | parser.add_argument('--metric_to_minimize', default="src_risk", type=str, help='select one of: (src_risk - trg_risk - few_shot_trg_risk - dev_risk)') 422 | 423 | args = parser.parse_args() 424 | 425 | 426 | if __name__ == "__main__": 427 | trainer = mcd_kd_trainer(args) 428 | 429 | if args.is_sweep: 430 | trainer.sweep() 431 | else: 432 | trainer.train() 433 | -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import math 4 | from torch.autograd import Function 5 | from torch.nn.utils import weight_norm 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | import numpy as np 9 | import random 10 | 11 | 12 | # from utils import weights_init 13 | 14 | def get_backbone_class(backbone_name): 15 | """Return the algorithm class with the given name.""" 16 | if backbone_name not in globals(): 17 | raise NotImplementedError("Algorithm not found: {}".format(backbone_name)) 18 | return globals()[backbone_name] 19 | 20 | 21 | ################################################## 22 | ########## BACKBONE NETWORKS ################### 23 | ################################################## 24 | 25 | ########## CNN ############################# 26 | class CNN(nn.Module): 27 | def __init__(self, configs): 28 | super(CNN, self).__init__() 29 | 30 | self.conv_block1 = nn.Sequential( 31 | nn.Conv1d(configs.input_channels, configs.mid_channels, kernel_size=configs.kernel_size, 32 | stride=configs.stride, bias=False, padding=(configs.kernel_size // 2)), 33 | nn.BatchNorm1d(configs.mid_channels), 34 | nn.ReLU(), 35 | nn.MaxPool1d(kernel_size=2, stride=2, padding=1), 36 | nn.Dropout(configs.dropout) 37 | ) 38 | 39 | self.conv_block2 = nn.Sequential( 40 | nn.Conv1d(configs.mid_channels, configs.mid_channels * 2, kernel_size=8, stride=1, bias=False, padding=4), 41 | nn.BatchNorm1d(configs.mid_channels * 2), 42 | nn.ReLU(), 43 | nn.MaxPool1d(kernel_size=2, stride=2, padding=1), 44 | nn.Dropout(configs.dropout) 45 | ) 46 | 47 | self.conv_block3 = nn.Sequential( 48 | nn.Conv1d(configs.mid_channels * 2, configs.final_out_channels, kernel_size=8, stride=1, bias=False, 49 | padding=4), 50 | nn.BatchNorm1d(configs.final_out_channels), 51 | nn.ReLU(), 52 | nn.MaxPool1d(kernel_size=2, stride=2, padding=1), 53 | nn.Dropout(configs.dropout) 54 | ) 55 | 56 | self.adaptive_pool = nn.AdaptiveAvgPool1d(configs.features_len) 57 | 58 | # weights_init(self.conv_block1) 59 | # weights_init(self.conv_block2) 60 | # weights_init(self.conv_block3) 61 | 62 | def forward(self, x_in): 63 | x = self.conv_block1(x_in) 64 | x = self.conv_block2(x) 65 | x = self.conv_block3(x) 66 | x = self.adaptive_pool(x) 67 | x_flat = x.reshape(x.shape[0], -1) 68 | return x_flat 69 | 70 | class classifier(nn.Module): 71 | def __init__(self, configs): 72 | super(classifier, self).__init__() 73 | 74 | model_output_dim = configs.features_len 75 | self.logits = nn.Linear(model_output_dim * configs.final_out_channels, configs.num_classes) 76 | 77 | def forward(self, x): 78 | predictions = self.logits(x) 79 | return predictions 80 | 81 | 82 | class CNN_T(nn.Module): 83 | def __init__(self, configs): 84 | super(CNN_T, self).__init__() 85 | 86 | self.conv_block1 = nn.Sequential( 87 | nn.Conv1d(configs.input_channels, configs.mid_channels_t, kernel_size=configs.kernel_size, 88 | stride=configs.stride, bias=False, padding=(configs.kernel_size // 2)), 89 | nn.BatchNorm1d(configs.mid_channels_t), 90 | nn.ReLU(), 91 | nn.MaxPool1d(kernel_size=2, stride=2, padding=1), 92 | nn.Dropout(configs.dropout) 93 | ) 94 | 95 | # self.conv_block1_backup = nn.Sequential( 96 | # nn.Conv1d(configs.input_channels, configs.final_out_channels_t, kernel_size=configs.kernel_size, 97 | # stride=configs.stride, bias=False, padding=(configs.kernel_size // 2)), 98 | # nn.BatchNorm1d(configs.final_out_channels_t), 99 | # nn.ReLU(), 100 | # nn.MaxPool1d(kernel_size=2, stride=2, padding=1), 101 | # nn.Dropout(configs.dropout) 102 | # ) 103 | 104 | self.conv_block2 = nn.Sequential( 105 | nn.Conv1d(configs.mid_channels_t, configs.mid_channels_t * 2, kernel_size=8, stride=1, bias=False, padding=4), 106 | nn.BatchNorm1d(configs.mid_channels_t * 2), 107 | nn.ReLU(), 108 | nn.MaxPool1d(kernel_size=2, stride=2, padding=1), 109 | nn.Dropout(configs.dropout) # 110 | ) 111 | 112 | # self.conv_block2_rep = nn.Sequential( 113 | # nn.Conv1d(configs.mid_channels_t*2, configs.mid_channels_t * 2, kernel_size=8, stride=1, bias=False, padding=4), 114 | # nn.BatchNorm1d(configs.mid_channels_t * 2), 115 | # nn.ReLU(), 116 | # nn.MaxPool1d(kernel_size=2, stride=2, padding=1) 117 | # ) 118 | 119 | self.conv_block3 = nn.Sequential( 120 | nn.Conv1d(configs.mid_channels_t * 2, configs.final_out_channels_t, kernel_size=8, stride=1, bias=False, 121 | padding=4), 122 | nn.BatchNorm1d(configs.final_out_channels_t), 123 | nn.ReLU(), 124 | nn.MaxPool1d(kernel_size=2, stride=2, padding=1), 125 | nn.Dropout(configs.dropout) # 126 | ) 127 | 128 | self.adaptive_pool = nn.AdaptiveAvgPool1d(configs.features_len) 129 | 130 | def forward(self, x_in): 131 | x = self.conv_block1(x_in) 132 | # x = self.conv_block1_backup(x_in) 133 | 134 | x = self.conv_block2(x) 135 | 136 | # x = self.conv_block2_rep(x) 137 | # x = self.conv_block2_rep(x) 138 | 139 | # x = self.conv_block2_rep(x) 140 | # x = self.conv_block2_rep(x) 141 | # 142 | # 143 | # x = self.conv_block2_rep(x) 144 | # x = self.conv_block2_rep(x) 145 | 146 | x = self.conv_block3(x) 147 | 148 | x = self.adaptive_pool(x) 149 | x_flat = x.reshape(x.shape[0], -1) 150 | return x_flat 151 | 152 | 153 | class classifier_T(nn.Module): 154 | def __init__(self, configs): 155 | super(classifier_T, self).__init__() 156 | 157 | model_output_dim = configs.features_len 158 | self.logits = nn.Linear(model_output_dim * configs.final_out_channels_t, configs.num_classes) 159 | 160 | def forward(self, x): 161 | predictions = self.logits(x) 162 | return predictions 163 | 164 | 165 | class CNN_mul_exit(nn.Module): 166 | def __init__(self, configs): 167 | super(CNN_mul_exit, self).__init__() 168 | 169 | model_output_dim = configs.features_len 170 | 171 | self.conv_block1 = nn.Sequential( 172 | nn.Conv1d(configs.input_channels, configs.mid_channels, kernel_size=configs.kernel_size, 173 | stride=configs.stride, bias=False, padding=(configs.kernel_size // 2)), 174 | nn.BatchNorm1d(configs.mid_channels), 175 | nn.ReLU(), 176 | nn.MaxPool1d(kernel_size=2, stride=2, padding=1), 177 | nn.Dropout(configs.dropout) 178 | ) 179 | 180 | self.conv_block2 = nn.Sequential( 181 | nn.Conv1d(configs.mid_channels, configs.mid_channels * 2, kernel_size=8, stride=1, bias=False, padding=4), 182 | nn.BatchNorm1d(configs.mid_channels * 2), 183 | nn.ReLU(), 184 | nn.MaxPool1d(kernel_size=2, stride=2, padding=1) 185 | ) 186 | 187 | self.conv_block3 = nn.Sequential( 188 | nn.Conv1d(configs.mid_channels * 2, configs.final_out_channels, kernel_size=8, stride=1, bias=False, 189 | padding=4), 190 | nn.BatchNorm1d(configs.final_out_channels), 191 | nn.ReLU(), 192 | nn.MaxPool1d(kernel_size=2, stride=2, padding=1), 193 | ) 194 | 195 | self.adaptive_pool = nn.AdaptiveAvgPool1d(configs.features_len) 196 | 197 | self.logits_low = nn.Linear(model_output_dim * configs.mid_channels, configs.num_classes) 198 | self.logits_middle = nn.Linear(model_output_dim * configs.mid_channels * 2, configs.num_classes) 199 | self.logits_top = nn.Linear(model_output_dim * configs.final_out_channels, configs.num_classes) 200 | 201 | def forward(self, x_in): 202 | x = self.conv_block1(x_in) 203 | exit_l = self.adaptive_pool(x).squeeze() # Low level exit 204 | exit_l = exit_l.reshape(exit_l.shape[0], -1) 205 | exit_l_logits = self.logits_low(exit_l) 206 | 207 | x = self.conv_block2(x) 208 | exit_m = self.adaptive_pool(x) # Middle level exit 209 | exit_m = exit_m.reshape(exit_m.shape[0], -1) 210 | exit_m_logits = self.logits_middle(exit_m) 211 | 212 | 213 | x = self.conv_block3(x) # 214 | x = self.adaptive_pool(x) 215 | x_flat = x.reshape(x.shape[0], -1) 216 | exit_top_logits = self.logits_top(x_flat) 217 | 218 | return exit_l_logits, exit_m_logits, exit_top_logits, x_flat 219 | 220 | 221 | ########## TCN ############################# 222 | torch.backends.cudnn.benchmark = True # might be required to fasten TCN 223 | 224 | 225 | class Chomp1d(nn.Module): 226 | def __init__(self, chomp_size): 227 | super(Chomp1d, self).__init__() 228 | self.chomp_size = chomp_size 229 | 230 | def forward(self, x): 231 | return x[:, :, :-self.chomp_size].contiguous() 232 | 233 | 234 | class TCN(nn.Module): 235 | def __init__(self, configs): 236 | super(TCN, self).__init__() 237 | 238 | in_channels0 = configs.input_channels 239 | out_channels0 = configs.tcn_layers[1] 240 | kernel_size = configs.tcn_kernel_size 241 | stride = 1 242 | dilation0 = 1 243 | padding0 = (kernel_size - 1) * dilation0 244 | 245 | self.net0 = nn.Sequential( 246 | weight_norm(nn.Conv1d(in_channels0, out_channels0, kernel_size, stride=stride, padding=padding0, 247 | dilation=dilation0)), 248 | nn.ReLU(), 249 | weight_norm(nn.Conv1d(out_channels0, out_channels0, kernel_size, stride=stride, padding=padding0, 250 | dilation=dilation0)), 251 | nn.ReLU(), 252 | ) 253 | 254 | self.downsample0 = nn.Conv1d(in_channels0, out_channels0, 1) if in_channels0 != out_channels0 else None 255 | self.relu = nn.ReLU() 256 | 257 | in_channels1 = configs.tcn_layers[0] 258 | out_channels1 = configs.tcn_layers[1] 259 | dilation1 = 2 260 | padding1 = (kernel_size - 1) * dilation1 261 | self.net1 = nn.Sequential( 262 | nn.Conv1d(in_channels0, out_channels1, kernel_size, stride=stride, padding=padding1, dilation=dilation1), 263 | nn.ReLU(), 264 | nn.Conv1d(out_channels1, out_channels1, kernel_size, stride=stride, padding=padding1, dilation=dilation1), 265 | nn.ReLU(), 266 | ) 267 | self.downsample1 = nn.Conv1d(out_channels1, out_channels1, 1) if in_channels1 != out_channels1 else None 268 | 269 | self.conv_block1 = nn.Sequential( 270 | nn.Conv1d(in_channels0, out_channels0, kernel_size=kernel_size, stride=stride, bias=False, padding=padding0, 271 | dilation=dilation0), 272 | Chomp1d(padding0), 273 | nn.BatchNorm1d(out_channels0), 274 | nn.ReLU(), 275 | 276 | nn.Conv1d(out_channels0, out_channels0, kernel_size=kernel_size, stride=stride, bias=False, 277 | padding=padding0, dilation=dilation0), 278 | Chomp1d(padding0), 279 | nn.BatchNorm1d(out_channels0), 280 | nn.ReLU(), 281 | ) 282 | 283 | self.conv_block2 = nn.Sequential( 284 | nn.Conv1d(out_channels0, out_channels1, kernel_size=kernel_size, stride=stride, bias=False, 285 | padding=padding1, dilation=dilation1), 286 | Chomp1d(padding1), 287 | nn.BatchNorm1d(out_channels1), 288 | nn.ReLU(), 289 | 290 | nn.Conv1d(out_channels1, out_channels1, kernel_size=kernel_size, stride=stride, bias=False, 291 | padding=padding1, dilation=dilation1), 292 | Chomp1d(padding1), 293 | nn.BatchNorm1d(out_channels1), 294 | nn.ReLU(), 295 | ) 296 | 297 | def forward(self, inputs): 298 | """Inputs have to have dimension (N, C_in, L_in)""" 299 | x0 = self.conv_block1(inputs) 300 | res0 = inputs if self.downsample0 is None else self.downsample0(inputs) 301 | out_0 = self.relu(x0 + res0) 302 | 303 | x1 = self.conv_block2(out_0) 304 | res1 = out_0 if self.downsample1 is None else self.downsample1(out_0) 305 | out_1 = self.relu(x1 + res1) 306 | 307 | out = out_1[:, :, -1] 308 | return out 309 | 310 | 311 | ######## RESNET ############################################## 312 | 313 | class RESNET18(nn.Module): 314 | def __init__(self, configs): 315 | layers = [2, 2, 2, 2] 316 | # block = BasicBlock 317 | block = BasicBlock1d 318 | 319 | self.inplanes = configs.input_channels 320 | super(RESNET18, self).__init__() 321 | self.layer1 = self._make_layer(block, configs.mid_channels, layers[0], stride=configs.stride) 322 | self.layer2 = self._make_layer(block, configs.mid_channels * 2, layers[1], stride=1) 323 | self.layer3 = self._make_layer(block, configs.final_out_channels, layers[2], stride=1) 324 | self.layer4 = self._make_layer(block, configs.final_out_channels, layers[3], stride=1) 325 | 326 | self.avgpool = nn.MaxPool1d(kernel_size=2, stride=2, padding=1) 327 | self.adaptive_pool = nn.AdaptiveAvgPool1d(configs.features_len) 328 | 329 | def _make_layer(self, block, planes, blocks, stride=1): 330 | downsample = None 331 | if stride != 1 or self.inplanes != planes * block.expansion: 332 | downsample = nn.Sequential( 333 | nn.Conv1d(self.inplanes, planes * block.expansion, 334 | kernel_size=1, stride=stride, bias=False), 335 | nn.BatchNorm1d(planes * block.expansion), 336 | ) 337 | layers = [] 338 | layers.append(block(self.inplanes, planes, stride, downsample)) 339 | self.inplanes = planes * block.expansion 340 | for i in range(1, blocks): 341 | layers.append(block(self.inplanes, planes)) 342 | return nn.Sequential(*layers) 343 | 344 | def forward(self, x): 345 | x = self.layer1(x) 346 | x = self.layer2(x) 347 | x = self.layer3(x) 348 | x = self.layer4(x) 349 | 350 | x = self.adaptive_pool(x) 351 | 352 | x_flat = x.reshape(x.shape[0], -1) 353 | return x_flat 354 | 355 | 356 | class RESNET34(nn.Module): 357 | def __init__(self, configs): 358 | layers = [3, 4, 6, 3] 359 | block = BasicBlock1d 360 | 361 | self.inplanes = configs.input_channels 362 | super(RESNET34, self).__init__() 363 | self.layer1 = self._make_layer(block, configs.mid_channels, layers[0], stride=configs.stride) 364 | self.layer2 = self._make_layer(block, configs.mid_channels * 2, layers[1], stride=1) 365 | self.layer3 = self._make_layer(block, configs.final_out_channels, layers[2], stride=1) 366 | self.layer4 = self._make_layer(block, configs.final_out_channels, layers[3], stride=1) 367 | 368 | self.avgpool = nn.MaxPool1d(kernel_size=2, stride=2, padding=1) 369 | self.adaptive_pool = nn.AdaptiveAvgPool1d(configs.features_len) 370 | 371 | def _make_layer(self, block, planes, blocks, stride=1): 372 | downsample = None 373 | if stride != 1 or self.inplanes != planes * block.expansion: 374 | downsample = nn.Sequential( 375 | nn.Conv1d(self.inplanes, planes * block.expansion, 376 | kernel_size=1, stride=stride, bias=False), 377 | nn.BatchNorm1d(planes * block.expansion), 378 | ) 379 | layers = [] 380 | layers.append(block(self.inplanes, planes, stride, downsample)) 381 | self.inplanes = planes * block.expansion 382 | for i in range(1, blocks): 383 | layers.append(block(self.inplanes, planes)) 384 | return nn.Sequential(*layers) 385 | 386 | def forward(self, x): 387 | x = self.layer1(x) 388 | x = self.layer2(x) 389 | x = self.layer3(x) 390 | x = self.layer4(x) 391 | 392 | x = self.adaptive_pool(x) 393 | 394 | x_flat = x.reshape(x.shape[0], -1) 395 | return x_flat 396 | 397 | 398 | class BasicBlock(nn.Module): 399 | expansion = 1 400 | 401 | def __init__(self, inplanes, planes, stride=1, downsample=None): 402 | super(BasicBlock, self).__init__() 403 | self.conv1 = nn.Conv1d(inplanes, planes, kernel_size=1, stride=stride, 404 | bias=False) 405 | self.bn1 = nn.BatchNorm1d(planes) 406 | 407 | self.downsample = downsample 408 | self.stride = stride 409 | self.relu = nn.ReLU(inplace=True) 410 | 411 | def forward(self, x): 412 | residual = x 413 | 414 | out = self.conv1(x) 415 | out = self.bn1(out) 416 | out = self.relu(out) 417 | 418 | if self.downsample is not None: 419 | residual = self.downsample(x) 420 | 421 | out = out + residual 422 | out = self.relu(out) 423 | 424 | return out 425 | 426 | 427 | def conv(in_planes, out_planes, stride=1, kernel_size=3): 428 | "convolution with padding" 429 | return nn.Conv1d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, 430 | padding=(kernel_size-1)//2, bias=False) 431 | 432 | 433 | class BasicBlock1d(nn.Module): 434 | expansion = 1 435 | def __init__(self, inplanes, planes, stride=1, downsample=None): 436 | super().__init__() 437 | 438 | # if(isinstance(kernel_size,int)): kernel_size = [kernel_size,kernel_size//2+1] 439 | 440 | self.conv1 = conv(inplanes, planes, stride=stride, kernel_size=1) 441 | self.bn1 = nn.BatchNorm1d(planes) 442 | self.relu = nn.ReLU(inplace=True) 443 | self.conv2 = conv(planes, planes,kernel_size=1) 444 | self.bn2 = nn.BatchNorm1d(planes) 445 | self.downsample = downsample 446 | self.stride = stride 447 | 448 | def forward(self, x): 449 | residual = x 450 | 451 | out = self.conv1(x) 452 | out = self.bn1(out) 453 | out = self.relu(out) 454 | 455 | out = self.conv2(out) 456 | out = self.bn2(out) 457 | 458 | if self.downsample is not None: 459 | residual = self.downsample(x) 460 | 461 | out = out + residual 462 | out = self.relu(out) 463 | 464 | return out 465 | 466 | 467 | class BasicBlock1d_wang(nn.Module): 468 | expansion = 1 469 | def __init__(self, inplanes, planes, stride=1, downsample=None,kernel_size=[5,3]): 470 | super().__init__() 471 | 472 | # if(isinstance(kernel_size,int)): kernel_size = [kernel_size,kernel_size//2+1] 473 | 474 | self.conv1 = conv(inplanes, planes, stride=stride, kernel_size=kernel_size[0]) 475 | self.bn1 = nn.BatchNorm1d(planes) 476 | self.relu = nn.ReLU(inplace=True) 477 | self.conv2 = conv(planes, planes,kernel_size=kernel_size[1]) 478 | self.bn2 = nn.BatchNorm1d(planes) 479 | self.downsample = downsample 480 | self.stride = stride 481 | 482 | def forward(self, x): 483 | residual = x 484 | 485 | out = self.conv1(x) 486 | out = self.bn1(out) 487 | out = self.relu(out) 488 | 489 | out = self.conv2(out) 490 | out = self.bn2(out) 491 | 492 | if self.downsample is not None: 493 | residual = self.downsample(x) 494 | 495 | out = out + residual 496 | out = self.relu(out) 497 | 498 | return out 499 | 500 | 501 | class RESNET1D_WANG(nn.Module): 502 | def __init__(self, configs): 503 | layers = [1,1,1] 504 | block = BasicBlock1d_wang 505 | 506 | self.input_channels = configs.input_channels 507 | self.inplanes = configs.mid_channels 508 | super(RESNET1D_WANG, self).__init__() 509 | 510 | self.stem = nn.Sequential( 511 | nn.Conv1d(self.input_channels, configs.mid_channels, kernel_size=7, stride=1, padding=3,bias=False), 512 | nn.BatchNorm1d(configs.mid_channels), 513 | nn.ReLU(inplace=True) 514 | ) 515 | 516 | self.layer2 = self._make_layer(block, configs.mid_channels, layers[0], stride=configs.stride) 517 | self.layer3 = self._make_layer(block, configs.mid_channels * 2, layers[1], stride=1) 518 | self.layer4 = self._make_layer(block, configs.final_out_channels, layers[2], stride=1) 519 | 520 | self.avgpool = nn.MaxPool1d(kernel_size=2, stride=2, padding=1) 521 | self.adaptive_pool = nn.AdaptiveAvgPool1d(configs.features_len) 522 | 523 | def _make_layer(self, block, planes, blocks, stride=1): 524 | downsample = None 525 | if stride != 1 or self.inplanes != planes * block.expansion: 526 | downsample = nn.Sequential( 527 | nn.Conv1d(self.inplanes, planes * block.expansion, 528 | kernel_size=1, stride=stride, bias=False), 529 | nn.BatchNorm1d(planes * block.expansion), 530 | ) 531 | layers = [] 532 | layers.append(block(self.inplanes, planes, stride, downsample)) 533 | self.inplanes = planes * block.expansion 534 | for i in range(1, blocks): 535 | layers.append(block(self.inplanes, planes)) 536 | return nn.Sequential(*layers) 537 | 538 | def forward(self, x): 539 | x = self.stem(x) 540 | x = self.layer2(x) 541 | x = self.layer3(x) 542 | x = self.layer4(x) 543 | 544 | x = self.adaptive_pool(x) 545 | 546 | x_flat = x.reshape(x.shape[0], -1) 547 | return x_flat 548 | 549 | ################################################## 550 | ########## OTHER NETWORKS ###################### 551 | ################################################## 552 | 553 | class codats_classifier(nn.Module): 554 | def __init__(self, configs): 555 | super(codats_classifier, self).__init__() 556 | model_output_dim = configs.features_len 557 | self.hidden_dim = configs.hidden_dim 558 | self.logits = nn.Sequential( 559 | nn.Linear(model_output_dim * configs.final_out_channels, self.hidden_dim), 560 | nn.ReLU(), 561 | nn.Linear(self.hidden_dim, self.hidden_dim), 562 | nn.ReLU(), 563 | nn.Linear(self.hidden_dim, configs.num_classes)) 564 | 565 | def forward(self, x_in): 566 | predictions = self.logits(x_in) 567 | return predictions 568 | 569 | 570 | class Discriminator(nn.Module): 571 | """Discriminator model for source domain.""" 572 | 573 | def __init__(self, configs): 574 | """Init discriminator.""" 575 | super(Discriminator, self).__init__() 576 | 577 | self.layer = nn.Sequential( 578 | nn.Linear(configs.features_len * configs.final_out_channels, configs.disc_hid_dim), 579 | nn.ReLU(), 580 | nn.Linear(configs.disc_hid_dim, configs.disc_hid_dim), 581 | nn.ReLU(), 582 | nn.Linear(configs.disc_hid_dim, 2) 583 | # nn.LogSoftmax(dim=1) 584 | ) 585 | 586 | def forward(self, input): 587 | """Forward the discriminator.""" 588 | out = self.layer(input) 589 | return out 590 | 591 | class Discriminator_t(nn.Module): 592 | """Discriminator model for source domain.""" 593 | 594 | def __init__(self,configs): 595 | """Init discriminator.""" 596 | super(Discriminator_t, self).__init__() 597 | 598 | self.layer = nn.Sequential( 599 | nn.Linear(128, configs.disc_hid_dim), 600 | nn.ReLU(), 601 | nn.Linear(configs.disc_hid_dim, configs.disc_hid_dim), 602 | nn.ReLU(), 603 | nn.Linear(configs.disc_hid_dim, 2) 604 | # nn.LogSoftmax(dim=1) 605 | ) 606 | 607 | def forward(self, input): 608 | """Forward the discriminator.""" 609 | out = self.layer(input) 610 | return out 611 | 612 | class Discriminator_fea(nn.Module): 613 | """Discriminator model for source domain.""" 614 | 615 | def __init__(self, configs): 616 | """Init discriminator.""" 617 | super(Discriminator_fea, self).__init__() 618 | 619 | self.layer = nn.Sequential( 620 | nn.Linear(configs.features_len * configs.final_out_channels_t, configs.hidden_dim), 621 | nn.ReLU(), 622 | nn.Linear(configs.hidden_dim, configs.hidden_dim), 623 | nn.ReLU(), 624 | nn.Linear(configs.hidden_dim, 1), 625 | nn.Sigmoid() 626 | ) 627 | 628 | def forward(self, input): 629 | """Forward the discriminator.""" 630 | out = self.layer(input) 631 | return out 632 | 633 | class Discriminator_s(nn.Module): 634 | """Discriminator model for source domain.""" 635 | 636 | def __init__(self, configs): 637 | """Init discriminator.""" 638 | super(Discriminator_s, self).__init__() 639 | 640 | self.layer = nn.Sequential( 641 | nn.Linear(configs.features_len * configs.final_out_channels, configs.hidden_dim), 642 | nn.ReLU(), 643 | nn.Linear(configs.hidden_dim, configs.hidden_dim), 644 | nn.ReLU(), 645 | nn.Linear(configs.hidden_dim, 2), 646 | nn.LogSoftmax() 647 | ) 648 | 649 | def forward(self, input): 650 | """Forward the discriminator.""" 651 | out = self.layer(input) 652 | return out 653 | 654 | class Adapter(nn.Module): 655 | """mapping student feature dimension to teacher feature dimension""" 656 | 657 | def __init__(self, configs): 658 | """Init adaptor.""" 659 | super(Adapter, self).__init__() 660 | self.layer = nn.Linear(configs.final_out_channels, configs.final_out_channels_t) 661 | 662 | def forward(self, input): 663 | """Forward the adaptor.""" 664 | out = self.layer(input) 665 | return out 666 | 667 | 668 | #### Codes required by DANN ############## 669 | class ReverseLayerF(Function): 670 | @staticmethod 671 | def forward(ctx, x, alpha): 672 | ctx.alpha = alpha 673 | return x.view_as(x) 674 | 675 | @staticmethod 676 | def backward(ctx, grad_output): 677 | output = grad_output.neg() * ctx.alpha 678 | return output, None 679 | 680 | 681 | #### Codes required by CDAN ############## 682 | class RandomLayer(nn.Module): 683 | def __init__(self, input_dim_list=[], output_dim=1024): 684 | super(RandomLayer, self).__init__() 685 | self.input_num = len(input_dim_list) 686 | self.output_dim = output_dim 687 | self.random_matrix = [torch.randn(input_dim_list[i], output_dim) for i in range(self.input_num)] 688 | 689 | def forward(self, input_list): 690 | return_list = [torch.mm(input_list[i], self.random_matrix[i]) for i in range(self.input_num)] 691 | return_tensor = return_list[0] / math.pow(float(self.output_dim), 1.0 / len(return_list)) 692 | for single in return_list[1:]: 693 | return_tensor = torch.mul(return_tensor, single) 694 | return return_tensor 695 | 696 | def cuda(self): 697 | super(RandomLayer, self).cuda() 698 | self.random_matrix = [val.cuda() for val in self.random_matrix] 699 | 700 | 701 | class Discriminator_CDAN(nn.Module): 702 | """Discriminator model for CDAN .""" 703 | 704 | def __init__(self, configs): 705 | """Init discriminator.""" 706 | super(Discriminator_CDAN, self).__init__() 707 | 708 | self.restored = False 709 | 710 | self.layer = nn.Sequential( 711 | nn.Linear(configs.features_len * configs.final_out_channels_t * configs.num_classes, configs.disc_hid_dim), 712 | nn.ReLU(), 713 | nn.Linear(configs.disc_hid_dim, configs.disc_hid_dim), 714 | nn.ReLU(), 715 | nn.Linear(configs.disc_hid_dim, 2) 716 | # nn.LogSoftmax(dim=1) 717 | ) 718 | 719 | def forward(self, input): 720 | """Forward the discriminator.""" 721 | out = self.layer(input) 722 | return out 723 | 724 | 725 | #### Codes required by AdvSKM ############## 726 | class Cosine_act(nn.Module): 727 | def __init__(self): 728 | super(Cosine_act, self).__init__() 729 | 730 | def forward(self, input): 731 | return torch.cos(input) 732 | 733 | 734 | cos_act = Cosine_act() 735 | 736 | class AdvSKM_Disc(nn.Module): 737 | """Discriminator model for source domain.""" 738 | 739 | def __init__(self, configs): 740 | """Init discriminator.""" 741 | super(AdvSKM_Disc, self).__init__() 742 | 743 | self.input_dim = configs.features_len * configs.final_out_channels 744 | self.hid_dim = configs.DSKN_disc_hid 745 | self.branch_1 = nn.Sequential( 746 | nn.Linear(self.input_dim, self.hid_dim), 747 | nn.Linear(self.hid_dim, self.hid_dim), 748 | nn.BatchNorm1d(self.hid_dim), 749 | cos_act, 750 | nn.Linear(self.hid_dim, self.hid_dim // 2), 751 | nn.Linear(self.hid_dim // 2, self.hid_dim // 2), 752 | nn.BatchNorm1d(self.hid_dim // 2), 753 | cos_act 754 | ) 755 | self.branch_2 = nn.Sequential( 756 | nn.Linear(configs.features_len * configs.final_out_channels, configs.disc_hid_dim), 757 | nn.Linear(configs.disc_hid_dim, configs.disc_hid_dim), 758 | nn.BatchNorm1d(configs.disc_hid_dim), 759 | nn.ReLU(), 760 | nn.Linear(configs.disc_hid_dim, configs.disc_hid_dim // 2), 761 | nn.Linear(configs.disc_hid_dim // 2, configs.disc_hid_dim // 2), 762 | nn.BatchNorm1d(configs.disc_hid_dim // 2), 763 | nn.ReLU()) 764 | 765 | def forward(self, input): 766 | """Forward the discriminator.""" 767 | out_cos = self.branch_1(input) 768 | out_rel = self.branch_2(input) 769 | total_out = torch.cat((out_cos, out_rel), dim=1) 770 | return total_out 771 | 772 | ######### DDQN ########### 773 | class NoisyLinear(nn.Module): 774 | def __init__(self, in_features, out_features, std_init=0.4): 775 | super(NoisyLinear, self).__init__() 776 | 777 | self.in_features = in_features 778 | self.out_features = out_features 779 | self.std_init = std_init 780 | 781 | self.weight_mu = nn.Parameter(torch.FloatTensor(out_features, in_features)) 782 | self.weight_sigma = nn.Parameter(torch.FloatTensor(out_features, in_features)) 783 | self.register_buffer('weight_epsilon', torch.FloatTensor(out_features, in_features)) 784 | 785 | self.bias_mu = nn.Parameter(torch.FloatTensor(out_features)) 786 | self.bias_sigma = nn.Parameter(torch.FloatTensor(out_features)) 787 | self.register_buffer('bias_epsilon', torch.FloatTensor(out_features)) 788 | 789 | self.reset_parameters() 790 | self.reset_noise() 791 | 792 | def forward(self, x): 793 | if self.training: 794 | weight = self.weight_mu + self.weight_sigma.mul(Variable(self.weight_epsilon)) 795 | bias = self.bias_mu + self.bias_sigma.mul(Variable(self.bias_epsilon)) 796 | else: 797 | weight = self.weight_mu 798 | bias = self.bias_mu 799 | 800 | return F.linear(x, weight, bias) 801 | 802 | def reset_parameters(self): 803 | mu_range = 1 / np.sqrt(self.weight_mu.size(1)) 804 | 805 | self.weight_mu.data.uniform_(-mu_range, mu_range) 806 | self.weight_sigma.data.fill_(self.std_init / np.sqrt(self.weight_sigma.size(1))) 807 | 808 | self.bias_mu.data.uniform_(-mu_range, mu_range) 809 | self.bias_sigma.data.fill_(self.std_init / np.sqrt(self.bias_sigma.size(0))) 810 | 811 | def reset_noise(self): 812 | epsilon_in = self._scale_noise(self.in_features) 813 | epsilon_out = self._scale_noise(self.out_features) 814 | 815 | self.weight_epsilon.copy_(epsilon_out.ger(epsilon_in)) 816 | self.bias_epsilon.copy_(self._scale_noise(self.out_features)) 817 | 818 | def _scale_noise(self, size): 819 | x = torch.randn(size) 820 | x = x.sign().mul(x.abs().sqrt()) 821 | return x 822 | 823 | 824 | class Qnet(nn.Module): 825 | def __init__(self): 826 | super(Qnet, self).__init__() 827 | 828 | self.linear1 = nn.Linear(32, 1024) 829 | 830 | self.noisy_value1 = NoisyLinear(1024, 1024) 831 | self.noisy_value2 = NoisyLinear(1024, 1) 832 | 833 | self.noisy_advantage1 = NoisyLinear(1024, 1024) 834 | self.noisy_advantage2 = NoisyLinear(1024, 2) 835 | 836 | def forward(self, x): 837 | x = F.relu(self.linear1(x)) 838 | value = F.relu(self.noisy_value1(x)) 839 | value = self.noisy_value2(value) 840 | advantage = F.relu(self.noisy_advantage1(x)) 841 | advantage = self.noisy_advantage2(advantage) 842 | return value + advantage - advantage.mean() # dim = [batch_size, 2] 843 | 844 | def sample_action(self, obs, epsilon): 845 | out = self.forward(obs) 846 | # if Q value of action 1 == Q_value of action_2 847 | if out[0, 0] == out[0, 1]: 848 | return np.array([random.randrange(2)]), out, 2 849 | else: 850 | # return shape ndarray (batch_size,) 851 | return (torch.argmax(out, dim=1)).cpu().detach().numpy(), out, 0 852 | 853 | def reset_noise(self): 854 | self.noisy_value1.reset_noise() 855 | self.noisy_value2.reset_noise() 856 | self.noisy_advantage1.reset_noise() 857 | self.noisy_advantage2.reset_noise() 858 | --------------------------------------------------------------------------------