├── HMGAN_framework.png ├── README.md ├── get_data.py ├── main.py ├── sliding_window.py ├── model_DeepSense.py ├── args_space.py ├── model_HMGAN.py ├── metrics.py ├── preprocess.py └── solver_HMGAN.py /HMGAN_framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rxannro/HMGAN/HEAD/HMGAN_framework.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HMGAN 2 | 3 | This is the official repository for our paper: HMGAN: A Hierarchical Multi-Modal Generative Adversarial Network Model for Wearable Human Activity Recognition 4 | 5 | ![framework](HMGAN_framework.png) 6 | 7 | ## Dependencies 8 | 9 | * python 3.8 10 | * torch == 1.10.0 (with suitable CUDA and CuDNN version) 11 | * numpy, torchmetrics, scipy, pandas, argparse, sklearn 12 | 13 | ## Datasets 14 | 15 | | Dataset | Download Link | 16 | | -- | -- | 17 | | UTD-MHAD | https://personal.utdallas.edu/~kehtar/UTD-MHAD.html | 18 | | UCI-HAR | https://archive.ics.uci.edu/ml/datasets/human+activity+recognition+using+smartphones | 19 | | OPPORTUNITY | https://archive.ics.uci.edu/ml/datasets/opportunity+activity+recognition | 20 | 21 | ## Quick Start 22 | 23 | Data preprocessing is included in main.py. Download the datasets and run HMGAN as follows. This gives the performance of each split in 5-fold cross-validation, and their average. 24 | ``` 25 | python main.py --data_path [/path/to/dataset] --dataset [UTD_MHAD_arm, UCI_HAR, or OPPORTUNITY] 26 | ``` 27 | -------------------------------------------------------------------------------- /get_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import DataLoader, TensorDataset 4 | from sklearn.model_selection import train_test_split 5 | import numpy as np 6 | 7 | def get_data(args): 8 | train_ratio = 0.8 9 | if 'UTD_MHAD' in args.dataset: 10 | path = args.data_dir + 'UTD_MHAD/' + args.dataset[-3:] + '/processed_data' 11 | else: 12 | path = args.data_dir + args.dataset + '/processed_data' 13 | 14 | x_all = np.load(path+'/features.npy') 15 | y_all = np.load(path+'/labels.npy') 16 | train_idx = np.load(path+'/fold{}_train_idx.npy'.format(args.test_fold)) 17 | test_idx = np.load(path+'/fold{}_test_idx.npy'.format(args.test_fold)) 18 | 19 | x_train, x_test = x_all[train_idx], x_all[test_idx] 20 | y_train, y_test = y_all[train_idx], y_all[test_idx] 21 | 22 | x_train, x_valid, y_train, y_valid = train_test_split(x_train, y_train, train_size = train_ratio, random_state = 0) 23 | 24 | if 'multiply' in args.aug_type: 25 | x_train, y_train = augment_to_mutiply(args, x_train, y_train) 26 | 27 | train_dataset = TensorDataset(torch.from_numpy(x_train.astype(np.float32)), torch.from_numpy(y_train)) 28 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=0) 29 | 30 | valid_dataset = TensorDataset(torch.from_numpy(x_valid.astype(np.float32)), torch.from_numpy(y_valid)) 31 | valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=0) 32 | 33 | test_dataset = TensorDataset(torch.from_numpy(x_test.astype(np.float32)), torch.from_numpy(y_test)) 34 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=0) 35 | 36 | return train_loader, valid_loader, test_loader 37 | 38 | def window_slice(x, reduce_ratio=0.9): 39 | # https://halshs.archives-ouvertes.fr/halshs-01357973/document 40 | target_len = np.ceil(reduce_ratio*x.shape[1]).astype(int) 41 | if target_len >= x.shape[1]: 42 | return x 43 | starts = np.random.randint(low=0, high=x.shape[1]-target_len, size=(x.shape[0])).astype(int) 44 | ends = (target_len + starts).astype(int) 45 | 46 | ret = np.zeros_like(x) 47 | for i, pat in enumerate(x): 48 | for dim in range(x.shape[2]): 49 | ret[i,:,dim] = np.interp(np.linspace(0, target_len, num=x.shape[1]), np.arange(target_len), pat[starts[i]:ends[i],dim]).T 50 | return ret 51 | 52 | def augment_to_mutiply(args, data, labels):# instead of changing the DA function, we could just sample train data ahead according to labels 53 | np.random.seed(args.seed) 54 | data_to_aug = np.repeat(data, args.N_aug, axis=0) 55 | labels_to_aug = np.repeat(labels, args.N_aug, axis=0) 56 | generated_data = window_slice(data_to_aug) 57 | return np.concatenate([data, generated_data]), np.concatenate([labels, labels_to_aug]) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import args_space 3 | import solver_HMGAN 4 | import preprocess 5 | import numpy as np 6 | import warnings 7 | import os 8 | import time 9 | warnings.filterwarnings("ignore") 10 | 11 | def main(args): 12 | if args.cuda != -1: 13 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda) 14 | 15 | acc, f1 = [np.empty([args.N_folds], dtype=np.float) for _ in range(2)] 16 | p_score, d_score, tstr_score = [np.empty([args.N_folds], dtype=np.float) for _ in range(3)] 17 | starttime = time.time() 18 | for test_fold in range(args.N_folds): 19 | args.test_fold = test_fold 20 | 21 | if 'HMGAN' in args.model_type: 22 | mysolver = solver_HMGAN.DASolver_HMGAN(args) 23 | print('\n=== ' + args.dataset + '_' + args.model_type + '_fold' + str(args.test_fold) + ' ===') 24 | 25 | test_acc, test_f1 = mysolver.train() 26 | 27 | acc[test_fold] = test_acc 28 | f1[test_fold] = test_f1 29 | 30 | test_p_score, test_d_score, _, test_tstr_score = mysolver.eval_gen_data(training=True) 31 | p_score[test_fold] = test_p_score 32 | d_score[test_fold] = test_d_score 33 | tstr_score[test_fold] = test_tstr_score 34 | 35 | endtime = time.time() 36 | 37 | print('\n=== ' + args.dataset + '_' + args.model_type + ' ===') 38 | print('Duration: ', round(endtime - starttime, 2), 'secs') 39 | print(args) 40 | print("FINAL VALUE: \nacc: ", np.around(acc, 3), "\nf1: ", np.around(f1, 3)) 41 | print("FINAL AVERAGE: \nacc: ", np.around(np.mean(acc), 3), "\nf1: ", np.around(np.mean(f1), 3)) 42 | print("FINAL STD: \nacc: ", np.around(np.std(acc), 3), "\nf1: ", np.around(np.std(f1), 3)) 43 | print("OTHER FINAL VALUE: \np_score: ", np.around(p_score, 3), "\nd_score: ", np.around(d_score, 3), "\ntstr_score: ", np.around(tstr_score, 3)) 44 | print("OTHER FINAL AVERAGE: \np_score: ", np.around(np.mean(p_score), 3), "\nd_score: ", np.around(np.mean(d_score), 3), "\ntstr_score: ", np.around(np.mean(tstr_score), 3)) 45 | print("OTHER FINAL STD: \np_score: ", np.around(np.std(p_score), 3), "\nd_score: ", np.around(np.std(d_score), 3), "\ntstr_score: ", np.around(np.std(tstr_score), 3)) 46 | 47 | return np.mean(acc), np.mean(f1) 48 | 49 | if __name__ == '__main__': 50 | args = args_space.get_args() 51 | 52 | # dataset-specific parameters 53 | if args.dataset == 'UTD_MHAD_arm': 54 | args.batch_size = 64 55 | args.w_mg = 0.9 56 | args.w_mod = [0.5,0.5] 57 | preprocess.preprocess_UTD_MHAD(args.window_UM, args.stride_UM, args.data_dir) 58 | elif args.dataset == 'OPPORTUNITY': 59 | args.batch_size = 128 60 | args.w_mg = 0.5 61 | args.w_mod = [1/3, 1/3, 1/3] 62 | preprocess.preprocess_OPPORTUNITY(args.window_O, args.stride_O, args.data_dir) 63 | elif args.dataset == 'UCI_HAR': 64 | args.batch_size = 128 65 | args.w_mg = 0.9 66 | args.w_mod = [0.5,0.5] 67 | preprocess.preprocess_UCIHAR(args.data_dir) 68 | 69 | # general parameters 70 | args.N_epochs_GAN = 150 71 | args.N_epochs_ALL = 100 72 | args.N_epochs_DA = 30 73 | args.N_epochs_C = 100 74 | args.p_drop = 0 75 | args.lr_G = 0.0007 76 | args.lr_D = 0.0001 77 | args.lr_C = 0.001 78 | args.N_steps_D = 5 79 | args.latent_dim = 100 80 | args.w_gp = 10 81 | args.w_gc = 1.2 82 | args.N_aug = 1 83 | 84 | args.aug_type = 'multiply' 85 | args.to_save = True 86 | 87 | args.model_type = 'HMGAN' 88 | main(args) -------------------------------------------------------------------------------- /sliding_window.py: -------------------------------------------------------------------------------- 1 | # from http://www.johnvinyard.com/blog/?p=268 2 | 3 | import numpy as np 4 | from numpy.lib.stride_tricks import as_strided as ast 5 | 6 | def norm_shape(shape): 7 | ''' 8 | Normalize numpy array shapes so they're always expressed as a tuple, 9 | even for one-dimensional shapes. 10 | 11 | Parameters 12 | shape - an int, or a tuple of ints 13 | 14 | Returns 15 | a shape tuple 16 | ''' 17 | try: 18 | i = int(shape) 19 | return (i,) 20 | except TypeError: 21 | # shape was not a number 22 | pass 23 | 24 | try: 25 | t = tuple(shape) 26 | return t 27 | except TypeError: 28 | # shape was not iterable 29 | pass 30 | 31 | raise TypeError('shape must be an int, or a tuple of ints') 32 | 33 | def sliding_window(a,ws,ss = None,flatten = True): 34 | ''' 35 | Return a sliding window over a in any number of dimensions 36 | 37 | Parameters: 38 | a - an n-dimensional numpy array 39 | ws - an int (a is 1D) or tuple (a is 2D or greater) representing the size 40 | of each dimension of the window 41 | ss - an int (a is 1D) or tuple (a is 2D or greater) representing the 42 | amount to slide the window in each dimension. If not specified, it 43 | defaults to ws. 44 | flatten - if True, all slices are flattened, otherwise, there is an 45 | extra dimension for each dimension of the input. 46 | 47 | Returns 48 | an array containing each n-dimensional window from a 49 | ''' 50 | 51 | if None is ss: 52 | # ss was not provided. the windows will not overlap in any direction. 53 | ss = ws 54 | ws = norm_shape(ws) 55 | ss = norm_shape(ss) 56 | 57 | # convert ws, ss, and a.shape to numpy arrays so that we can do math in every 58 | # dimension at once. 59 | ws = np.array(ws) 60 | ss = np.array(ss) 61 | shape = np.array(a.shape) 62 | 63 | 64 | # ensure that ws, ss, and a.shape all have the same number of dimensions 65 | ls = [len(shape),len(ws),len(ss)] 66 | if 1 != len(set(ls)): 67 | raise ValueError(\ 68 | 'a.shape, ws and ss must all have the same length. They were %s' % str(ls)) 69 | 70 | # ensure that ws is smaller than a in every dimension 71 | if np.any(ws > shape): 72 | raise ValueError(\ 73 | 'ws cannot be larger than a in any dimension.\ 74 | a.shape was %s and ws was %s' % (str(a.shape),str(ws))) 75 | 76 | # how many slices will there be in each dimension? 77 | newshape = norm_shape(((shape - ws) // ss) + 1) 78 | # the shape of the strided array will be the number of slices in each dimension 79 | # plus the shape of the window (tuple addition) 80 | newshape += norm_shape(ws) 81 | # the strides tuple will be the array's strides multiplied by step size, plus 82 | # the array's strides (tuple addition) 83 | newstrides = norm_shape(np.array(a.strides) * ss) + a.strides 84 | strided = ast(a,shape = newshape,strides = newstrides) 85 | if not flatten: 86 | return strided 87 | 88 | # Collapse strided so that it has one more dimension than the window. i.e., 89 | # the new array is a flat list of slices. 90 | meat = len(ws) if ws.shape else 0 91 | firstdim = (np.product(newshape[:-meat]),) if ws.shape else () 92 | dim = firstdim + (newshape[-meat:]) 93 | # remove any dimensions with size 1 94 | dim = filter(lambda i : i != 1,dim) 95 | dim=norm_shape(dim) 96 | return strided.reshape(dim) 97 | -------------------------------------------------------------------------------- /model_DeepSense.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from itertools import repeat 4 | 5 | class ActivityClassifier_DPS(nn.Module): 6 | def __init__(self, N_modalities, N_classes, N_intervals, len_intervals, p_drop): 7 | super(ActivityClassifier_DPS, self).__init__() 8 | self.N_modalities = N_modalities 9 | self.N_intervals = N_intervals 10 | self.len_intervals = len_intervals 11 | K = 2 12 | C = 64 13 | self.C = C 14 | 15 | self.mod_conv = nn.ModuleList([nn.Sequential( 16 | nn.Conv2d(1, C, (1, K*3*2), stride=(1, 2*3), bias=False), 17 | nn.BatchNorm2d(C), 18 | nn.ReLU(), 19 | SpatialDropout(drop=p_drop), 20 | nn.Conv2d(C, C, (1, K), bias=False), 21 | nn.BatchNorm2d(C), 22 | nn.ReLU(), 23 | SpatialDropout(drop=p_drop), 24 | nn.Conv2d(C, C, (1, K), bias=False), 25 | nn.BatchNorm2d(C), 26 | nn.ReLU(), 27 | SpatialDropout(drop=p_drop), 28 | ) for _ in range(N_modalities)]) 29 | 30 | self.fuse_conv = nn.Sequential( 31 | nn.Conv3d(C, C, (1, N_modalities, 2), padding='same', bias=False), 32 | nn.BatchNorm3d(C), 33 | nn.ReLU(), 34 | SpatialDropout(drop=p_drop), 35 | nn.Conv3d(C, C, (1, N_modalities, 2), padding='same', bias=False), 36 | nn.BatchNorm3d(C), 37 | nn.ReLU(), 38 | SpatialDropout(drop=p_drop), 39 | nn.Conv3d(C, C, (1, N_modalities, 2), padding='same', bias=False), 40 | nn.BatchNorm3d(C), 41 | nn.ReLU(), 42 | SpatialDropout(drop=p_drop) 43 | ) 44 | 45 | self.GRU = nn.GRU(input_size=C * N_modalities * (len_intervals - 3 * (K - 1)), hidden_size=120, num_layers=2, dropout=p_drop) 46 | self.dropout = nn.Dropout(p=p_drop) 47 | 48 | self.out = nn.Sequential( 49 | nn.Linear(120, N_classes) 50 | ) 51 | 52 | def forward(self, x): # a list of [batch_size, channel_nums, seq_len] for each modality 53 | x = my_fft_torch(x, self.N_intervals, self.len_intervals) 54 | batch_size = x[0].size()[0] 55 | N_intervals = x[0].size()[2] 56 | x = [self.mod_conv[i](x[i]) for i in range(self.N_modalities)] 57 | feature_len = x[0].size()[3] 58 | x = [x_mod.reshape(batch_size, self.C, N_intervals, 1, feature_len) for x_mod in x] 59 | x = torch.cat(x, dim=3) 60 | x = self.fuse_conv(x) 61 | x = x.permute(2, 0, 1, 3, 4).reshape(N_intervals, batch_size, -1) 62 | x, _ = self.GRU(x) 63 | x = torch.mean(x, dim=0, keepdim=False) 64 | x = self.out(self.dropout(x)) 65 | return x 66 | 67 | def my_fft_torch(tensor_list, N_intervals, len_intervals): 68 | fft_tensor_list = [] 69 | batch_size = tensor_list[0].shape[0] 70 | for tensor in tensor_list: # [batch_size, num_channels, seq_len] 71 | fft_tensor = tensor.permute(0,2,1).reshape(batch_size, N_intervals, len_intervals, 3) 72 | fft_tensor = torch.fft.fft(fft_tensor, dim=2) 73 | fft_tensor = torch.cat([fft_tensor.real, fft_tensor.imag], 3) # [batch_size, N_intervals, interval_length, 3*2] last dimension: real(xyz), imag(xyz) 74 | fft_tensor = fft_tensor.reshape(batch_size, N_intervals, -1).unsqueeze(1) # [batch_size, 1, N_intervals, interval_length*3*2] last dim: real(xyz), imag(xyz) at t0, t1, ... 75 | 76 | fft_tensor_list.append(fft_tensor) 77 | return fft_tensor_list 78 | 79 | class SpatialDropout(nn.Module): 80 | def __init__(self, drop=0.5, noise_shape=None): 81 | super(SpatialDropout, self).__init__() 82 | self.drop = drop 83 | self.noise_shape = noise_shape 84 | 85 | def forward(self, inputs): 86 | """ 87 | inputs: tensor [batch_size, num_channels, ...] 88 | noise_shape: [batch_size, ...] same dimension as inputs, dropout along the dimensions of value 1 89 | """ 90 | outputs = inputs.clone() 91 | if self.noise_shape is None: 92 | self.noise_shape = (inputs.shape[0], inputs.shape[1], *repeat(1, inputs.dim()-2)) # default: dropout on channel dimension, along all other dimensions 93 | 94 | if not self.training or self.drop == 0: 95 | return inputs 96 | else: 97 | noises = self._make_noises(inputs) 98 | if self.drop == 1: 99 | noises.fill_(0.0) 100 | else: 101 | noises.bernoulli_(1 - self.drop).div_(1 - self.drop) 102 | noises = noises.expand_as(inputs) 103 | outputs.mul_(noises) 104 | return outputs 105 | 106 | def _make_noises(self, inputs): 107 | return inputs.new().resize_(self.noise_shape) -------------------------------------------------------------------------------- /args_space.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_args(): 4 | parser = argparse.ArgumentParser(description='PyTorch Implementation') 5 | #parameters w.r.t. datasets 6 | parser.add_argument('--window_U', type=int, default=128, help='the number of readings in a time window in UCI HAR, this dataset is already partitioned into 2.56-second windows') 7 | parser.add_argument('--overlap_U', type=int, default=0.5, help='the overlap ratio between time windows in UCI HAR') 8 | parser.add_argument('--N_classes_U', type=int, default=6, help='the number of activity classes in UCI HAR') 9 | parser.add_argument('--N_channels_U', type=int, default=6, help='the total number of channels in UCI HAR') 10 | parser.add_argument('--N_modalities_U', type=int, default=2, help='the number of modalities in total in UCI HAR') 11 | parser.add_argument('--N_users_U', type=int, default=None, help='the number of users in UCI HAR') 12 | parser.add_argument('--N_intervals_U', type=int, default=8, help='the number of intervals in each window in UCI HAR') 13 | 14 | parser.add_argument('--window_UM', type=int, default=100, help='the number of readings in a time window in UTD-MHAD, this dataset contains 861 data sequences of around 3 seconds') 15 | parser.add_argument('--stride_UM', type=int, default=50, help='the number of readings to slide between time windows in UTD-MHAD') 16 | parser.add_argument('--N_modalities_UM', type=int, default=2, help='the number of sensor modalities in UTD-MHAD') 17 | parser.add_argument('--N_classes_UM_arm', type=int, default=21, help='the number of activity classes in UTD-MHAD') 18 | parser.add_argument('--N_channels_UM', type=int, default=6, help='the total number of channels in UTD-MHAD') 19 | parser.add_argument('--N_users_UM', type=int, default=8, help='the number of users in UTD-MHAD') 20 | parser.add_argument('--N_intervals_UM', type=int, default=10, help='the number of intervals in a window for UTD-MHAD') 21 | 22 | parser.add_argument('--window_O', type=int, default=60, help='the number of readings in a time window in OPPORTUNITY') 23 | parser.add_argument('--stride_O', type=int, default=30, help='the number of readings to slide between time windows in OPPORTUNITY') 24 | parser.add_argument('--N_modalities_O', type=int, default=3, help='the number of sensor modalities in OPPORTUNITY') 25 | parser.add_argument('--N_classes_O', type=int, default=17, help='the number of activity classes in OPPORTUNITY') 26 | parser.add_argument('--N_channels_O', type=int, default=9, help='the total number of channels in OPPORTUNITY') 27 | parser.add_argument('--N_users_O', type=int, default=4, help='the number of users in OPPORTUNITY') 28 | parser.add_argument('--N_intervals_O', type=int, default=10, help='the number of intervals in a window for OPPORTUNITY') 29 | 30 | #parameters w.r.t. general model settings 31 | parser.add_argument('--N_aug', type=float, default=1, help='the ratio of the amount of the generated data compared to original data') 32 | parser.add_argument('--lr_G', type=float, default=1e-4, help='learning rate for Generator') 33 | parser.add_argument('--lr_D', type=float, default=1e-4, help='learning rate for Discriminator') 34 | parser.add_argument('--lr_C', type=float, default=1e-3, help='learning rate for Classifier') 35 | parser.add_argument('--seed', type=int, default=0, help='random seed') 36 | parser.add_argument('--N_epochs_GAN', type=int, default=100, help='the number of epochs for stage 1') 37 | parser.add_argument('--N_epochs_ALL', type=int, default=200, help='the number of epochs for stage 2') 38 | parser.add_argument('--N_epochs_C', type=int, default=100, help='the number of epochs for classifier training') 39 | parser.add_argument('--N_epochs_DA', type=int, default=0, help='the number of epochs to start using generated data for augmentation') 40 | parser.add_argument('--batch_size', type=int, default=128, help='mini-batch size') 41 | 42 | #parameters w.r.t. model structures and training for HMGAN 43 | parser.add_argument('--latent_dim', type=int, default=100) 44 | parser.add_argument('--N_channels_per_mod', type=int, default=3, help='the number of channels for each modality') 45 | parser.add_argument('--p_drop', type=float, default=0.05, help='the probability of dropping out') 46 | parser.add_argument('--weight_decay', type=float, default=0, help='the coefficient of weight decay (L2 penalty)') 47 | parser.add_argument('--aug_type', type=str, default='', help='how to augment training data') 48 | parser.add_argument('--w_mg', type=float, default=0.3) 49 | parser.add_argument('--w_mod', type=list, default=[0.5,0.5]) 50 | parser.add_argument('--w_gc', type=float, default=1) 51 | parser.add_argument('--w_gp', type=float, default=10) 52 | parser.add_argument('--N_steps_D', type=int, default=5) 53 | 54 | #parameters w.r.t. model structures and training for evaluation metrics 55 | parser.add_argument('--lr_GAN', type=float, default=1e-4) 56 | parser.add_argument('--lr_pred', type=float, default=1e-3) 57 | parser.add_argument('--N_epochs_pred', type=int, default=200) 58 | parser.add_argument('--N_epochs_disc', type=int, default=100) 59 | 60 | #parameters w.r.t. experiment setups 61 | parser.add_argument('--dataset', type=str) 62 | parser.add_argument('--model_type', type=str, help='the model name') 63 | parser.add_argument('--N_folds', type=int, default=5, help='the number of folds') 64 | parser.add_argument('--test_fold', type=int, default=0, help='which fold to test on') 65 | parser.add_argument('--cuda', type=int, default=-1, help='the cuda device to run on') 66 | parser.add_argument('--to_save', type=bool, default=False, help='whether to save the model') 67 | parser.add_argument('--data_dir', type=str) 68 | 69 | args = parser.parse_args() 70 | 71 | return args -------------------------------------------------------------------------------- /model_HMGAN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class G_conv(nn.Module): 5 | def __init__(self, in_dim, seq_len, N_modalities, N_channels_per_mod): 6 | super().__init__() 7 | 8 | self.N_modalities = N_modalities 9 | if seq_len == 60: 10 | self.start = 22 11 | elif seq_len == 100: 12 | self.start = 32 13 | elif seq_len == 128: 14 | self.start = 39 15 | 16 | self.shared_fc = nn.Sequential( 17 | nn.Linear(in_dim, self.start * 1 * 32, bias=False), 18 | nn.BatchNorm1d(self.start * 1 * 32, momentum=0.05, affine=True), 19 | nn.LeakyReLU(0.2), 20 | ) 21 | 22 | self.shared_conv = nn.Sequential( 23 | nn.Conv2d(32, 32, kernel_size=(3, 1), stride=(1, 1), bias=False), 24 | nn.BatchNorm2d(32, momentum=0.05, affine=True), 25 | nn.LeakyReLU(0.2), 26 | nn.Conv2d(32, 32, kernel_size=(3, 1), stride=(1, 1), bias=False), 27 | nn.BatchNorm2d(32, momentum=0.05, affine=True), 28 | nn.LeakyReLU(0.2), 29 | nn.UpsamplingNearest2d(scale_factor=(2, 1)), 30 | ) 31 | 32 | self.mod_conv = nn.ModuleList([nn.Sequential( 33 | nn.Conv2d(32, 16, kernel_size=(3, 1), stride=(1, 1), bias=False), 34 | nn.BatchNorm2d(16, momentum=0.05, affine=True), 35 | nn.LeakyReLU(0.2), 36 | nn.Conv2d(16, 16, kernel_size=(3, 1), stride=(1, 1), bias=False), 37 | nn.BatchNorm2d(16, momentum=0.05, affine=True), 38 | nn.LeakyReLU(0.2), 39 | nn.UpsamplingNearest2d(scale_factor=(2, 1)), 40 | 41 | nn.Conv2d(16, 8, kernel_size=(3, 1), stride=(1, 1), bias=False), 42 | nn.BatchNorm2d(8, momentum=0.05, affine=True), 43 | nn.LeakyReLU(0.2), 44 | nn.Conv2d(8, N_channels_per_mod, kernel_size=(3, 1), stride=(1, 1), bias=False), 45 | nn.Tanh(), 46 | ) for _ in range(self.N_modalities)]) 47 | 48 | self.apply(self.init_weights) 49 | 50 | def init_weights(self, module): 51 | if isinstance(module, nn.Conv1d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): 52 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 53 | 54 | def forward(self, z, y): 55 | 56 | z = torch.cat((z, y), dim=1) 57 | 58 | g = self.shared_fc(z) 59 | g = g.view(-1, 32, self.start, 1) 60 | g = self.shared_conv(g) 61 | 62 | g = [self.mod_conv[i](g) for i in range(self.N_modalities)] 63 | g = [g_mod.squeeze(-1) for g_mod in g] 64 | return g 65 | 66 | class D_conv(nn.Module): 67 | def __init__(self, in_dim, seq_len, N_modalities, N_channels_per_mod): 68 | super().__init__() 69 | self.N_modalities = N_modalities 70 | self.N_channels_per_mod = N_channels_per_mod 71 | self.kernel_sizes = [11, 11, 7, 7, 5, 5, 3] 72 | self.padding = [5, 5, 3, 3, 2, 2, 1] 73 | self.strides = [1, 2, 1, 2, 1, 2, 1] 74 | self.kernel_num = [32, 32, 64, 64, 128, 128, 128] 75 | if seq_len == 100: 76 | feat_dim1 = 13 77 | feat_dim2 = 4 78 | elif seq_len == 60: 79 | feat_dim1 = 8 80 | feat_dim2 = 2 81 | elif seq_len == 128: 82 | feat_dim1 = 16 83 | feat_dim2 = 4 84 | 85 | self.mod_conv = nn.ModuleList([nn.Sequential( 86 | nn.Conv1d(in_dim, self.kernel_num[0], self.kernel_sizes[0], self.strides[0], self.padding[0], bias=False), 87 | nn.LeakyReLU(0.2), 88 | ) for _ in range(self.N_modalities)]) 89 | 90 | for m in range(self.N_modalities): 91 | for i in range(1, len(self.kernel_sizes)): 92 | self.mod_conv[m].add_module(str(len(self.mod_conv[m])), 93 | nn.Conv1d(self.kernel_num[i - 1], self.kernel_num[i], self.kernel_sizes[i], self.strides[i], self.padding[i], bias=False)) 94 | self.mod_conv[m].add_module(str(len(self.mod_conv[m])), 95 | nn.LeakyReLU(0.2)) 96 | 97 | self.mod_out = nn.ModuleList([nn.Sequential( 98 | nn.Linear(feat_dim1 * self.kernel_num[-1], 1024, bias=False), 99 | nn.LeakyReLU(0.2), 100 | nn.Linear(1024, 1, bias=False) 101 | ) for _ in range(self.N_modalities)]) 102 | 103 | self.shared_conv = nn.Sequential( 104 | nn.Conv1d(self.N_modalities * self.kernel_num[-1], 32, kernel_size=3, stride=1, padding=1, bias=False), 105 | nn.LeakyReLU(0.2), 106 | nn.Conv1d(32, 32, kernel_size=3, stride=2, padding=1, bias=False), 107 | nn.LeakyReLU(0.2), 108 | nn.Conv1d(32, 64, kernel_size=3, stride=1, padding=1, bias=False), 109 | nn.LeakyReLU(0.2), 110 | nn.Conv1d(64, 64, kernel_size=3, stride=2, padding=1, bias=False), 111 | nn.LeakyReLU(0.2), 112 | ) 113 | 114 | self.shared_out = nn.Sequential( 115 | nn.Linear(feat_dim2 * 64, 1024, bias=False), 116 | nn.LeakyReLU(0.2), 117 | nn.Linear(1024, 1, bias=False) 118 | ) 119 | 120 | self.apply(self.init_weights) 121 | 122 | def init_weights(self, module): 123 | if isinstance(module, nn.Conv1d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): 124 | torch.nn.init.xavier_uniform(module.weight) 125 | 126 | def forward(self, x, y): # a list of [batch_size, channels, time_steps] for each modality 127 | mod_x = [None for _ in range(self.N_modalities)] 128 | mod_prob = [None for _ in range(self.N_modalities)] 129 | for i in range(self.N_modalities): 130 | mod_x[i] = label_concat(x[i], y) 131 | mod_x[i] = self.mod_conv[i](mod_x[i]) 132 | mod_prob[i] = torch.flatten(mod_x[i], start_dim=1) 133 | mod_prob[i] = self.mod_out[i](mod_prob[i]) 134 | 135 | glb_x = torch.cat(mod_x, dim=1) 136 | glb_x = self.shared_conv(glb_x) 137 | glb_x = torch.flatten(glb_x, start_dim=1) 138 | glb_prob = self.shared_out(glb_x) 139 | 140 | return mod_prob, glb_prob 141 | 142 | def label_concat(x, y): # x [batch_size, channels, time_steps, 1] y [batch_size, num_classes] onehot 143 | x_shape = list(x.shape) 144 | label_shape = list(y.shape) 145 | y = y.view(label_shape[0], label_shape[1], 1) 146 | label_shape = list(y.shape) 147 | y = y * torch.ones(label_shape[0], label_shape[1], x_shape[2]).cuda() 148 | x = torch.cat((x, y), 1) 149 | return x -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | import torch.nn as nn 4 | from torch import optim 5 | from torch.nn import init 6 | from torch.autograd import Variable 7 | from model_DeepSense import ActivityClassifier_DPS 8 | import torchmetrics 9 | import numpy as np 10 | 11 | def init_weights(module): 12 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.Conv3d): 13 | init.xavier_uniform_(module.weight, gain=1) 14 | 15 | class Predictor(nn.Module): 16 | def __init__(self, N_channels): 17 | super(Predictor, self).__init__() 18 | 19 | self.GRU = nn.GRU(input_size=N_channels, 20 | hidden_size=120, 21 | num_layers=2, 22 | batch_first=True) 23 | 24 | self.out = nn.Sequential( 25 | nn.Linear(120, N_channels), 26 | nn.Sigmoid() 27 | ) 28 | 29 | self.apply(init_weights) 30 | 31 | def forward(self, x): # [batch_size, seq_len, channel_nums] 32 | x, _ = self.GRU(x) 33 | x = self.out(x) 34 | return x 35 | 36 | def get_predictive_score(args, real_loader, gen_loader, N_channels): 37 | """ 38 | Args: 39 | - real_loader: original data [N_samples, channel_nums, seq_len, 1] 40 | - gen_loader: generated synthetic data [N_samples, channel_nums, seq_len, 1] 41 | Returns: 42 | - predictive_score: MAE of the predictions on the original data 43 | """ 44 | 45 | predictor = Predictor(N_channels) 46 | predictor.cuda() 47 | opt_p = optim.Adam(predictor.parameters(), lr=args.lr_pred) 48 | predictor.train() 49 | 50 | torch.manual_seed(args.seed) 51 | torch.cuda.manual_seed(args.seed) 52 | 53 | for _ in range(args.N_epochs_pred): 54 | for _, (x, _) in enumerate(gen_loader): 55 | x = x.cuda() # [batch_size, seq_len, channel_nums] 56 | x1 = x[:, :-1, :] 57 | x2 = x[:, 1:, :] 58 | 59 | opt_p.zero_grad() 60 | 61 | x_pred = predictor(x1) 62 | loss = nn.L1Loss()(x_pred, x2) 63 | loss.backward() 64 | opt_p.step() 65 | 66 | torch.manual_seed(args.seed) 67 | torch.cuda.manual_seed(args.seed) 68 | 69 | predictor.eval() 70 | predictive_score = 0 71 | for _, (x, _) in enumerate(real_loader): 72 | x = x.cuda() # [batch_size, seq_len, channel_nums] 73 | x1 = x[:, :-1, :] 74 | x2 = x[:, 1:, :] 75 | 76 | x_pred = predictor(x1) 77 | loss = nn.L1Loss()(x_pred, x2).item() 78 | predictive_score += loss 79 | 80 | predictive_score /= len(real_loader) 81 | 82 | return predictive_score 83 | 84 | class Discriminator(nn.Module): 85 | def __init__(self, N_channels): 86 | super(Discriminator, self).__init__() 87 | 88 | N_layers = 2 89 | self.GRU = nn.GRU(input_size=N_channels, 90 | hidden_size=120, 91 | num_layers=N_layers, 92 | batch_first=True) 93 | 94 | self.out = nn.Sequential( 95 | nn.Linear(120*N_layers, 1), 96 | nn.Sigmoid() 97 | ) 98 | 99 | self.apply(init_weights) 100 | 101 | def forward(self, x): # [batch_size, seq_len, channel_nums] 102 | _, x = self.GRU(x) 103 | x = torch.flatten(x.permute(1,0,2), start_dim=1) 104 | x = self.out(x) 105 | return x.squeeze() 106 | 107 | def get_discriminative_score(args, train_d_loader, test_d_loader, N_channels): 108 | """ 109 | Args: 110 | - real_loader: original data [N_samples, channel_nums, seq_len, 1] 111 | - gen_loader: generated synthetic data [N_samples, channel_nums, seq_len, 1] 112 | Returns: 113 | - discriminative_score: 114 | """ 115 | 116 | D = Discriminator(N_channels) 117 | D.cuda() 118 | opt_d = optim.Adam(D.parameters(), lr=args.lr_GAN) 119 | D.train() 120 | 121 | torch.manual_seed(args.seed) 122 | torch.cuda.manual_seed(args.seed) 123 | 124 | for _ in range(args.N_epochs_disc): 125 | for _, (x, y_d) in enumerate(train_d_loader): 126 | x = x.cuda() 127 | y_d = y_d.cuda() 128 | 129 | opt_d.zero_grad() 130 | 131 | probs_d = D(x) 132 | D_loss = torch.nn.BCEWithLogitsLoss()(probs_d, y_d) 133 | 134 | D_loss.backward() 135 | opt_d.step() 136 | 137 | torch.manual_seed(args.seed) 138 | torch.cuda.manual_seed(args.seed) 139 | 140 | test_d_acc = torchmetrics.Accuracy().cuda() 141 | D.eval() 142 | for _, (x, y_d) in enumerate(test_d_loader): 143 | x = x.cuda() 144 | y_d = y_d.cuda() 145 | 146 | probs_d = D(x) 147 | test_d_acc(probs_d, y_d.long()) 148 | 149 | disc_acc = test_d_acc.compute().item() 150 | discriminative_score = np.abs(0.5 - disc_acc) 151 | 152 | return discriminative_score, disc_acc 153 | 154 | def get_TSTR_score(args, real_loader, gen_loader, N_modalities, N_channels_per_mod, N_classes, N_intervals, len_intervals, CM=False): 155 | """ 156 | Args: 157 | - real_loader: original data [N_samples, channel_nums, seq_len, 1] 158 | - gen_loader: generated synthetic data [N_samples, channel_nums, seq_len, 1] 159 | Returns: 160 | - TSTR score: classification accuracy on the original data 161 | """ 162 | 163 | C = ActivityClassifier_DPS(N_modalities, N_classes, N_intervals, len_intervals, 0) 164 | C.cuda() 165 | opt_c = optim.Adam(C.parameters(), lr=args.lr_C) 166 | C.train() 167 | 168 | torch.manual_seed(args.seed) 169 | torch.cuda.manual_seed(args.seed) 170 | 171 | for _ in range(args.N_epochs_C): 172 | for _, (x, y) in enumerate(gen_loader): 173 | x = Variable(x.cuda()) 174 | x = x.permute(0, 2, 1) 175 | x = torch.split(x, N_channels_per_mod, dim=1) 176 | y = y.long().cuda() 177 | 178 | opt_c.zero_grad() 179 | 180 | logits_c = C(x) 181 | loss = nn.CrossEntropyLoss()(logits_c, y) 182 | loss.backward() 183 | opt_c.step() 184 | 185 | torch.manual_seed(args.seed) 186 | torch.cuda.manual_seed(args.seed) 187 | 188 | C.eval() 189 | test_c_acc = torchmetrics.Accuracy().cuda() 190 | if CM: 191 | all_y_true = np.empty([0], dtype=np.int) 192 | all_y_pred = np.empty([0], dtype=np.int) 193 | for _, (x, y) in enumerate(real_loader): 194 | x = Variable(x.cuda()) 195 | x = x.permute(0, 2, 1) 196 | x = torch.split(x, N_channels_per_mod, dim=1) 197 | y = y.long().cuda() 198 | 199 | logits_c = C(x) 200 | test_c_acc(logits_c, y) 201 | if CM: 202 | y_pred = logits_c.data.max(1)[1] 203 | all_y_pred = np.concatenate((all_y_pred, y_pred.cpu().numpy()), axis=0) 204 | all_y_true = np.concatenate((all_y_true, y.cpu().numpy()), axis=0) 205 | 206 | if CM: 207 | return test_c_acc.compute().item(), all_y_true, all_y_pred 208 | else: 209 | return test_c_acc.compute().item() -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import args_space 4 | import numpy as np 5 | from scipy import stats 6 | from pandas import Series 7 | from sliding_window import sliding_window 8 | from sklearn.model_selection import KFold 9 | import pandas as pd 10 | import scipy.io 11 | 12 | def preprocess_UTD_MHAD(window, stride, data_dir, K=5): 13 | dataset_path = data_dir + 'UTD_MHAD/' 14 | N_channels = 6 15 | 16 | upper_bound_arm = np.array([3.652832, 7.725342, 6.398193, 1000.519084, 606.778626, 1000.519084]) # max 17 | lower_bound_arm = np.array([-8.0, -8.0, -8.0, -1000.549618, -1000.549618, -741.557252]) # min 18 | 19 | acts_arm = range(1,22) 20 | 21 | if os.path.exists(dataset_path + 'arm/processed_data/'): 22 | shutil.rmtree(dataset_path + 'arm/processed_data/') 23 | os.mkdir(dataset_path + 'arm/processed_data/') 24 | 25 | acts = acts_arm 26 | 27 | upper_bound = upper_bound_arm 28 | lower_bound = lower_bound_arm 29 | 30 | x_all = np.empty([0, window, N_channels], dtype=np.float) 31 | y_all = np.empty([0], dtype=np.int) 32 | for user in range(1,9): 33 | print( "process arm activity data... user{}".format(user-1)) 34 | time_windows = np.empty([0, window, N_channels], dtype=np.float) 35 | act_labels = np.empty([0], dtype=np.int) 36 | 37 | for act in acts: 38 | for trial in range(1,5): 39 | file = dataset_path + 'Inertial/a{}_s{}_t{}_inertial.mat'.format(act, user, trial) 40 | 41 | if not os.path.exists(file): 42 | continue 43 | 44 | data = scipy.io.loadmat(file)['d_iner'] # [?, 6] around 150 time steps 45 | 46 | # normalization 47 | diff = upper_bound - lower_bound 48 | data = 2 * (data - lower_bound) / diff - 1 49 | 50 | data[ data > 1 ] = 1.0 51 | data[ data < -1 ] = -1.0 52 | 53 | #sliding window 54 | data = sliding_window(data, (window, N_channels), (stride, 1)) 55 | if len(data.shape) == 2: 56 | data = data.reshape(1,data.shape[0],data.shape[1]) 57 | 58 | act_min = 1 59 | label = np.ones(len(data)) * (act-act_min) 60 | 61 | time_windows = np.concatenate((time_windows, data), axis=0) 62 | act_labels = np.concatenate((act_labels, label), axis=0) 63 | 64 | x_all = np.concatenate((x_all, time_windows), axis=0) 65 | y_all = np.concatenate((y_all, act_labels), axis=0) 66 | 67 | np.save(dataset_path + 'arm/processed_data/features', x_all) 68 | np.save(dataset_path + 'arm/processed_data/labels', y_all) 69 | # save the K fold idx 70 | kf = KFold(n_splits=K, shuffle=True, random_state=0) 71 | for i, (train_index, test_index) in enumerate(kf.split(x_all)): 72 | np.save(dataset_path + 'arm/processed_data/' + 'fold{}_train_idx'.format(i), train_index) 73 | np.save(dataset_path + 'arm/processed_data/' + 'fold{}_test_idx'.format(i), test_index) 74 | 75 | def preprocess_OPPORTUNITY(window, overlap, data_dir, K=5): 76 | dataset_path = data_dir + 'OPPORTUNITY/' 77 | N_channels = 9 78 | 79 | file_list = [ ['S1-Drill.dat', 80 | 'S1-ADL1.dat', 81 | 'S1-ADL2.dat', 82 | 'S1-ADL3.dat', 83 | 'S1-ADL4.dat', 84 | 'S1-ADL5.dat'] , 85 | ['S2-Drill.dat', 86 | 'S2-ADL1.dat', 87 | 'S2-ADL2.dat', 88 | 'S2-ADL3.dat', 89 | 'S2-ADL4.dat', 90 | 'S2-ADL5.dat'] , 91 | ['S3-Drill.dat', 92 | 'S3-ADL1.dat', 93 | 'S3-ADL2.dat', 94 | 'S3-ADL3.dat', 95 | 'S3-ADL4.dat', 96 | 'S3-ADL5.dat'] , 97 | ['S4-Drill.dat', 98 | 'S4-ADL1.dat', 99 | 'S4-ADL2.dat', 100 | 'S4-ADL3.dat', 101 | 'S4-ADL4.dat', 102 | 'S4-ADL5.dat'] ] 103 | 104 | upper_bound = np.array([498.0, 1809.0, 1723.1842000000179, 6794.719200000167, 5843.026200000197, 4011.30700000003, 1678.122800000012, 1225.0, 1446.061400000006])# 0.9999 quantile 105 | lower_bound = np.array([-1435.0, -832.0, -617.0, -2939.0, -1795.0, -2158.0, -660.0, -1096.0, -928.0])# 0.005 quantile 106 | 107 | if os.path.exists( dataset_path + 'processed_data/' ): 108 | shutil.rmtree( dataset_path + 'processed_data/' ) 109 | os.mkdir( dataset_path + 'processed_data/' ) 110 | 111 | time_windows_all = [] 112 | act_labels_all = [] 113 | for usr_idx in range( 4 ): 114 | 115 | print( "process data... user{}".format( usr_idx ) ) 116 | time_windows = np.empty( [0, window, N_channels], dtype=np.float ) 117 | act_labels = np.empty( [0], dtype=np.int ) 118 | 119 | for file_idx in range( len(file_list[0]) ): 120 | 121 | filename = file_list[ usr_idx ][ file_idx ] 122 | 123 | file = dataset_path + filename 124 | signals = pd.read_csv(file, delimiter=' ', header=None) 125 | 126 | signals = signals.loc[:, [50, 51, 52, 53, 54, 55, 56, 57, 58, 249]] # RUA acc xyz gyro xyz mag xyz 127 | signals.dropna(inplace=True) 128 | 129 | data = signals.values[:,:9] 130 | label = signals.values[:,-1].astype( np.int ) 131 | 132 | label[ label == 0 ] = -1 133 | 134 | # ML_Both_Arms 135 | label[ label == 406516 ] = 0 # Open Door 1 136 | label[ label == 406517 ] = 1 # Open Door 2 137 | label[ label == 404516 ] = 2 # Close Door 1 138 | label[ label == 404517 ] = 3 # Close Door 2 139 | label[ label == 406520 ] = 4 # Open Fridge 140 | label[ label == 404520 ] = 5 # Close Fridge 141 | label[ label == 406505 ] = 6 # Open Dishwasher 142 | label[ label == 404505 ] = 7 # Close Dishwasher 143 | label[ label == 406519 ] = 8 # Open Drawer 1 144 | label[ label == 404519 ] = 9 # Close Drawer 1 145 | label[ label == 406511 ] = 10 # Open Drawer 2 146 | label[ label == 404511 ] = 11 # Close Drawer 2 147 | label[ label == 406508 ] = 12 # Open Drawer 3 148 | label[ label == 404508 ] = 13 # Close Drawer 3 149 | label[ label == 408512 ] = 14 # Clean Table 150 | label[ label == 407521 ] = 15 # Drink from Cup 151 | label[ label == 405506 ] = 16 # Toggle Switch 152 | 153 | # fill missing values using Linear Interpolation 154 | data = np.array( [Series(i).interpolate(method='linear') for i in data.T] ).T 155 | data[ np.isnan( data ) ] = 0. 156 | 157 | # normalization 158 | diff = upper_bound - lower_bound 159 | data = ( data - lower_bound ) / diff 160 | 161 | data[ data > 1 ] = 1.0 162 | data[ data < 0 ] = 0.0 163 | 164 | #sliding window 165 | data = sliding_window( data, (window, N_channels), (overlap, 1) ) 166 | label = sliding_window( label, window, overlap ) 167 | label = stats.mode( label, axis=1 )[0][:,0] 168 | 169 | #remove non-interested time windows (label==-1) 170 | invalid_idx = np.nonzero( label < 0 )[0] 171 | data = np.delete( data, invalid_idx, axis=0 ) 172 | label = np.delete( label, invalid_idx, axis=0 ) 173 | 174 | time_windows = np.concatenate( (time_windows, data), axis=0 ) 175 | act_labels = np.concatenate( (act_labels, label), axis=0 ) 176 | 177 | time_windows_all.append(time_windows) 178 | act_labels_all.append(act_labels) 179 | 180 | time_windows_all = np.concatenate(time_windows_all, axis=0) 181 | act_labels_all = np.concatenate(act_labels_all, axis=0) 182 | 183 | np.save(dataset_path + '/processed_data/features', time_windows_all) 184 | np.save(dataset_path + '/processed_data/labels', act_labels_all) 185 | # save the K fold idx 186 | kf = KFold(n_splits=K, shuffle=True, random_state=0) 187 | for i, (train_index, test_index) in enumerate(kf.split(time_windows_all)): 188 | np.save(dataset_path + '/processed_data/' + 'fold{}_train_idx'.format(i), train_index) 189 | np.save(dataset_path + '/processed_data/' + 'fold{}_test_idx'.format(i), test_index) 190 | 191 | def load_data_UCIHAR(dataset_path, file_list, type): 192 | x_data_list = [] 193 | for item in file_list: 194 | item_data = np.array(pd.read_csv(dataset_path + type + '/Inertial Signals/' + item + type + '.txt', delim_whitespace=True, header=None)) 195 | x_data_list.append(item_data) 196 | x = np.stack(x_data_list, -1) 197 | 198 | y = np.array(pd.read_csv(dataset_path + type + '/y_'+ type + '.txt', names=['Activity'], squeeze=True)) 199 | return x, y 200 | 201 | def preprocess_UCIHAR(data_dir, K=5): 202 | dataset_path = data_dir + 'UCI_HAR/' 203 | 204 | # get the data from txt files to pandas dataffame 205 | file_list = ['body_acc_x_', 'body_acc_y_', 'body_acc_z_', 'body_gyro_x_', 'body_gyro_y_', 'body_gyro_z_'] 206 | x_train, y_train = load_data_UCIHAR(dataset_path, file_list, 'train') 207 | x_test, y_test = load_data_UCIHAR(dataset_path, file_list, 'test') 208 | 209 | x = np.concatenate([x_train, x_test], axis=0) 210 | y = np.concatenate([y_train, y_test], axis=0) 211 | 212 | # the data are already preprocessed and filtered, no missing values np.isnan(x).sum()=0, since the data is already preprocessed, we use min max value here 213 | lower_bound = np.array([-0.7270811, -0.8285408496200001, -0.72422586782, -2.5482600237, -2.3043757869, -1.698266]) 214 | upper_bound = np.array([1.072854472100006, 0.620366, 0.6387655, 2.643864, 3.4056461708005163, 1.60952]) 215 | diff = upper_bound - lower_bound 216 | 217 | if os.path.exists( dataset_path + 'processed_data/' ): 218 | shutil.rmtree( dataset_path + 'processed_data/' ) 219 | os.mkdir( dataset_path + 'processed_data/' ) 220 | 221 | x = 2 * (x - lower_bound) / diff - 1 # need to keep the last dimension as the channel dimension for bradcasted deduction 222 | 223 | x[ x > 1 ] = 1.0 224 | x[ x < -1 ] = -1.0 225 | 226 | y = y - 1 # pytorch requires labels to lie within [0, C) 227 | 228 | np.save(dataset_path + '/processed_data/features', x) 229 | np.save(dataset_path + '/processed_data/labels', y) 230 | # save the K fold idx 231 | kf = KFold(n_splits=K, shuffle=True, random_state=0) 232 | for i, (train_index, test_index) in enumerate(kf.split(x)): 233 | np.save(dataset_path + '/processed_data/' + 'fold{}_train_idx'.format(i), train_index) 234 | np.save(dataset_path + '/processed_data/' + 'fold{}_test_idx'.format(i), test_index) 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | -------------------------------------------------------------------------------- /solver_HMGAN.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from torch import autograd 7 | from torch.autograd import Variable 8 | from torch.utils.data import DataLoader, TensorDataset 9 | 10 | from model_DeepSense import ActivityClassifier_DPS 11 | from model_HMGAN import G_conv, D_conv 12 | from get_data import get_data 13 | import torchmetrics 14 | from metrics import get_predictive_score 15 | from metrics import get_discriminative_score 16 | from metrics import get_TSTR_score 17 | from sklearn.model_selection import train_test_split 18 | 19 | class DASolver_HMGAN(nn.Module): 20 | def __init__(self, args): 21 | super().__init__() 22 | self.args = args 23 | self.seed = args.seed 24 | self.batch_size = args.batch_size 25 | self.N_epochs_GAN = args.N_epochs_GAN 26 | self.N_epochs_ALL = args.N_epochs_ALL 27 | self.N_steps_D = args.N_steps_D 28 | self.N_epochs_C = args.N_epochs_C 29 | self.N_epochs_DA = args.N_epochs_DA 30 | self.lr_G = args.lr_G 31 | self.lr_D = args.lr_D 32 | self.lr_C = args.lr_C 33 | 34 | self.latent_dim = args.latent_dim 35 | self.w_mg = args.w_mg 36 | self.w_gp = args.w_gp 37 | self.w_mod = args.w_mod 38 | self.w_gc = args.w_gc 39 | 40 | self.train_loader, self.valid_loader, self.test_loader = get_data(args) 41 | 42 | self.to_save = args.to_save 43 | self.dataset = args.dataset 44 | self.tag = args.dataset + '_' + args.model_type + '_fold' + str(args.test_fold) 45 | 46 | self.model_path = args.data_dir + 'checkpoints/' + self.tag 47 | if not os.path.exists(self.model_path): 48 | os.mkdir(self.model_path) 49 | 50 | self.N_channels_per_mod = args.N_channels_per_mod 51 | if 'UTD_MHAD' in args.dataset: 52 | self.N_modalities = args.N_modalities_UM 53 | self.N_channels = args.N_channels_UM 54 | self.N_classes = args.N_classes_UM_arm 55 | self.seq_len = args.window_UM 56 | self.N_intervals = args.N_intervals_UM 57 | self.len_intervals = int(args.window_UM / args.N_intervals_UM) 58 | elif args.dataset == 'OPPORTUNITY': 59 | self.N_modalities = args.N_modalities_O 60 | self.N_channels = args.N_channels_O 61 | self.N_classes = args.N_classes_O 62 | self.seq_len = args.window_O 63 | self.N_intervals = args.N_intervals_O 64 | self.len_intervals = int(args.window_O / args.N_intervals_O) 65 | elif args.dataset == 'UCI_HAR': 66 | self.N_modalities = args.N_modalities_U 67 | self.N_channels = args.N_channels_U 68 | self.N_classes = args.N_classes_U 69 | self.seq_len = args.window_U 70 | self.N_intervals = args.N_intervals_U 71 | self.len_intervals = int(args.window_U / args.N_intervals_U) 72 | 73 | self.G = G_conv(self.latent_dim+self.N_classes, self.seq_len, self.N_modalities, args.N_channels_per_mod) 74 | self.D = D_conv(args.N_channels_per_mod+self.N_classes, self.seq_len, self.N_modalities, self.N_channels_per_mod) 75 | self.C = ActivityClassifier_DPS(self.N_modalities, self.N_classes, self.N_intervals, self.len_intervals, args.p_drop) 76 | 77 | self.G.cuda() 78 | self.D.cuda() 79 | self.C.cuda() 80 | 81 | self.opt_g = optim.Adam(self.G.parameters(), lr=args.lr_G, betas=(0.5, 0.999)) 82 | self.opt_d = optim.Adam(self.D.parameters(), lr=args.lr_D, betas=(0.5, 0.999)) 83 | self.opt_gc = optim.Adam(self.G.parameters(), lr=args.lr_G, betas=(0.5, 0.999)) 84 | self.opt_c = optim.Adam(self.C.parameters(), lr=args.lr_C) 85 | 86 | def reset_grad(self): 87 | self.opt_gc.zero_grad() 88 | self.opt_g.zero_grad() 89 | self.opt_d.zero_grad() 90 | self.opt_c.zero_grad() 91 | 92 | def sample_z(self): 93 | z = Variable(torch.randn((self.batch_size, self.latent_dim), dtype=torch.float32).cuda()) 94 | return z 95 | 96 | def get_D_loss(self, logits_d_mod_r, logits_d_glb_r, logits_d_mod_g, logits_d_glb_g, x_r, x_g, y_inter): 97 | eps = torch.zeros(self.args.batch_size, 1, 1).uniform_().cuda() 98 | x_inter = [eps * x_r[i] + (1 - eps) * x_g[i] for i in range(self.N_modalities)] 99 | logits_d_mod_inter, logits_d_glb_inter = self.D(x_inter, y_inter) 100 | d_loss_mod = [self.modal_D_loss(logits_d_mod_r[i], logits_d_mod_g[i], x_inter[i], logits_d_mod_inter[i]) for i in range(self.N_modalities)] 101 | d_loss_glb = self.global_D_loss(logits_d_glb_r, logits_d_glb_g, x_inter, logits_d_glb_inter) 102 | d_loss_mod_sum = sum([d_loss_mod[i] * self.w_mod[i] for i in range(self.N_modalities)]) 103 | d_loss = d_loss_glb * self.w_mg + d_loss_mod_sum * (1 - self.w_mg) 104 | return d_loss 105 | 106 | def get_G_loss(self, logits_d_mod_g, logits_d_glb_g): 107 | g_loss_mod = [self.single_G_loss(logits_d_mod_g[i]) for i in range(self.N_modalities)] 108 | g_loss_glb = self.single_G_loss(logits_d_glb_g) 109 | g_loss_mod_sum = sum([g_loss_mod[i] * self.w_mod[i] for i in range(self.N_modalities)]) 110 | g_loss = g_loss_glb * self.w_mg + g_loss_mod_sum * (1 - self.w_mg) 111 | return g_loss 112 | 113 | def modal_D_loss(self, logits_d_r, logits_d_g, x_inter, logits_inter): 114 | grads = autograd.grad(outputs=logits_inter, inputs=x_inter, 115 | grad_outputs=torch.ones_like(logits_inter), 116 | create_graph=True, retain_graph=True, 117 | only_inputs=True)[0] 118 | grad_pen = torch.pow(grads.norm(2, dim=1) - 1, 2).mean() 119 | 120 | d_loss = -logits_d_r.mean() + logits_d_g.mean() + self.w_gp * grad_pen 121 | return d_loss 122 | 123 | def global_D_loss(self, logits_d_r, logits_d_g, x_inter, logits_inter): 124 | grads = [autograd.grad(outputs=logits_inter, inputs=x_inter[i], 125 | grad_outputs=torch.ones_like(logits_inter), 126 | create_graph=True, retain_graph=True, 127 | only_inputs=True)[0] for i in range(self.N_modalities)] 128 | grads = torch.cat(grads, dim=1) 129 | grad_pen = torch.pow(grads.norm(2, dim=1) - 1, 2).mean() 130 | 131 | d_loss = -logits_d_r.mean() + logits_d_g.mean() + self.w_gp * grad_pen 132 | return d_loss 133 | 134 | def single_G_loss(self, logits_d_g): 135 | g_loss = -logits_d_g.mean() 136 | return g_loss 137 | 138 | def forward_pass(self, x_r, y_r, type): 139 | z_g = self.sample_z() 140 | x_g = self.G(z_g, y_r) 141 | if type == 'get_x_g': 142 | return x_g 143 | 144 | if type != 'train_C': 145 | logits_d_mod_g, logits_d_glb_g = self.D(x_g, y_r) 146 | if 'train_G' not in type: 147 | logits_d_mod_r, logits_d_glb_r = self.D(x_r, y_r) 148 | 149 | if type != 'train_D': 150 | if type != 'train_G': 151 | logits_c_g = self.C(x_g) 152 | if 'train_G' not in type: 153 | logits_c_r = self.C(x_r) 154 | 155 | if type == 'train_D': 156 | return logits_d_mod_r, logits_d_glb_r, logits_d_mod_g, logits_d_glb_g, x_g 157 | elif type == 'train_C': 158 | return logits_c_r, logits_c_g 159 | elif type == 'train_GC': 160 | return logits_d_mod_g, logits_d_glb_g, logits_c_g 161 | elif type == 'train_G': 162 | return logits_d_mod_g, logits_d_glb_g 163 | 164 | def train(self): 165 | self.train_GAN() 166 | self.train_all() 167 | test_acc, test_f1 = self.train_C(training=False) 168 | return test_acc, test_f1 169 | 170 | def train_GAN(self): 171 | print('\n>>> Start Training GAN...') 172 | 173 | # lossess 174 | Loss_g = torchmetrics.MeanMetric().cuda() 175 | Loss_d = torchmetrics.MeanMetric().cuda() 176 | 177 | torch.manual_seed(self.seed) 178 | torch.cuda.manual_seed(self.seed) 179 | 180 | for epoch in range(self.N_epochs_GAN): 181 | 182 | self.G.train() 183 | self.D.train() 184 | 185 | for batch_idx, (x_r, y_r) in enumerate(self.train_loader): 186 | x_r = Variable(x_r.cuda()) 187 | x_r = x_r.permute(0, 2, 1) 188 | x_r = torch.split(x_r, self.N_channels_per_mod, dim=1) 189 | y_r = F.one_hot(y_r.long(), num_classes=self.N_classes) 190 | y_r = Variable(y_r.float().cuda()) 191 | 192 | self.reset_grad() 193 | 194 | ''' train discriminator ''' 195 | for _ in range(self.N_steps_D): 196 | logits_d_mod_r, logits_d_glb_r, logits_d_mod_g, logits_d_glb_g, x_g = self.forward_pass(x_r, y_r, 'train_D') 197 | D_loss = self.get_D_loss(logits_d_mod_r, logits_d_glb_r, logits_d_mod_g, logits_d_glb_g, x_r, x_g, y_r) 198 | D_loss.backward() 199 | self.opt_d.step() 200 | self.reset_grad() 201 | 202 | ''' train generator ''' 203 | G_loss = 0 204 | for _ in range(2): 205 | logits_d_mod_g, logits_d_glb_g = self.forward_pass(x_r, y_r, 'train_G') 206 | G_loss += self.get_G_loss(logits_d_mod_g, logits_d_glb_g) 207 | G_loss.backward() 208 | self.opt_g.step() 209 | self.reset_grad() 210 | 211 | # track training losses and metrics after optimization 212 | Loss_d(D_loss) 213 | Loss_g(G_loss) 214 | 215 | print('Train Epoch {}: Train: Loss_d:{:.6f} Loss_g:{:.6f}'.format( 216 | epoch, Loss_d.compute().item(), Loss_g.compute().item())) 217 | 218 | Loss_g.reset() 219 | Loss_d.reset() 220 | 221 | def train_all(self): 222 | print('\n>>> Start Training GAN and Classifier...') 223 | max_tstr_score = 0 224 | 225 | criterion_c = nn.CrossEntropyLoss().cuda() 226 | 227 | # lossess 228 | Loss_g = torchmetrics.MeanMetric().cuda() 229 | Loss_d = torchmetrics.MeanMetric().cuda() 230 | Loss_c = torchmetrics.MeanMetric().cuda() 231 | 232 | # classification accuracies of real and generated data 233 | train_c_acc_r = torchmetrics.Accuracy().cuda() 234 | train_c_acc_g = torchmetrics.Accuracy().cuda() 235 | 236 | torch.manual_seed(self.seed) 237 | torch.cuda.manual_seed(self.seed) 238 | 239 | for epoch in range(self.N_epochs_ALL): 240 | 241 | self.G.train() 242 | self.D.train() 243 | self.C.train() 244 | 245 | for batch_idx, (x_r, y_r) in enumerate(self.train_loader): 246 | x_r = Variable(x_r.cuda()) 247 | x_r = x_r.permute(0, 2, 1) 248 | x_r = torch.split(x_r, self.N_channels_per_mod, dim=1) 249 | y_r = F.one_hot(y_r.long(), num_classes=self.N_classes) 250 | y_r = Variable(y_r.float().cuda()) 251 | 252 | ''' train discriminator ''' 253 | for _ in range(self.N_steps_D): 254 | self.reset_grad() 255 | logits_d_mod_r, logits_d_glb_r, logits_d_mod_g1, logits_d_glb_g1, x_g = self.forward_pass(x_r, y_r, 'train_D') 256 | D_loss = self.get_D_loss(logits_d_mod_r, logits_d_glb_r, logits_d_mod_g1, logits_d_glb_g1, x_r, x_g, y_r) 257 | D_loss.backward() 258 | self.opt_d.step() 259 | 260 | ''' train classifier ''' 261 | self.reset_grad() 262 | logits_c_r, logits_c_g = self.forward_pass(x_r, y_r, 'train_C') 263 | C_loss_r = criterion_c(logits_c_r, y_r) 264 | if epoch >= self.N_epochs_DA: 265 | C_loss_g = criterion_c(logits_c_g, y_r) 266 | C_loss = (C_loss_r + C_loss_g) / 2 267 | if epoch == self.N_epochs_DA and batch_idx == 0: 268 | print('DA!') 269 | else: 270 | C_loss = C_loss_r 271 | C_loss.backward() 272 | self.opt_c.step() 273 | 274 | ''' train generator ''' 275 | self.reset_grad() 276 | G_loss_GAN = 0 277 | G_loss_C = 0 278 | for _ in range(2): 279 | logits_d_mod_g, logits_d_glb_g, logits_c_g = self.forward_pass(x_r, y_r, 'train_GC') 280 | G_loss_GAN += self.get_G_loss(logits_d_mod_g, logits_d_glb_g) 281 | G_loss_C += criterion_c(logits_c_g, y_r) 282 | G_loss = G_loss_GAN + self.w_gc * G_loss_C 283 | G_loss.backward() 284 | self.opt_gc.step() 285 | self.reset_grad() 286 | 287 | # track training losses and metrics after optimization 288 | Loss_d(D_loss) 289 | Loss_c(C_loss) 290 | Loss_g(G_loss) 291 | train_c_acc_r(logits_c_r.softmax(dim=-1), y_r.long()) 292 | train_c_acc_g(logits_c_g.softmax(dim=-1), y_r.long()) 293 | 294 | if (epoch+1) % 10 == 0: 295 | test_tstr_score = self.eval_tstr(training=True) 296 | if self.to_save and test_tstr_score > max_tstr_score: 297 | max_tstr_score = test_tstr_score 298 | torch.save(self.G.state_dict(), self.model_path + '/g.pkl') 299 | torch.save(self.D.state_dict(), self.model_path + '/d.pkl') 300 | print('best tstr model saved!') 301 | 302 | print('Train Epoch {}: Train: c_acc_r:{:.6f} c_acc_f:{:.6f} Loss_d:{:.6f} Loss_c:{:.6f} Loss_g:{:.6f}'.format( 303 | epoch, train_c_acc_r.compute().item(), train_c_acc_g.compute().item(), Loss_d.compute().item(), Loss_c.compute().item(), Loss_g.compute().item())) 304 | 305 | Loss_g.reset() 306 | Loss_d.reset() 307 | Loss_c.reset() 308 | train_c_acc_r.reset() 309 | train_c_acc_g.reset() 310 | 311 | test_c_acc, test_c_f1 = self.eval_C(training=True, test_loader=self.test_loader) 312 | test_acc = test_c_acc 313 | test_f1 = test_c_f1 314 | 315 | print('>>> Training Finished!') 316 | return test_acc, test_f1 317 | 318 | def train_C(self, training=False): 319 | print('\n>>> Start Training Classifier...') 320 | 321 | aug_loader = self.get_gen_dataset(training, type='aug') 322 | 323 | criterion_c = nn.CrossEntropyLoss().cuda() 324 | 325 | train_c_acc = torchmetrics.Accuracy().cuda() 326 | Loss_c = torchmetrics.MeanMetric().cuda() 327 | 328 | torch.manual_seed(self.seed) 329 | torch.cuda.manual_seed(self.seed) 330 | 331 | for epoch in range(self.N_epochs_C): 332 | 333 | self.C.train() 334 | 335 | for batch_idx, (x, y) in enumerate(aug_loader): 336 | x = Variable(x.cuda()) 337 | x = x.permute(0, 2, 1) 338 | x = torch.split(x, self.N_channels_per_mod, dim=1) 339 | y = Variable(y.long().cuda()) 340 | 341 | self.reset_grad() 342 | 343 | ''' train classifier ''' 344 | logits_c = self.C(x) 345 | 346 | loss_c = criterion_c(logits_c, y) 347 | loss_c.backward() 348 | self.opt_c.step() 349 | self.reset_grad() 350 | 351 | # track training losses and metrics after optimization 352 | Loss_c(loss_c) 353 | train_c_acc(logits_c.softmax(dim=-1), y) 354 | 355 | print('Train Epoch {}: Train: c_acc:{:.6f} Loss_c:{:.6f}'.format( 356 | epoch, train_c_acc.compute().item(), Loss_c.compute().item())) 357 | 358 | train_c_acc.reset() 359 | 360 | test_c_acc, test_c_f1 = self.eval_C(training=True, test_loader=self.test_loader) 361 | test_acc = test_c_acc 362 | test_f1 = test_c_f1 363 | 364 | if self.to_save: 365 | torch.save(self.C.state_dict(), self.model_path + '/c.pkl') 366 | 367 | print('>>> Training Finished!') 368 | return test_acc, test_f1 369 | 370 | def eval_C(self, training, test_loader): 371 | ''' 372 | training==True: the model is tested during training, use the current model and print test result in training info 373 | training==False: the model is tested after training, load the saved model and print test result alone 374 | ''' 375 | torch.manual_seed(self.seed) 376 | torch.cuda.manual_seed(self.seed) 377 | 378 | if not training: 379 | self.C.load_state_dict(torch.load((self.model_path + '/c.pkl'))) 380 | 381 | test_c_acc = torchmetrics.Accuracy().cuda() 382 | test_c_f1 = torchmetrics.F1Score(num_classes=self.N_classes, average='macro').cuda() 383 | 384 | self.C.eval() 385 | 386 | for _, (x, y) in enumerate(test_loader): 387 | x = Variable(x.cuda()) 388 | x = x.permute(0, 2, 1) 389 | x = torch.split(x, self.N_channels_per_mod, dim=1) 390 | y = Variable(y.long().cuda()) 391 | 392 | logits_c = self.C(x) 393 | 394 | # track training losses and metrics 395 | test_c_acc(logits_c.softmax(dim=-1), y) 396 | test_c_f1(logits_c.softmax(dim=-1), y) 397 | 398 | if not training: 399 | print('\n>>> Start Testing ...') 400 | print(self.tag + ' test acc:{:.6f} test f1:{:.6f}'.format( 401 | test_c_acc.compute().item(), test_c_f1.compute().item())) 402 | return test_c_acc.compute().item(), test_c_f1.compute().item() 403 | 404 | def eval_gen_data(self, training=True): 405 | gen_loader = self.get_gen_dataset(training) 406 | 407 | predictive_score = get_predictive_score(self.args, self.train_loader, gen_loader, self.N_channels) 408 | 409 | train_d_loader, test_d_loader = self.get_disc_dataset(training) 410 | discriminative_score, disc_acc = get_discriminative_score(self.args, train_d_loader, test_d_loader, self.N_channels) 411 | 412 | tstr_score = get_TSTR_score(self.args, self.train_loader, gen_loader, self.N_modalities, self.N_channels_per_mod, self.N_classes, self.N_intervals, self.len_intervals) 413 | 414 | return predictive_score, discriminative_score, disc_acc, tstr_score 415 | 416 | def eval_tstr(self, training=True): 417 | gen_loader = self.get_gen_dataset(training) 418 | tstr_score = get_TSTR_score(self.args, self.train_loader, gen_loader, self.N_modalities, self.N_channels_per_mod, self.N_classes, self.N_intervals, self.len_intervals) 419 | return tstr_score 420 | 421 | def get_gen_dataset(self, training=False, type='gen'): 422 | torch.manual_seed(self.seed) 423 | torch.cuda.manual_seed(self.seed) 424 | 425 | if not training: 426 | self.G.load_state_dict(torch.load((self.model_path + '/g.pkl'))) 427 | 428 | self.G.eval() 429 | 430 | data_g = [] 431 | label_g = [] 432 | if type == 'aug': 433 | data_r = [] 434 | label_r = [] 435 | for _, (x_r, y_r) in enumerate(self.train_loader): 436 | if type == 'aug': 437 | data_r.append(x_r) 438 | label_r.append(y_r) 439 | x_r = Variable(x_r.cuda()) 440 | x_r = x_r.permute(0, 2, 1) 441 | x_r = torch.split(x_r, self.N_channels_per_mod, dim=1) 442 | y_g = F.one_hot(y_r.long(), num_classes=self.N_classes) 443 | y_g = Variable(y_g.float().cuda()) 444 | 445 | for _ in range(self.args.N_aug): 446 | x_g = self.forward_pass(x_r, y_g, 'get_x_g') 447 | x_g = [x_g_mod.permute(0, 2, 1) for x_g_mod in x_g] 448 | x_g = torch.cat(x_g, dim=-1) 449 | data_g.append(x_g.detach().cpu()) 450 | label_g.append(y_r) 451 | 452 | data_g = torch.concat(data_g) 453 | label_g = torch.concat(label_g) 454 | if type == 'aug': 455 | data_r = torch.concat(data_r) 456 | label_r = torch.concat(label_r) 457 | 458 | if type == 'gen': 459 | gen_dataset = TensorDataset(data_g, label_g) 460 | gen_loader = DataLoader(gen_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True, num_workers=0) 461 | 462 | return gen_loader 463 | elif type == 'aug': 464 | data_rg = torch.concat([data_r, data_g]) 465 | label_rg = torch.concat([label_r, label_g]) 466 | 467 | aug_dataset = TensorDataset(data_rg, label_rg) 468 | aug_loader = DataLoader(aug_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True, num_workers=0) 469 | 470 | return aug_loader 471 | 472 | def get_disc_dataset(self, training=True): 473 | torch.manual_seed(self.seed) 474 | torch.cuda.manual_seed(self.seed) 475 | 476 | if not training: 477 | self.G.load_state_dict(torch.load((self.model_path + '/g.pkl'))) 478 | 479 | self.G.eval() 480 | 481 | data_aug = [] 482 | yd_aug = [] 483 | for _, (x_r, y_r) in enumerate(self.train_loader): 484 | data_aug.append(x_r) 485 | x_r = Variable(x_r.cuda()) 486 | x_r = x_r.permute(0, 2, 1) 487 | x_r = x_r.unsqueeze(-1) 488 | y_r = F.one_hot(y_r.long(), num_classes=self.N_classes) 489 | y_r = Variable(y_r.float().cuda()) 490 | yd = torch.concat([torch.ones(self.batch_size), torch.zeros(self.batch_size)]) 491 | 492 | x_g = self.forward_pass(x_r, y_r, 'get_x_g') 493 | 494 | x_g = [x_g_mod.permute(0, 2, 1) for x_g_mod in x_g] 495 | x_g = torch.cat(x_g, dim=-1) 496 | data_aug.append(x_g.detach().cpu()) 497 | yd_aug.append(yd) 498 | data_aug = torch.concat(data_aug) 499 | yd_aug = torch.concat(yd_aug) 500 | 501 | x_train, x_test, yd_train, yd_test = train_test_split(data_aug, yd_aug, train_size = 0.8, random_state = 0) 502 | 503 | train_d_dataset = TensorDataset(x_train, yd_train) 504 | train_d_loader= DataLoader(train_d_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True, num_workers=0) 505 | 506 | test_d_dataset = TensorDataset(x_test, yd_test) 507 | test_d_loader= DataLoader(test_d_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True, num_workers=0) 508 | 509 | return train_d_loader, test_d_loader --------------------------------------------------------------------------------