├── figure.jpg ├── src ├── data_config.json ├── .DS_Store ├── experiments │ ├── .DS_Store │ ├── in-tandem-regularization │ │ ├── results │ │ │ └── .gitignore │ │ ├── configs │ │ │ ├── regularizers.json │ │ │ └── datasets.json │ │ └── main.py │ ├── stand-alone-regularization │ │ ├── results │ │ │ └── .gitignore │ │ ├── configs │ │ │ ├── regularizers.json │ │ │ └── datasets.json │ │ └── main.py │ └── larger-data │ │ └── main.py ├── legacy │ ├── comment.md │ └── legacy.py ├── regularizers.py ├── models.py ├── losses.py └── load_data.py ├── requirements.txt ├── LICENSE ├── README.md └── TANGOS_quickstart.ipynb /figure.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alanjeffares/TANGOS/HEAD/figure.jpg -------------------------------------------------------------------------------- /src/data_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "path_to_data": "path/to/your/data/folder/" 3 | } -------------------------------------------------------------------------------- /src/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alanjeffares/TANGOS/HEAD/src/.DS_Store -------------------------------------------------------------------------------- /src/experiments/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alanjeffares/TANGOS/HEAD/src/experiments/.DS_Store -------------------------------------------------------------------------------- /src/experiments/in-tandem-regularization/results/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /src/experiments/stand-alone-regularization/results/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /src/experiments/in-tandem-regularization/configs/regularizers.json: -------------------------------------------------------------------------------- 1 | { 2 | "no_reg": { 3 | "placeholder": [0] 4 | }, 5 | "l2": { 6 | "weight": [0.1, 0.01, 0.001] 7 | }, 8 | "l1": { 9 | "weight": [0.1, 0.01, 0.001] 10 | }, 11 | "dropout": { 12 | "p": [0.2, 0.5] 13 | }, 14 | "input_noise": { 15 | "std": [0.1, 0.01] 16 | }, 17 | "mixup": { 18 | "alpha": [1] 19 | }, 20 | "TANGOS": { 21 | "lambda_1": [1, 10, 100], 22 | "lambda_2":[0.1, 1], 23 | "param_schedule": [3], 24 | "subsample": [50] 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | liac-arff==2.5.0 2 | charset-normalizer==3.0.1 3 | functorch==1.13.1 4 | idna==3.4 5 | joblib==1.2.0 6 | numpy==1.24.2 7 | nvidia-cublas-cu11==11.10.3.66 8 | nvidia-cuda-nvrtc-cu11==11.7.99 9 | nvidia-cuda-runtime-cu11==11.7.99 10 | nvidia-cudnn-cu11==8.5.0.96 11 | pandas==1.5.3 12 | Pillow==9.4.0 13 | python-dateutil==2.8.2 14 | pytz==2022.7.1 15 | requests==2.28.2 16 | scikit-learn==1.2.1 17 | scipy==1.10.0 18 | six==1.16.0 19 | threadpoolctl==3.1.0 20 | torch==1.13.1 21 | torchaudio==0.13.1 22 | torchvision==0.14.1 23 | typing_extensions==4.4.0 24 | urllib3==1.26.14 25 | -------------------------------------------------------------------------------- /src/experiments/stand-alone-regularization/configs/regularizers.json: -------------------------------------------------------------------------------- 1 | { 2 | "no_reg": { 3 | "placeholder": [0] 4 | }, 5 | "l2": { 6 | "weight": [0.1, 0.01, 0.001] 7 | }, 8 | "l1": { 9 | "weight": [0.1, 0.01, 0.001] 10 | }, 11 | "dropout": { 12 | "p": [0.2, 0.5] 13 | }, 14 | "input_noise": { 15 | "std": [0.1, 0.01] 16 | }, 17 | "batch_norm": { 18 | "placeholder": [0] 19 | }, 20 | "mixup": { 21 | "alpha": [1] 22 | }, 23 | "TANGOS": { 24 | "lambda_1": [1, 10, 100], 25 | "lambda_2":[0.1, 1], 26 | "param_schedule": [3], 27 | "subsample": [50] 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/legacy/comment.md: -------------------------------------------------------------------------------- 1 | **Comment on functorch and batch norm** 2 | 3 | The original implementation of TANGOS used a different, less efficient method for the calculation of gradient attributions. This was updated after running the in tandem experiments. Although both methods produce identical results, there is a compatibility issue between functorch (the new library for calculating attributions) and batch norm. More details on this issue are discussed [here](https://pytorch.org/functorch/stable/batch_norm.html) and [here](https://github.com/pytorch/functorch/issues/384). We have therefore removed the combination of TANGOS and batch norm from the config for this experiment by default. In case this particular combination is required by someone in the future, we have included [the original implementation](https://github.com/alanjeffares/TANGOS/blob/main/src/legacy/legacy.py) for calculating the attribution loss in this folder. 4 | 5 | -------------------------------------------------------------------------------- /src/regularizers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | device = 'cuda:0' 5 | 6 | def l1(model): 7 | l1_regularisation = 0. 8 | for param in model.parameters(): 9 | l1_regularisation += param.abs().sum() 10 | return l1_regularisation 11 | 12 | 13 | def add_input_noise(input, std, mean=0): 14 | return input + torch.randn(input.size()).to(device) * std + mean 15 | 16 | # adapted from https://github.com/facebookresearch/mixup-cifar10/blob/main/train.py 17 | def mixup_data(x, y, alpha=1.0, device='cpu'): 18 | '''Returns mixed inputs, pairs of targets, and lambda''' 19 | if alpha > 0: 20 | lam = np.random.beta(alpha, alpha) 21 | else: 22 | lam = 1 23 | 24 | batch_size = x.size()[0] 25 | if device=='cuda': 26 | index = torch.randperm(batch_size).cuda() 27 | else: 28 | index = torch.randperm(batch_size) 29 | 30 | mixed_x = lam * x + (1 - lam) * x[index, :] 31 | y_a, y_b = y, y[index] 32 | return mixed_x, y_a, y_b, lam 33 | 34 | 35 | def mixup_criterion(criterion, pred, y_a, y_b, lam): 36 | return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class UCI_MLP(nn.Module): 5 | def __init__(self, num_features, num_outputs, dropout=0, batch_norm=False): 6 | super(UCI_MLP, self).__init__() 7 | self.dropout = torch.nn.Dropout(p=dropout) 8 | self.batch_norm = batch_norm 9 | d = num_features + 1 10 | self.fc1 = nn.Linear(num_features, d) 11 | self.bn1 = nn.BatchNorm1d(d) 12 | self.relu1 = nn.ReLU(inplace=False) 13 | self.fc2 = nn.Linear(d, d) 14 | self.bn2 = nn.BatchNorm1d(d) 15 | self.relu2 = nn.ReLU(inplace=False) 16 | self.fc3 = nn.Linear(d, num_outputs) 17 | 18 | def forward(self, x): 19 | batch_size = x.shape[0] 20 | out = self.fc1(x) 21 | if self.batch_norm and batch_size > 1: 22 | out = self.bn1(out) 23 | out = self.relu1(out) 24 | out = self.dropout(out) 25 | out = self.fc2(out) 26 | if self.batch_norm and batch_size > 1: 27 | out = self.bn2(out) 28 | h_output = self.relu2(out) 29 | h_output = self.dropout(h_output) 30 | out = self.fc3(h_output) 31 | return out, h_output -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2023, Alan Jeffares 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /src/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | from functorch import jacrev 5 | from functorch import vmap 6 | 7 | def MSE(output, label): 8 | return nn.MSELoss()(output.squeeze(), label) 9 | 10 | 11 | class parameter_schedule: 12 | def __init__(self, lambda_1, lambda_2, epoch): 13 | self.lambda_1 = lambda_1 14 | self.lambda_2 = lambda_2 15 | self.switch_epoch = epoch 16 | 17 | def get_reg(self, epoch): 18 | if epoch < self.switch_epoch: 19 | return 0, 0 20 | else: 21 | return self.lambda_1, self.lambda_2 22 | 23 | 24 | def cosine_similarity(w1, w2): 25 | return torch.dot(w1, w2).abs() / (torch.norm(w1, 2) * torch.norm(w2, 2)) 26 | 27 | 28 | def weight_correlation(weights, device='cpu'): 29 | h_dim = weights.shape[0] 30 | 31 | weight_corr = torch.tensor(0., requires_grad=True).to(device) 32 | weights = weights.clone().requires_grad_(True) 33 | 34 | cos = nn.CosineSimilarity(dim=0, eps=1e-6) 35 | 36 | for neuron_i in range(1, h_dim): 37 | for neuron_j in range(0, neuron_i): 38 | pairwise_corr = cosine_similarity(weights[neuron_i, :], weights[neuron_j, :]) 39 | weight_corr = weight_corr + pairwise_corr.norm(p=1) 40 | 41 | return weight_corr / (h_dim * (h_dim - 1) / 2) 42 | 43 | 44 | def kl_divergence(mu, logvar): 45 | KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 46 | return KLD 47 | 48 | 49 | 50 | def attr_loss(forward_func, data_input, device='cpu', subsample=-1): 51 | ########## UPDATE functools ############ 52 | batch_size = data_input.shape[0] 53 | def to_latent(input_): 54 | _, h_out = forward_func(input_) 55 | return h_out 56 | 57 | data_input = data_input.clone().requires_grad_(True) 58 | jacobian = vmap(jacrev(to_latent), randomness='same')(data_input) 59 | neuron_attr = jacobian.swapaxes(0, 1) 60 | h_dim = neuron_attr.shape[0] 61 | 62 | if len(neuron_attr.shape) > 3: 63 | # h_dim x batch_size x features 64 | neuron_attr = neuron_attr.flatten(start_dim=2) 65 | 66 | sparsity_loss = torch.norm(neuron_attr, p=1) / (batch_size * h_dim * neuron_attr.shape[2]) 67 | 68 | cos = nn.CosineSimilarity(dim=1, eps=1e-6) 69 | correlation_loss = torch.tensor(0., requires_grad=True).to(device) 70 | 71 | if subsample > 0 and subsample < h_dim * (h_dim - 1) / 2: 72 | tensor_pairs = [list(np.random.choice(h_dim, size=(2), replace=False)) for i in range(subsample)] 73 | for tensor_pair in tensor_pairs: 74 | pairwise_corr = cos(neuron_attr[tensor_pair[0], :, :], neuron_attr[tensor_pair[1], :, :]).norm(p=1) 75 | correlation_loss = correlation_loss + pairwise_corr 76 | 77 | correlation_loss = correlation_loss / (batch_size * subsample) 78 | 79 | else: 80 | for neuron_i in range(1, h_dim): 81 | for neuron_j in range(0, neuron_i): 82 | pairwise_corr = cos(neuron_attr[neuron_i, :, :], neuron_attr[neuron_j, :, :]).norm(p=1) 83 | correlation_loss = correlation_loss + pairwise_corr 84 | num_pairs = h_dim * (h_dim - 1) / 2 85 | correlation_loss = correlation_loss / (batch_size * num_pairs) 86 | 87 | return sparsity_loss, correlation_loss -------------------------------------------------------------------------------- /src/legacy/legacy.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Depreciated implementation of attribution loss 3 | ''' 4 | 5 | import torch 6 | import numpy as np 7 | import torch.nn as nn 8 | 9 | 10 | def MSE(output, label): 11 | return nn.MSELoss()(output.squeeze(), label) 12 | 13 | 14 | def cosine_similarity(w1, w2): 15 | return torch.dot(w1, w2).abs() / (torch.norm(w1, 2) * torch.norm(w2, 2)) 16 | 17 | 18 | def weight_correlation(weights, device='cpu'): 19 | h_dim = weights.shape[0] 20 | 21 | weight_corr = torch.tensor(0., requires_grad=True).to(device) 22 | weights = weights.clone().requires_grad_(True) 23 | 24 | cos = nn.CosineSimilarity(dim=0, eps=1e-6) 25 | 26 | for neuron_i in range(1, h_dim): 27 | for neuron_j in range(0, neuron_i): 28 | pairwise_corr = cosine_similarity(weights[neuron_i, :], weights[neuron_j, :]) 29 | weight_corr = weight_corr + pairwise_corr.norm(p=1) 30 | 31 | return weight_corr / (h_dim * (h_dim - 1) / 2) 32 | 33 | 34 | def kl_divergence(mu, logvar): 35 | KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 36 | return KLD 37 | 38 | 39 | def attr_loss(forward_func, data_input, device='cpu', subsample=-1): 40 | # data_input = data_input.clone().detach().requires_grad_(True) 41 | 42 | #### CHANGED THISSS 43 | data_input = data_input.clone().requires_grad_(True) 44 | 45 | _, h_output = forward_func(data_input) 46 | 47 | batch_size = data_input.shape[0] 48 | h_dim = h_output.shape[1] 49 | 50 | neuron_attr = [] 51 | 52 | for neuron in range(h_dim): 53 | grad_outputs = torch.nn.functional.one_hot(torch.tensor([neuron]), h_dim).repeat((batch_size, 1)).to(device) 54 | grad = torch.autograd.grad(outputs=h_output, inputs=data_input, 55 | grad_outputs=grad_outputs, 56 | create_graph=True)[0] 57 | 58 | neuron_attr.append(grad) 59 | 60 | neuron_attr = torch.stack(neuron_attr) 61 | 62 | if len(neuron_attr.shape) > 3: 63 | # h_dim x batch_size x features 64 | neuron_attr = neuron_attr.flatten(start_dim=2) 65 | 66 | sparsity_loss = torch.norm(neuron_attr, p=1) / (batch_size * h_dim * neuron_attr.shape[2]) 67 | 68 | cos = nn.CosineSimilarity(dim=1, eps=1e-6) 69 | correlation_loss = torch.tensor(0., requires_grad=True).to(device) 70 | 71 | if subsample > 0 and subsample < h_dim * (h_dim - 1) / 2: 72 | tensor_pairs = [list(np.random.choice(h_dim, size=(2), replace=False)) for i in range(subsample)] 73 | for tensor_pair in tensor_pairs: 74 | pairwise_corr = cos(neuron_attr[tensor_pair[0], :, :], neuron_attr[tensor_pair[1], :, :]).norm(p=1) 75 | correlation_loss = correlation_loss + pairwise_corr 76 | 77 | correlation_loss = correlation_loss / (batch_size * subsample) 78 | 79 | else: 80 | for neuron_i in range(1, h_dim): 81 | for neuron_j in range(0, neuron_i): 82 | pairwise_corr = cos(neuron_attr[neuron_i, :, :], neuron_attr[neuron_j, :, :]).norm(p=1) 83 | correlation_loss = correlation_loss + pairwise_corr 84 | num_pairs = h_dim * (h_dim - 1) / 2 85 | correlation_loss = correlation_loss / (batch_size * num_pairs) 86 | 87 | return sparsity_loss, correlation_loss -------------------------------------------------------------------------------- /src/experiments/in-tandem-regularization/configs/datasets.json: -------------------------------------------------------------------------------- 1 | { 2 | "student": { 3 | "type": "regression", 4 | "loader": "load_student", 5 | "num_features": 56, 6 | "num_outputs": 1, 7 | "lr": 0.0001 8 | }, 9 | "bioconcentration": { 10 | "type": "regression", 11 | "loader": "load_bioconcentration", 12 | "num_features": 45, 13 | "num_outputs": 1, 14 | "lr": 0.001 15 | }, 16 | "facebook": { 17 | "type": "regression", 18 | "loader": "load_facebook", 19 | "num_features": 21, 20 | "num_outputs": 1, 21 | "lr": 0.01 22 | }, 23 | "wine": { 24 | "type": "regression", 25 | "loader": "load_wine", 26 | "num_features": 11, 27 | "num_outputs": 1, 28 | "lr": 0.001 29 | }, 30 | "abalone": { 31 | "type": "regression", 32 | "loader": "load_abalone", 33 | "num_features": 9, 34 | "num_outputs": 1, 35 | "lr": 0.01 36 | }, 37 | "skillcraft": { 38 | "type": "regression", 39 | "loader": "load_skillcraft", 40 | "num_features": 18, 41 | "num_outputs": 1, 42 | "lr": 0.01 43 | }, 44 | "weather": { 45 | "type": "regression", 46 | "loader": "load_weather", 47 | "num_features": 45, 48 | "num_outputs": 1, 49 | "lr": 0.01 50 | }, 51 | "forest": { 52 | "type": "regression", 53 | "loader": "load_forest", 54 | "num_features": 39, 55 | "num_outputs": 1, 56 | "lr": 0.0001 57 | }, 58 | "protein": { 59 | "type": "regression", 60 | "loader": "load_protein", 61 | "num_features": 9, 62 | "num_outputs": 1, 63 | "lr": 0.01 64 | }, 65 | "heart": { 66 | "type": "classification", 67 | "loader": "load_heart", 68 | "num_features": 20, 69 | "num_outputs": 2, 70 | "lr": 0.01 71 | }, 72 | "breast": { 73 | "type": "classification", 74 | "loader": "load_breast", 75 | "num_features": 9, 76 | "num_outputs": 2, 77 | "lr": 0.01 78 | }, 79 | "cervical": { 80 | "type": "classification", 81 | "loader": "load_cervical", 82 | "num_features": 136, 83 | "num_outputs": 5, 84 | "lr": 0.01 85 | }, 86 | "credit": { 87 | "type": "classification", 88 | "loader": "load_credit", 89 | "num_features": 40, 90 | "num_outputs": 2, 91 | "lr": 0.001 92 | }, 93 | "hcv": { 94 | "type": "classification", 95 | "loader": "load_hcv", 96 | "num_features": 12, 97 | "num_outputs": 4, 98 | "lr": 0.001 99 | }, 100 | "tumor": { 101 | "type": "classification", 102 | "loader": "load_tumor", 103 | "num_features": 25, 104 | "num_outputs": 22, 105 | "lr": 0.001 106 | }, 107 | "soybean": { 108 | "type": "classification", 109 | "loader": "load_soybean", 110 | "num_features": 484, 111 | "num_outputs": 19, 112 | "lr": 0.001 113 | }, 114 | "australian": { 115 | "type": "classification", 116 | "loader": "load_australian", 117 | "num_features": 55, 118 | "num_outputs": 2, 119 | "lr": 0.001 120 | }, 121 | "entrance": { 122 | "type": "classification", 123 | "loader": "load_entrance", 124 | "num_features": 38, 125 | "num_outputs": 4, 126 | "lr": 0.001 127 | }, 128 | "thoracic": { 129 | "type": "classification", 130 | "loader": "load_thoracic", 131 | "num_features": 24, 132 | "num_outputs": 2, 133 | "lr": 0.001 134 | } 135 | } -------------------------------------------------------------------------------- /src/experiments/stand-alone-regularization/configs/datasets.json: -------------------------------------------------------------------------------- 1 | { 2 | "student": { 3 | "type": "regression", 4 | "loader": "load_student", 5 | "num_features": 56, 6 | "num_outputs": 1, 7 | "lr": 0.0001 8 | }, 9 | "bioconcentration": { 10 | "type": "regression", 11 | "loader": "load_bioconcentration", 12 | "num_features": 45, 13 | "num_outputs": 1, 14 | "lr": 0.001 15 | }, 16 | "facebook": { 17 | "type": "regression", 18 | "loader": "load_facebook", 19 | "num_features": 21, 20 | "num_outputs": 1, 21 | "lr": 0.01 22 | }, 23 | "wine": { 24 | "type": "regression", 25 | "loader": "load_wine", 26 | "num_features": 11, 27 | "num_outputs": 1, 28 | "lr": 0.001 29 | }, 30 | "abalone": { 31 | "type": "regression", 32 | "loader": "load_abalone", 33 | "num_features": 9, 34 | "num_outputs": 1, 35 | "lr": 0.01 36 | }, 37 | "skillcraft": { 38 | "type": "regression", 39 | "loader": "load_skillcraft", 40 | "num_features": 18, 41 | "num_outputs": 1, 42 | "lr": 0.01 43 | }, 44 | "weather": { 45 | "type": "regression", 46 | "loader": "load_weather", 47 | "num_features": 45, 48 | "num_outputs": 1, 49 | "lr": 0.01 50 | }, 51 | "forest": { 52 | "type": "regression", 53 | "loader": "load_forest", 54 | "num_features": 39, 55 | "num_outputs": 1, 56 | "lr": 0.0001 57 | }, 58 | "protein": { 59 | "type": "regression", 60 | "loader": "load_protein", 61 | "num_features": 9, 62 | "num_outputs": 1, 63 | "lr": 0.01 64 | }, 65 | "heart": { 66 | "type": "classification", 67 | "loader": "load_heart", 68 | "num_features": 20, 69 | "num_outputs": 2, 70 | "lr": 0.01 71 | }, 72 | "breast": { 73 | "type": "classification", 74 | "loader": "load_breast", 75 | "num_features": 9, 76 | "num_outputs": 2, 77 | "lr": 0.01 78 | }, 79 | "cervical": { 80 | "type": "classification", 81 | "loader": "load_cervical", 82 | "num_features": 136, 83 | "num_outputs": 5, 84 | "lr": 0.01 85 | }, 86 | "credit": { 87 | "type": "classification", 88 | "loader": "load_credit", 89 | "num_features": 40, 90 | "num_outputs": 2, 91 | "lr": 0.001 92 | }, 93 | "hcv": { 94 | "type": "classification", 95 | "loader": "load_hcv", 96 | "num_features": 12, 97 | "num_outputs": 4, 98 | "lr": 0.001 99 | }, 100 | "tumor": { 101 | "type": "classification", 102 | "loader": "load_tumor", 103 | "num_features": 25, 104 | "num_outputs": 22, 105 | "lr": 0.001 106 | }, 107 | "soybean": { 108 | "type": "classification", 109 | "loader": "load_soybean", 110 | "num_features": 484, 111 | "num_outputs": 19, 112 | "lr": 0.001 113 | }, 114 | "australian": { 115 | "type": "classification", 116 | "loader": "load_australian", 117 | "num_features": 55, 118 | "num_outputs": 2, 119 | "lr": 0.001 120 | }, 121 | "entrance": { 122 | "type": "classification", 123 | "loader": "load_entrance", 124 | "num_features": 38, 125 | "num_outputs": 4, 126 | "lr": 0.001 127 | }, 128 | "thoracic": { 129 | "type": "classification", 130 | "loader": "load_thoracic", 131 | "num_features": 24, 132 | "num_outputs": 2, 133 | "lr": 0.001 134 | } 135 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## TANGOS: Regularizing Tabular Neural Networks through Gradient Orthogonalization and Specialization 2 | 3 | [![pdf](https://img.shields.io/badge/PDF-ICLR%202023-red)](https://openreview.net/forum?id=n6H86gW8u0d) 4 | [![License: BSD 3-Clause](https://img.shields.io/badge/License-BSD-blue.svg)](https://github.com/alanjeffares/TANGOS/blob/main/LICENSE) 5 | 6 | ![TANGOS](figure.jpg?raw=true "TANGOS") 7 | 8 | This repository contains the code associated with [our ICLR 2023 paper](https://openreview.net/forum?id=n6H86gW8u0d) where we introduce a novel regularizer for training deep neural networks. Tabular Neural Gradient Orthogonalization and Specialization (TANGOS) provides a framework for regularization in the tabular setting built on latent unit attributions. For further details, please see our paper. 9 | 10 | 11 | ### Getting Started With TANGOS 12 | To quickly get started with integrating TANGOS into a PyTorch workflow, we have provided a handy quickstart guide. This consists of a simple MLP training routine with a drop-in function that calculates and applies TANGOS regularization. This example notebook can be found in [TANGOS_quickstart.ipynb](https://github.com/alanjeffares/TANGOS/blob/main/TANGOS_quickstart.ipynb). 13 | 14 | ### Experiments 15 | **Setup** 16 | 17 | Clone this repository and navigate to the root folder. 18 | ``` 19 | git clone https://github.com/alanjeffares/TANGOS.git 20 | cd TANGOS 21 | ``` 22 | Ensure PYTHONPATH is also set to the root folder. 23 | ``` 24 | export PYTHONPATH="/your/path/to/TANGOS" 25 | ``` 26 | Using conda, create and activate a new environment. 27 | ``` 28 | conda create -n pip python 29 | conda activate 30 | ``` 31 | Then install the repository requirements. 32 | ``` 33 | pip install -r requirements.txt 34 | ``` 35 | 36 | **Data** 37 | 38 | Datasets can be downloaded using `wget` and the `` described in Appendix L of the paper. 39 | ``` 40 | wget -P /path/to/data/folder/ https://archive.ics.uci.edu/ml/machine- 41 | learning-databases// 42 | ``` 43 | Then set the path to your local data folder in `src/data_config.json`. 44 | ``` 45 | {"path_to_data": "path/to/data/folder/"} 46 | ``` 47 | 48 | 49 | **Running** 50 | 51 | These folders are associated with the commented experiments from the paper. 52 | ``` 53 | └── src 54 | └── experiments 55 | ├── behavior-analysis # TANGOS Behavior Analysis. 56 | ├── compute # Approximation and Algorithm. 57 | ├── in-tandem-regularization # Generalisaton: In Tandem Regularization. 58 | ├── larger-data # Performance With Increasing Data Size. 59 | └── stand-alone-regularization # Generalization: Stand-Alone Regularization. 60 | ``` 61 | 62 | The main experiments can be run by navigating to the root folder and running the following command. 63 | 64 | ```python src/experiments//main.py``` 65 | 66 | Results and hyperparameters of these experiments are saved in json format to the results folder. 67 | 68 | ```src/experiments//results``` 69 | 70 | The behavior analysis and compute experiments are included in ```.ipynb``` notebooks with instructions included. Please note that all jupyter notebooks are self contained and designed to be run in colab by clicking the link at the top of each notebook (e.g. [![open in colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/alanjeffares/TANGOS/blob/main/TANGOS_quickstart.ipynb)). 71 | 72 | _Note: To run in tandem experiments with batch norm please see [our comment](https://github.com/alanjeffares/TANGOS/blob/main/src/legacy/comment.md)._ 73 | 74 | ### Citation 75 | If you use this code, please cite the associated paper. 76 | ``` 77 | @inproceedings{ 78 | jeffares2023tangos, 79 | title={{TANGOS}: Regularizing Tabular Neural Networks through Gradient Orthogonalization and Specialization}, 80 | author={Alan Jeffares and Tennison Liu and Jonathan Crabb{\'e} and Fergus Imrie and Mihaela van der Schaar}, 81 | booktitle={International Conference on Learning Representations}, 82 | year={2023}, 83 | url={https://openreview.net/forum?id=n6H86gW8u0d} 84 | } 85 | ``` 86 | -------------------------------------------------------------------------------- /src/experiments/stand-alone-regularization/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import itertools 4 | import argparse 5 | import json 6 | import copy 7 | from pathlib import Path 8 | import numpy as np 9 | import random 10 | import torch.nn as nn 11 | from torch import optim 12 | from sklearn.model_selection import KFold 13 | from src.losses import parameter_schedule, attr_loss, MSE 14 | from src.models import UCI_MLP 15 | from src.regularizers import l1, add_input_noise, mixup_data, mixup_criterion 16 | import src.load_data 17 | 18 | import warnings 19 | 20 | warnings.filterwarnings("ignore") 21 | 22 | EPOCHS = 200 23 | TRAINING_PATIENCE = 30 24 | BATCH_SIZE = 32 25 | K_FOLDS = 5 26 | DEVICE = 'cuda:0' 27 | 28 | 29 | def train_epoch(model, train_loader, optimiser, loss_func, params, regulariser, device='cpu'): 30 | reg_loss = 0 31 | for i, (data, label) in enumerate(train_loader): 32 | model.train() 33 | data, label = data.to(device), label.to(device) 34 | if regulariser == 'input_noise': 35 | data = add_input_noise(data, params['std']) 36 | optimiser.zero_grad() 37 | output, _ = model(data) 38 | pred_loss = loss_func(output, label) 39 | 40 | if regulariser == 'TANGOS' and (params['lambda_1_curr'] > 0 or params['lambda_2_curr'] > 0): 41 | sparsity_loss, correlation_loss = attr_loss(model, data, device=device, subsample=params['subsample']) 42 | reg_loss = params['lambda_1_curr'] * sparsity_loss + params['lambda_2_curr'] * correlation_loss 43 | 44 | elif regulariser == 'l1': 45 | reg_loss = params['weight'] * l1(model) 46 | 47 | elif regulariser == 'mixup': 48 | X_mixup, y_a, y_b, lam = mixup_data(data, label, alpha=params['alpha'], device=device) 49 | output_mixup, _ = model(X_mixup) 50 | reg_loss = mixup_criterion(loss_func, output_mixup, y_a, y_b, lam) 51 | 52 | loss = pred_loss + reg_loss 53 | loss.backward() 54 | optimiser.step() 55 | return model 56 | 57 | 58 | def evaluate(model, test_loader, loss_func, device=DEVICE): 59 | running_loss, running_pred_loss = 0, 0 60 | for epoch, (data, label) in enumerate(test_loader): 61 | model.eval() 62 | data, label = data.to(device), label.to(device) 63 | 64 | output, _ = model(data) 65 | 66 | # compute metric 67 | pred_loss = loss_func(output, label) 68 | loss = pred_loss 69 | 70 | running_loss += loss.item() 71 | running_pred_loss += pred_loss.item() 72 | 73 | return running_pred_loss/(epoch + 1), running_loss/(epoch + 1) 74 | 75 | def seed_worker(worker_id): 76 | worker_seed = torch.initial_seed() % 2 ** 32 77 | np.random.seed(worker_seed) 78 | random.seed(worker_seed) 79 | 80 | def ids_to_dataloader_split(data, train_ids, val_ids, seed): 81 | train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids) 82 | val_subsampler = torch.utils.data.SubsetRandomSampler(val_ids) 83 | 84 | g = torch.Generator() 85 | g.manual_seed(seed) 86 | 87 | trainloader = torch.utils.data.DataLoader( 88 | data, 89 | batch_size=BATCH_SIZE, sampler=train_subsampler, worker_init_fn=seed_worker, generator=g) 90 | valloader = torch.utils.data.DataLoader( 91 | data, 92 | batch_size=BATCH_SIZE, sampler=val_subsampler, worker_init_fn=seed_worker, generator=g) 93 | return trainloader, valloader 94 | 95 | 96 | def init_results(tag, seed, datasets, config_regs, overwrite=False): 97 | """Helper function to initialise an empty dictionary for storing results""" 98 | results = load_results(f'experiment_{tag}_seed_{seed}') 99 | regularisers = config_regs.keys() 100 | if results and not overwrite: 101 | raise ValueError('Results already exist, to overwrite pass overwrite = True') 102 | else: 103 | results = {} 104 | for dataset in datasets: 105 | results[dataset] = {} 106 | for regulariser in regularisers: 107 | results[dataset][regulariser] = {} 108 | num_combinations = 1 109 | for value in config_regs[regulariser].values(): 110 | num_combinations *= len(value) 111 | for i in range(num_combinations): 112 | results[dataset][regulariser][i] = {} 113 | results[dataset][regulariser][i]['val_loss'] = [] 114 | 115 | save_results(results, f'experiment_{tag}_seed_{seed}') 116 | 117 | def load_results(file_name): 118 | curr_dir = os.path.dirname(__file__) 119 | results_dir = os.path.join(curr_dir, f'results/{file_name}.json') 120 | file_obj = Path(results_dir) 121 | if file_obj.is_file(): 122 | with open(results_dir) as f: 123 | results = json.load(f) 124 | return results 125 | else: 126 | print(f'{file_name}.json not found in results folder, generating new file.') 127 | return {} 128 | 129 | def save_results(results, file_name): 130 | curr_dir = os.path.dirname(__file__) 131 | results_dir = os.path.join(curr_dir, f'results/{file_name}.json') 132 | with open(results_dir, 'w') as f: 133 | json.dump(results, f) 134 | 135 | def run_fold(fold_name, model, trainloader, valloader, config, config_dataset, seed): 136 | tag, dataset, regulariser, params, fold = fold_name.split(':') 137 | loss_func = MSE if config_dataset['type'] == 'regression' else nn.CrossEntropyLoss() 138 | l2_weight = config['weight'] if regulariser == 'l2' else 0 139 | optimiser = optim.Adam(model.parameters(), lr=config_dataset['lr'], weight_decay=l2_weight) 140 | if regulariser == 'TANGOS': 141 | parameter_scheduler = parameter_schedule(config['lambda_1'], config['lambda_2'], config['param_schedule']) 142 | else: 143 | parameter_scheduler = None 144 | best_val_loss = np.inf; last_update = 0 145 | for epoch in range(EPOCHS): 146 | if regulariser == 'TANGOS': 147 | lambda_1, lambda_2 = parameter_scheduler.get_reg(epoch) 148 | config['lambda_1_curr'] = lambda_1 149 | config['lambda_2_curr'] = lambda_2 150 | 151 | model = train_epoch(model, trainloader, optimiser, loss_func, config, regulariser, device=DEVICE) 152 | val_loss, _ = evaluate(model, valloader, loss_func, device=DEVICE) 153 | 154 | if (val_loss < best_val_loss) or (epoch < 5): 155 | best_val_loss = val_loss 156 | best_model = copy.deepcopy(model) 157 | last_update = epoch 158 | 159 | # early stopping criteria 160 | if epoch - last_update == TRAINING_PATIENCE: 161 | break 162 | 163 | # save best model results for this fold 164 | results = load_results(f'experiment_{tag}_seed_{seed}') 165 | results[dataset][regulariser][params]['val_loss'].append(best_val_loss) 166 | save_results(results, f'experiment_{tag}_seed_{seed}') 167 | 168 | return best_val_loss, best_model, last_update 169 | 170 | def run_cv(config_dataset: dict, regulariser: str, params: dict, run_name: str, seed: int): 171 | data_fetcher = getattr(src.load_data, config_dataset['loader']) 172 | loaders = data_fetcher(seed=0) 173 | dropout = params['p'] if regulariser == 'dropout' else 0 174 | batch_norm = True if regulariser == 'batch_norm' else False 175 | kfold = KFold(n_splits=K_FOLDS, shuffle=False) 176 | best_loss = np.inf 177 | # loop through folds 178 | for fold, (train_ids, val_ids) in enumerate(kfold.split(loaders['train'])): 179 | torch.manual_seed(seed); np.random.seed(seed) 180 | trainloader, valloader = ids_to_dataloader_split(loaders['train'], train_ids, val_ids, seed=seed) 181 | fold_name = run_name + f':{fold}' 182 | model = UCI_MLP(num_features=config_dataset['num_features'], num_outputs=config_dataset['num_outputs'], 183 | dropout=dropout, batch_norm=batch_norm).to(DEVICE) 184 | fold_loss, fold_model, fold_epoch = run_fold(fold_name, model, trainloader, valloader, params, 185 | config_dataset, seed=seed) 186 | if fold_loss < best_loss: 187 | best_loss = fold_loss 188 | best_model = copy.deepcopy(fold_model) 189 | best_epoch = fold_epoch 190 | 191 | # evalutate best performing model on held out test set 192 | loss_func = MSE if config_dataset['type'] == 'regression' else nn.CrossEntropyLoss() 193 | test_loss, _ = evaluate(best_model, loaders['test'], loss_func) 194 | tag, dataset, regulariser, params = run_name.split(':') 195 | results = load_results(f'experiment_{tag}_seed_{seed}') 196 | results[dataset][regulariser][params]['test_loss'] = test_loss 197 | results[dataset][regulariser][params]['train_final_epoch'] = best_epoch 198 | print(test_loss) 199 | save_results(results, f'experiment_{tag}_seed_{seed}') 200 | 201 | def grid_search_iterable(parameter_dict: dict) -> list: 202 | """Generate an iterable list of hyperparameters from a dictionary containing the values to be considered""" 203 | keys, values = zip(*parameter_dict.items()) 204 | parameter_grid = [dict(zip(keys, v)) for v in itertools.product(*values)] 205 | return parameter_grid 206 | 207 | def load_config(name): 208 | curr_dir = os.path.dirname(__file__) 209 | config_dir = os.path.join(curr_dir, f'configs/{name}.json') 210 | with open(config_dir) as f: 211 | config_dict = json.load(f) 212 | config_keys = config_dict.keys() 213 | return config_dict, config_keys 214 | 215 | def run_experiment(seeds: list, tag: str): 216 | # load config files 217 | config_regs, regularisers = load_config('regularizers') 218 | config_data, datasets = load_config('datasets') 219 | 220 | for seed in seeds: 221 | # initialise results file 222 | init_results(tag, seed, datasets, config_regs, overwrite=True) 223 | for dataset in datasets: 224 | for regulariser in regularisers: 225 | parmaeter_iterable = grid_search_iterable(config_regs[regulariser]) 226 | for idx, param_set in enumerate(parmaeter_iterable): 227 | run_name = f'{tag}:{dataset}:{regulariser}:{idx}' 228 | # run CV on this combination 229 | print(run_name) 230 | run_cv(config_data[dataset], regulariser, param_set, run_name, seed) 231 | # save record of parameters used for this run 232 | param_record = load_results(f'params_record') 233 | param_record[f'id_:{seed}:{run_name}'] = param_set 234 | save_results(param_record, f'params_record') 235 | 236 | 237 | if __name__ == '__main__': 238 | parser = argparse.ArgumentParser() 239 | parser.add_argument('-seeds', default=[0], help='Set of seeds to use for experiments') 240 | parser.add_argument('-tag', default='tag', help='Tag name for set of experiments') 241 | args = parser.parse_args() 242 | print(args.seeds) 243 | run_experiment(seeds=args.seeds, tag=args.tag) -------------------------------------------------------------------------------- /src/experiments/larger-data/main.py: -------------------------------------------------------------------------------- 1 | from functorch import jacrev 2 | from functorch import vmap 3 | from torch import optim 4 | import numpy as np 5 | from sklearn.experimental import enable_iterative_imputer # noqa: F401,E402 6 | import pandas as pd 7 | from torch.utils.data import Dataset, DataLoader, TensorDataset 8 | from sklearn.model_selection import train_test_split 9 | from sklearn.preprocessing import StandardScaler 10 | import torch 11 | import torch.nn as nn 12 | import json 13 | 14 | d = pd.read_csv('path/to/data/dionis', header = None) 15 | TRAINING_RATIO = 0.1 # change this for different ratios of training data 16 | 17 | y = d.iloc[:,0] 18 | X = d.iloc[:, 1:] 19 | 20 | 21 | SEED = 0 22 | BATCH_SIZE = 256 23 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=SEED) 24 | X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=SEED) 25 | 26 | 27 | num = int(X_train.shape[0] * TRAINING_RATIO) 28 | X_train = X_train[:num] 29 | y_train = y_train[:num] 30 | 31 | scaler_train = StandardScaler() 32 | X_train = scaler_train.fit_transform(X_train) 33 | X_val = scaler_train.transform(X_val) 34 | X_test = scaler_train.transform(X_test) 35 | 36 | train_dataset = TensorDataset(torch.Tensor(X_train), torch.Tensor(y_train.to_numpy())) 37 | val_dataset = TensorDataset(torch.Tensor(X_val), torch.Tensor(y_val.to_numpy())) 38 | test_dataset = TensorDataset(torch.Tensor(X_test), torch.Tensor(y_test.to_numpy())) 39 | loaders = { 40 | 'train': DataLoader(train_dataset, 41 | batch_size=BATCH_SIZE, 42 | shuffle=True, 43 | num_workers=1), 44 | 45 | 'val': DataLoader(val_dataset, 46 | batch_size=BATCH_SIZE, 47 | shuffle=False, 48 | num_workers=1), 49 | 50 | 'test': DataLoader(test_dataset, 51 | batch_size=int(BATCH_SIZE), 52 | shuffle=False, 53 | num_workers=1) 54 | } 55 | 56 | 57 | class UCI_MLP(nn.Module): 58 | def __init__(self, num_features, num_outputs, dropout=0, batch_norm=False): 59 | super(UCI_MLP, self).__init__() 60 | self.dropout = torch.nn.Dropout(p=dropout) 61 | self.batch_norm = batch_norm 62 | d = num_features + 1 63 | self.fc1 = nn.Linear(num_features, 400) 64 | self.bn1 = nn.BatchNorm1d(d) 65 | self.relu1 = nn.ReLU(inplace=False) 66 | self.fc2 = nn.Linear(400, 100) 67 | self.bn2 = nn.BatchNorm1d(d) 68 | self.relu2 = nn.ReLU(inplace=False) 69 | self.fc3 = nn.Linear(100, 10) 70 | self.relu3 = nn.ReLU(inplace=False) 71 | self.fc4 = nn.Linear(10, num_outputs) 72 | 73 | def forward(self, x): 74 | batch_size = x.shape[0] 75 | out = self.fc1(x) 76 | if self.batch_norm and batch_size > 1: 77 | out = self.bn1(out) 78 | out = self.relu1(out) 79 | out = self.dropout(out) 80 | out = self.fc2(out) 81 | if self.batch_norm and batch_size > 1: 82 | out = self.bn2(out) 83 | out = self.relu2(out) 84 | out = self.dropout(out) 85 | out = self.fc3(out) 86 | h_output = self.relu3(out) 87 | out = self.fc4(h_output) 88 | return out, h_output 89 | 90 | 91 | def attr_loss(forward_func, data_input, device='cpu', subsample=-1): 92 | ########## UPDATE functools ############ 93 | batch_size = data_input.shape[0] 94 | 95 | def test(input_): 96 | _, h_out = forward_func(input_) 97 | return h_out 98 | 99 | data_input = data_input.clone().requires_grad_(True) 100 | jacobian = vmap(jacrev(test))(data_input) 101 | neuron_attr = jacobian.swapaxes(0, 1) 102 | h_dim = neuron_attr.shape[0] 103 | 104 | if len(neuron_attr.shape) > 3: 105 | # h_dim x batch_size x features 106 | neuron_attr = neuron_attr.flatten(start_dim=2) 107 | 108 | sparsity_loss = torch.norm(neuron_attr, p=1) / (batch_size * h_dim * neuron_attr.shape[2]) 109 | 110 | cos = nn.CosineSimilarity(dim=1, eps=1e-6) 111 | correlation_loss = torch.tensor(0., requires_grad=True).to(device) 112 | if subsample > 0 and subsample < h_dim * (h_dim - 1) / 2: 113 | tensor_pairs = [list(np.random.choice(h_dim, size=(2), replace=False)) for i in range(subsample)] 114 | for tensor_pair in tensor_pairs: 115 | pairwise_corr = cos(neuron_attr[tensor_pair[0], :, :], neuron_attr[tensor_pair[1], :, :]).norm(p=1) 116 | correlation_loss = correlation_loss + pairwise_corr 117 | 118 | correlation_loss = correlation_loss / (batch_size * subsample) 119 | 120 | else: 121 | for neuron_i in range(1, h_dim): 122 | for neuron_j in range(0, neuron_i): 123 | pairwise_corr = cos(neuron_attr[neuron_i, :, :], neuron_attr[neuron_j, :, :]).norm(p=1) 124 | correlation_loss = correlation_loss + pairwise_corr 125 | num_pairs = h_dim * (h_dim - 1) / 2 126 | correlation_loss = correlation_loss / (batch_size * num_pairs) 127 | 128 | return sparsity_loss, correlation_loss 129 | 130 | 131 | def train_epoch(model, loader, loss_func, optimiser, epoch, 132 | lambda_1=0, lambda_2=0, device='cpu', subsample=-1): 133 | running_loss = 0 134 | for i, (data, label) in enumerate(loader): 135 | model.train() 136 | data, label = data.to(device), label.type(torch.LongTensor).to(device) 137 | optimiser.zero_grad() 138 | output, _ = model(data) 139 | 140 | pred_loss = loss_func(output.squeeze(), label) 141 | 142 | if lambda_1 + lambda_2 > 0: 143 | sparsity_loss, correlation_loss = attr_loss(model, data, device=device, subsample=subsample) 144 | else: 145 | sparsity_loss, correlation_loss = 0, 0 146 | 147 | loss = pred_loss + lambda_1 * sparsity_loss + lambda_2 * correlation_loss 148 | running_loss += loss.item() 149 | 150 | loss.backward() 151 | optimiser.step() 152 | 153 | if (i + 1) % 100 == 0: 154 | print('Epoch [{}], Step [{}/{}], Loss: {:.4f}' 155 | .format(epoch + 1, i + 1, len(loader), running_loss / (i + 1))) 156 | print(f"Lambda1: {lambda_1}, Lambda2: {lambda_2}") 157 | 158 | return model 159 | 160 | 161 | def evaluate(model, loader, loss_func, epoch, 162 | lambda_1=0, lambda_2=0, device='cpu', subsample=-1, log_set='test'): 163 | correct, total = 0, 0 164 | running_loss, running_pred_loss = 0, 0 165 | running_pred, running_gt = np.array([]), np.array([]) 166 | 167 | for i, (data, label) in enumerate(loader): 168 | model.eval() 169 | data, label = data.to(device), label.type(torch.LongTensor).to(device) 170 | 171 | output, _ = model(data) 172 | pred_loss = loss_func(output.squeeze(), label) 173 | 174 | sparsity_loss, correlation_loss = attr_loss(model, data, device=device, subsample=subsample) 175 | 176 | loss = pred_loss + lambda_1 * sparsity_loss + lambda_2 * correlation_loss 177 | 178 | running_loss += loss.item() 179 | running_pred_loss += pred_loss.item() 180 | 181 | pred_probs = torch.sigmoid(output) 182 | pred_y = torch.argmax(pred_probs, 1) 183 | correct += (pred_y == label).sum().item() 184 | total += float(label.size()[0]) 185 | 186 | 187 | accuracy = correct / total 188 | 189 | average_loss = running_loss / len(loader) 190 | averge_pred_loss = running_pred_loss / len(loader) 191 | 192 | print(f'[Test] Epoch: {epoch + 1}, accuracy: {accuracy:.4f}, ' \ 193 | f'average test loss: {average_loss:.4f}, ' \ 194 | f'pred loss: {averge_pred_loss:.4f}, ' \ 195 | f'sparsity loss: {sparsity_loss.item():.4f}, correlation loss: {correlation_loss.item():.4f}') 196 | 197 | return averge_pred_loss, accuracy 198 | 199 | 200 | def train_full(seed, lambda_1, lambda_2, LR=0): 201 | EPOCHS = 100 202 | TRAINING_PATIENCE = 5 203 | BATCH_SIZE = 256 204 | DEVICE = 'cuda:0' 205 | 206 | runs = 1 207 | 208 | learning_rate = 0.001 209 | weight_decay = LR 210 | num_features = 60 211 | num_outputs = 355 212 | subsample = 50 213 | model_save_path = 'model_weights' 214 | min_epoch = 1 215 | best_acc = 0 216 | accuracy_val_ls = [] 217 | 218 | for _ in range(runs): 219 | torch.random.manual_seed(seed) 220 | 221 | model = UCI_MLP(num_features, num_outputs, dropout=0, batch_norm=False).to(DEVICE) 222 | print(f'Training on {DEVICE}...') 223 | 224 | loss_func = nn.CrossEntropyLoss() 225 | optimiser = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) 226 | 227 | patience = 0 228 | 229 | for epoch in range(EPOCHS): 230 | 231 | model = train_epoch(model, loaders['train'], loss_func, optimiser, 232 | epoch=epoch, lambda_1=lambda_1, lambda_2=lambda_2, 233 | device=DEVICE, subsample=subsample) 234 | 235 | val_loss, accuracy = evaluate(model, loaders['val'], loss_func, epoch=epoch, 236 | lambda_1=lambda_1, lambda_2=lambda_2, device=DEVICE, subsample=subsample) 237 | accuracy_val_ls.append(accuracy) 238 | 239 | if epoch >= min_epoch: 240 | if best_acc < accuracy: 241 | print(f'Epoch {epoch + 1} - Validation performance improved, saving model...') 242 | best_acc = accuracy 243 | torch.save(model.state_dict(), model_save_path) 244 | patience = 0 245 | else: 246 | patience += 1 247 | 248 | if patience == TRAINING_PATIENCE: 249 | print(f'Epoch {epoch + 1} - Early stopping since no improvement after {patience} epochs') 250 | break 251 | 252 | # evaluate on cutract dataset 253 | # load best model 254 | model.load_state_dict(torch.load(model_save_path)) 255 | averge_pred_loss, accuracy = evaluate(model, loaders['test'], loss_func, epoch=0, 256 | lambda_1=lambda_1, lambda_2=lambda_2, device=DEVICE, subsample=subsample, 257 | log_set='target') 258 | return accuracy 259 | 260 | 261 | 262 | # main logic for training baseline, tangos regularization and l2 regularization 263 | baseline_ls = [] 264 | for seed in range(6): 265 | acc = train_full(seed, 0, 0, LR=0) 266 | baseline_ls.append(acc) 267 | with open('src/experiments/larger-data/baseline.json', 'w') as f: 268 | json.dump({'test_acc': baseline_ls}, f) 269 | 270 | TANGOS_ls = [] 271 | for seed in range(6): 272 | acc = train_full(seed, 1, 0.01, LR=0) 273 | TANGOS_ls.append(acc) 274 | with open('src/experiments/larger-data/TANGOS.json', 'w') as f: 275 | json.dump({'test_acc': TANGOS_ls}, f) 276 | 277 | l2_ls = [] 278 | for seed in range(6): 279 | acc = train_full(seed, 0, 0, LR=0.0001) 280 | l2_ls.append(acc) 281 | with open('src/experiments/larger-data/l2.json', 'w') as f: 282 | json.dump({'test_acc': l2_ls}, f) 283 | 284 | -------------------------------------------------------------------------------- /src/experiments/in-tandem-regularization/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import itertools 4 | import argparse 5 | import json 6 | import copy 7 | from pathlib import Path 8 | import numpy as np 9 | import random 10 | import torch.nn as nn 11 | from torch import optim 12 | from sklearn.model_selection import KFold 13 | from src.losses import parameter_schedule, attr_loss, MSE 14 | from src.models import UCI_MLP 15 | from src.regularizers import l1, add_input_noise, mixup_data, mixup_criterion 16 | import src.load_data 17 | 18 | import warnings 19 | 20 | warnings.filterwarnings("ignore") 21 | 22 | EPOCHS = 200 23 | TRAINING_PATIENCE = 30 24 | BATCH_SIZE = 32 25 | K_FOLDS = 5 26 | DEVICE = 'cuda:0' 27 | 28 | 29 | def train_epoch(model, train_loader, optimiser, loss_func, params, params_ls, regulariser, device='cpu'): 30 | reg_loss = 0 31 | for i, (data, label) in enumerate(train_loader): 32 | model.train() 33 | data, label = data.to(device), label.to(device) 34 | if regulariser == 'input_noise': 35 | data = add_input_noise(data, params['std']) 36 | optimiser.zero_grad() 37 | output, _ = model(data) 38 | pred_loss = loss_func(output, label) 39 | 40 | if (params_ls['lambda_1_curr'] > 0) or (params_ls['lambda_2_curr'] > 0): 41 | sparsity_loss, correlation_loss = attr_loss(model, data, device=device, subsample=params_ls['subsample']) 42 | reg_loss = params_ls['lambda_1_curr'] * sparsity_loss + params_ls['lambda_2_curr'] * correlation_loss 43 | 44 | if regulariser == 'l1': 45 | reg_loss = params['weight'] * l1(model) 46 | 47 | elif regulariser == 'mixup': 48 | X_mixup, y_a, y_b, lam = mixup_data(data, label, alpha=params['alpha'], device=device) 49 | output_mixup, _ = model(X_mixup) 50 | reg_loss = mixup_criterion(loss_func, output_mixup, y_a, y_b, lam) 51 | 52 | loss = pred_loss + reg_loss 53 | loss.backward() 54 | optimiser.step() 55 | return model 56 | 57 | 58 | def evaluate(model, test_loader, loss_func, device=DEVICE): 59 | running_loss, running_pred_loss = 0, 0 60 | for epoch, (data, label) in enumerate(test_loader): 61 | model.eval() 62 | data, label = data.to(device), label.to(device) 63 | 64 | output, _ = model(data) 65 | 66 | # compute metric 67 | pred_loss = loss_func(output, label) 68 | loss = pred_loss 69 | 70 | running_loss += loss.item() 71 | running_pred_loss += pred_loss.item() 72 | 73 | return running_pred_loss/(epoch + 1), running_loss/(epoch + 1) 74 | 75 | def seed_worker(worker_id): 76 | worker_seed = torch.initial_seed() % 2 ** 32 77 | np.random.seed(worker_seed) 78 | random.seed(worker_seed) 79 | 80 | def ids_to_dataloader_split(data, train_ids, val_ids, seed): 81 | train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids) 82 | val_subsampler = torch.utils.data.SubsetRandomSampler(val_ids) 83 | 84 | g = torch.Generator() 85 | g.manual_seed(seed) 86 | 87 | trainloader = torch.utils.data.DataLoader( 88 | data, 89 | batch_size=BATCH_SIZE, sampler=train_subsampler, worker_init_fn=seed_worker, generator=g) 90 | valloader = torch.utils.data.DataLoader( 91 | data, 92 | batch_size=BATCH_SIZE, sampler=val_subsampler, worker_init_fn=seed_worker, generator=g) 93 | return trainloader, valloader 94 | 95 | 96 | def init_results(tag, seed, datasets, config_regs, latent_sniper_regs, overwrite=False): 97 | """Helper function to initialise an empty dictionary for storing results""" 98 | results = load_results(f'experiment_{tag}_seed_{seed}') 99 | regularisers = config_regs.keys() 100 | num_combinations_ls = 1 101 | for i, j in latent_sniper_regs.items(): 102 | num_combinations_ls *= len(j) 103 | 104 | if results and not overwrite: 105 | raise ValueError('Results already exist, to overwrite pass overwrite = True') 106 | else: 107 | results = {} 108 | for dataset in datasets: 109 | results[dataset] = {} 110 | for regulariser in regularisers: 111 | results[dataset][regulariser] = {} 112 | num_combinations = 1 113 | for value in config_regs[regulariser].values(): 114 | num_combinations *= len(value) 115 | for i in range(num_combinations * num_combinations_ls): 116 | results[dataset][regulariser][i] = {} 117 | results[dataset][regulariser][i]['val_loss'] = [] 118 | 119 | save_results(results, f'experiment_{tag}_seed_{seed}') 120 | 121 | def load_results(file_name): 122 | curr_dir = os.path.dirname(__file__) 123 | results_dir = os.path.join(curr_dir, f'results/{file_name}.json') 124 | file_obj = Path(results_dir) 125 | if file_obj.is_file(): 126 | with open(results_dir) as f: 127 | results = json.load(f) 128 | return results 129 | else: 130 | print(f'{file_name}.json not found in results folder, generating new file.') 131 | return {} 132 | 133 | def save_results(results, file_name): 134 | curr_dir = os.path.dirname(__file__) 135 | results_dir = os.path.join(curr_dir, f'results/{file_name}.json') 136 | with open(results_dir, 'w') as f: 137 | json.dump(results, f) 138 | 139 | def run_fold(fold_name, model, trainloader, valloader, config, config_ls, config_dataset, seed): 140 | tag, dataset, regulariser, params, fold = fold_name.split(':') 141 | loss_func = MSE if config_dataset['type'] == 'regression' else nn.CrossEntropyLoss() 142 | l2_weight = config['weight'] if regulariser == 'l2' else 0 143 | optimiser = optim.Adam(model.parameters(), lr=config_dataset['lr'], weight_decay=l2_weight) 144 | parameter_scheduler = parameter_schedule(config_ls['lambda_1'], config_ls['lambda_2'], config_ls['param_schedule']) 145 | best_val_loss = np.inf; last_update = 0 146 | for epoch in range(EPOCHS): 147 | lambda_1, lambda_2 = parameter_scheduler.get_reg(epoch) 148 | config_ls['lambda_1_curr'] = lambda_1 149 | config_ls['lambda_2_curr'] = lambda_2 150 | 151 | model = train_epoch(model, trainloader, optimiser, loss_func, config, config_ls, regulariser, device=DEVICE) 152 | val_loss, _ = evaluate(model, valloader, loss_func, device=DEVICE) 153 | 154 | if (val_loss < best_val_loss) or (epoch < 5): 155 | best_val_loss = val_loss 156 | best_model = copy.deepcopy(model) 157 | last_update = epoch 158 | 159 | # early stopping criteria 160 | if epoch - last_update == TRAINING_PATIENCE: 161 | break 162 | 163 | # save best model results for this fold 164 | results = load_results(f'experiment_{tag}_seed_{seed}') 165 | results[dataset][regulariser][params]['val_loss'].append(best_val_loss) 166 | save_results(results, f'experiment_{tag}_seed_{seed}') 167 | 168 | return best_val_loss, best_model, last_update 169 | 170 | def run_cv(config_dataset: dict, regulariser: str, params: dict, params_ls: dict, run_name: str, seed: int): 171 | data_fetcher = getattr(src.load_data, config_dataset['loader']) 172 | loaders = data_fetcher(seed=0) 173 | dropout = params['p'] if regulariser == 'dropout' else 0 174 | batch_norm = True if regulariser == 'batch_norm' else False 175 | kfold = KFold(n_splits=K_FOLDS, shuffle=False) 176 | best_loss = np.inf 177 | # loop through folds 178 | for fold, (train_ids, val_ids) in enumerate(kfold.split(loaders['train'])): 179 | torch.manual_seed(seed); np.random.seed(seed) 180 | trainloader, valloader = ids_to_dataloader_split(loaders['train'], train_ids, val_ids, seed=seed) 181 | fold_name = run_name + f':{fold}' 182 | model = UCI_MLP(num_features=config_dataset['num_features'], num_outputs=config_dataset['num_outputs'], 183 | dropout=dropout, batch_norm=batch_norm).to(DEVICE) 184 | fold_loss, fold_model, fold_epoch = run_fold(fold_name, model, trainloader, valloader, params, params_ls, 185 | config_dataset, seed=seed) 186 | if fold_loss < best_loss: 187 | best_loss = fold_loss 188 | best_model = copy.deepcopy(fold_model) 189 | best_epoch = fold_epoch 190 | 191 | # evalutate best performing model on held out test set 192 | loss_func = MSE if config_dataset['type'] == 'regression' else nn.CrossEntropyLoss() 193 | test_loss, _ = evaluate(best_model, loaders['test'], loss_func) 194 | tag, dataset, regulariser, params = run_name.split(':') 195 | results = load_results(f'experiment_{tag}_seed_{seed}') 196 | results[dataset][regulariser][params]['test_loss'] = test_loss 197 | results[dataset][regulariser][params]['train_final_epoch'] = best_epoch 198 | print(test_loss) 199 | save_results(results, f'experiment_{tag}_seed_{seed}') 200 | 201 | def grid_search_iterable(parameter_dict: dict) -> list: 202 | """Generate an iterable list of hyperparameters from a dictionary containing the values to be considered""" 203 | keys, values = zip(*parameter_dict.items()) 204 | parameter_grid = [dict(zip(keys, v)) for v in itertools.product(*values)] 205 | return parameter_grid 206 | 207 | def load_config(name): 208 | curr_dir = os.path.dirname(__file__) 209 | config_dir = os.path.join(curr_dir, f'configs/{name}.json') 210 | with open(config_dir) as f: 211 | config_dict = json.load(f) 212 | config_keys = list(config_dict) 213 | return config_dict, config_keys 214 | 215 | def run_experiment(seeds: list, tag: str): 216 | # load config files 217 | config_regs, regularisers = load_config('regularizers') 218 | config_data, datasets = load_config('datasets') 219 | latent_sniper_regs = config_regs.pop('TANGOS', None) 220 | regularisers.remove('TANGOS') 221 | latent_sniper_iterable = grid_search_iterable(latent_sniper_regs) 222 | for seed in seeds: 223 | # initialise results file 224 | init_results(tag, seed, datasets, config_regs, latent_sniper_regs, overwrite=True) 225 | for dataset in datasets: 226 | for regulariser in regularisers: 227 | parmaeter_iterable = grid_search_iterable(config_regs[regulariser]) 228 | idx = 0 229 | for param_set in parmaeter_iterable: 230 | for param_set_ls in latent_sniper_iterable: 231 | run_name = f'{tag}:{dataset}:{regulariser}:{idx}' 232 | # run CV on this combination 233 | print(run_name) 234 | run_cv(config_data[dataset], regulariser, param_set, param_set_ls, run_name, seed) 235 | # save record of parameters used for this run 236 | param_record = load_results(f'params_record') 237 | param_record[f'id_:{seed}:{run_name}'] = param_set 238 | save_results(param_record, f'params_record') 239 | idx +=1 240 | 241 | 242 | if __name__ == '__main__': 243 | parser = argparse.ArgumentParser() 244 | parser.add_argument('-seeds', default=[0], help='Set of seeds to use for experiments') 245 | parser.add_argument('-tag', default='tag', help='Tag name for set of experiments') 246 | args = parser.parse_args() 247 | print(args.seeds) 248 | run_experiment(seeds=args.seeds, tag=args.tag) -------------------------------------------------------------------------------- /src/load_data.py: -------------------------------------------------------------------------------- 1 | import arff 2 | import os 3 | import re 4 | import torch 5 | import pandas as pd 6 | import numpy as np 7 | from torch.utils.data import Dataset, DataLoader, TensorDataset 8 | from sklearn.model_selection import train_test_split 9 | from sklearn.preprocessing import StandardScaler 10 | import json 11 | 12 | 13 | def get_path(): 14 | """Get path to data dir""" 15 | results_dir = 'src/data_config.json' 16 | with open(results_dir) as f: 17 | results = json.load(f) 18 | return results['path_to_data'] 19 | 20 | 21 | def load_wine(seed, train_prop=0.8, batch_size=64): 22 | data = pd.read_csv(get_path() + 'winequality-red.csv') 23 | data = data[:1000] 24 | 25 | X = data.drop('quality', axis=1) 26 | y = data.quality 27 | X, y = X.to_numpy(), y.to_numpy() 28 | X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=train_prop, random_state=seed) 29 | 30 | X_scaler = StandardScaler() 31 | X_train = X_scaler.fit_transform(X_train) 32 | X_test = X_scaler.transform(X_test) 33 | y_scaler = StandardScaler() 34 | y_train = y_scaler.fit_transform(y_train.reshape(-1,1)).reshape(-1) 35 | y_test = y_scaler.transform(y_test.reshape(-1,1)).reshape(-1) 36 | 37 | train_dataset = TensorDataset(torch.Tensor(X_train), torch.Tensor(y_train)) 38 | test_dataset = TensorDataset(torch.Tensor(X_test), torch.Tensor(y_test)) 39 | 40 | loaders = { 41 | 'train': train_dataset, 42 | 43 | 'test': DataLoader(test_dataset, 44 | batch_size=batch_size, 45 | shuffle=False, 46 | num_workers=1) 47 | } 48 | return loaders 49 | 50 | def load_facebook(seed, train_prop=0.8, batch_size=64): 51 | data = pd.read_csv(get_path() + 'dataset_facebook.csv', sep=';') 52 | data.dropna(inplace=True) # drop missing values 53 | one_hot = pd.get_dummies(data['Type']) # onehotencode categorical column 54 | data = data.drop('Type', axis=1) 55 | data = data.join(one_hot) 56 | X = data.drop('Lifetime Post Total Impressions', axis = 1) 57 | y = data['Lifetime Post Total Impressions'] 58 | X, y = X.to_numpy(), y.to_numpy() 59 | 60 | X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=train_prop, random_state=seed) 61 | 62 | X_scaler = StandardScaler() 63 | X_train = X_scaler.fit_transform(X_train) 64 | X_test = X_scaler.transform(X_test) 65 | y_scaler = StandardScaler() 66 | y_train = y_scaler.fit_transform(y_train.reshape(-1,1)).reshape(-1) 67 | y_test = y_scaler.transform(y_test.reshape(-1,1)).reshape(-1) 68 | 69 | train_dataset = TensorDataset(torch.Tensor(X_train), torch.Tensor(y_train)) 70 | test_dataset = TensorDataset(torch.Tensor(X_test), torch.Tensor(y_test)) 71 | 72 | loaders = { 73 | 'train': train_dataset, 74 | 75 | 'test': DataLoader(test_dataset, 76 | batch_size=batch_size, 77 | shuffle=False, 78 | num_workers=1) 79 | } 80 | return loaders 81 | 82 | 83 | def load_bioconcentration(seed, train_prop=0.8, batch_size=64): 84 | data = pd.read_csv(get_path() + 'Grisoni_et_al_2016_EnvInt88.csv', sep=',') 85 | 86 | X = data[['nHM', 'piPC09', 'PCD', 'X2Av', 'MLOGP', 'ON1V', 'N-072', 'B02[C-N]', 'F04[C-O]']] 87 | for var in ['nHM', 'N-072', 'B02[C-N]', 'F04[C-O]']: 88 | one_hot = pd.get_dummies(X[var], prefix=var) # onehotencode categorical column 89 | X = X.drop(var, axis=1) 90 | X = X.join(one_hot) 91 | 92 | y = data['logBCF'] 93 | X, y = X.to_numpy(), y.to_numpy() 94 | 95 | X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=train_prop, random_state=seed) 96 | 97 | X_scaler = StandardScaler() 98 | X_train = X_scaler.fit_transform(X_train) 99 | X_test = X_scaler.transform(X_test) 100 | y_scaler = StandardScaler() 101 | y_train = y_scaler.fit_transform(y_train.reshape(-1,1)).reshape(-1) 102 | y_test = y_scaler.transform(y_test.reshape(-1,1)).reshape(-1) 103 | 104 | train_dataset = TensorDataset(torch.Tensor(X_train), torch.Tensor(y_train)) 105 | test_dataset = TensorDataset(torch.Tensor(X_test), torch.Tensor(y_test)) 106 | 107 | loaders = { 108 | 'train': train_dataset, 109 | 110 | 'test': DataLoader(test_dataset, 111 | batch_size=batch_size, 112 | shuffle=False, 113 | num_workers=1) 114 | } 115 | return loaders 116 | 117 | def load_student(seed, train_prop=0.8, batch_size=64): 118 | data = pd.read_csv(get_path() + 'student-por.csv', sep=';') 119 | 120 | X = data.drop(['G1', 'G2', 'G3'], axis = 1) 121 | for var in ['school', 'sex', 'address', 'famsize', 'Pstatus', 'Mjob', 'Fjob', 122 | 'reason', 'guardian','schoolsup', 'famsup', 'paid', 'activities', 123 | 'nursery', 'higher', 'internet', 'romantic',]: 124 | one_hot = pd.get_dummies(X[var], prefix=var) # onehotencode categorical column 125 | X = X.drop(var, axis=1) 126 | X = X.join(one_hot) 127 | 128 | y = data['G3'] 129 | X, y = X.to_numpy(), y.to_numpy() 130 | 131 | X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=train_prop, random_state=seed) 132 | 133 | X_scaler = StandardScaler() 134 | X_train = X_scaler.fit_transform(X_train) 135 | X_test = X_scaler.transform(X_test) 136 | y_scaler = StandardScaler() 137 | y_train = y_scaler.fit_transform(y_train.reshape(-1,1)).reshape(-1) 138 | y_test = y_scaler.transform(y_test.reshape(-1,1)).reshape(-1) 139 | 140 | train_dataset = TensorDataset(torch.Tensor(X_train), torch.Tensor(y_train)) 141 | test_dataset = TensorDataset(torch.Tensor(X_test), torch.Tensor(y_test)) 142 | 143 | loaders = { 144 | 'train': train_dataset, 145 | 146 | 'test': DataLoader(test_dataset, 147 | batch_size=batch_size, 148 | shuffle=False, 149 | num_workers=1) 150 | } 151 | return loaders 152 | 153 | 154 | def load_abalone(seed, train_prop=0.8, batch_size=64): 155 | data = pd.read_csv(get_path() + 'abalone.data', sep=',', header=None) 156 | data = data[:1000] 157 | one_hot = pd.get_dummies(data[0], drop_first=True) # onehotencode categorical column 158 | data = data.drop(0, axis=1) 159 | data = data.join(one_hot) 160 | 161 | X = data.drop(8, axis=1) 162 | y = data[8] 163 | 164 | X, y = X.to_numpy(), y.to_numpy() 165 | 166 | X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=train_prop, random_state=seed) 167 | 168 | X_scaler = StandardScaler() 169 | X_train = X_scaler.fit_transform(X_train) 170 | X_test = X_scaler.transform(X_test) 171 | y_scaler = StandardScaler() 172 | y_train = y_scaler.fit_transform(y_train.reshape(-1,1)).reshape(-1) 173 | y_test = y_scaler.transform(y_test.reshape(-1,1)).reshape(-1) 174 | 175 | train_dataset = TensorDataset(torch.Tensor(X_train), torch.Tensor(y_train)) 176 | test_dataset = TensorDataset(torch.Tensor(X_test), torch.Tensor(y_test)) 177 | 178 | loaders = { 179 | 'train': train_dataset, 180 | 181 | 'test': DataLoader(test_dataset, 182 | batch_size=batch_size, 183 | shuffle=False, 184 | num_workers=1) 185 | } 186 | return loaders 187 | 188 | 189 | def load_skillcraft(seed, train_prop=0.8, batch_size=64): 190 | data = pd.read_csv(get_path() + 'SkillCraft1_Dataset.csv', sep=',') 191 | data = data.replace('?', np.NaN) 192 | data = data.dropna() 193 | data = data[:1000] 194 | X = data.drop(['GameID', 'LeagueIndex'], axis=1) 195 | y = data['LeagueIndex'] 196 | 197 | X, y = X.to_numpy(), y.to_numpy() 198 | 199 | X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=train_prop, random_state=seed) 200 | 201 | X_scaler = StandardScaler() 202 | X_train = X_scaler.fit_transform(X_train) 203 | X_test = X_scaler.transform(X_test) 204 | y_scaler = StandardScaler() 205 | y_train = y_scaler.fit_transform(y_train.reshape(-1, 1)).reshape(-1) 206 | y_test = y_scaler.transform(y_test.reshape(-1, 1)).reshape(-1) 207 | 208 | train_dataset = TensorDataset(torch.Tensor(X_train), torch.Tensor(y_train)) 209 | test_dataset = TensorDataset(torch.Tensor(X_test), torch.Tensor(y_test)) 210 | 211 | loaders = { 212 | 'train': train_dataset, 213 | 214 | 'test': DataLoader(test_dataset, 215 | batch_size=batch_size, 216 | shuffle=False, 217 | num_workers=1) 218 | } 219 | return loaders 220 | 221 | 222 | def load_weather(seed, train_prop=0.8, batch_size=64): 223 | data = pd.read_csv(get_path() + 'Bias_correction_ucl.csv', sep=',') 224 | data = data.dropna() 225 | data = data[:1000] 226 | one_hot = pd.get_dummies(data['station'], drop_first=True) # onehotencode categorical column 227 | data = data.drop('station', axis=1) 228 | data = data.join(one_hot) 229 | X = data.drop(['Date', 'Next_Tmax', 'Next_Tmin'], axis=1) 230 | y = data['Next_Tmax'] 231 | 232 | X, y = X.to_numpy(), y.to_numpy() 233 | 234 | X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=train_prop, random_state=seed) 235 | 236 | X_scaler = StandardScaler() 237 | X_train = X_scaler.fit_transform(X_train) 238 | X_test = X_scaler.transform(X_test) 239 | y_scaler = StandardScaler() 240 | y_train = y_scaler.fit_transform(y_train.reshape(-1, 1)).reshape(-1) 241 | y_test = y_scaler.transform(y_test.reshape(-1, 1)).reshape(-1) 242 | 243 | train_dataset = TensorDataset(torch.Tensor(X_train), torch.Tensor(y_train)) 244 | test_dataset = TensorDataset(torch.Tensor(X_test), torch.Tensor(y_test)) 245 | 246 | loaders = { 247 | 'train': train_dataset, 248 | 249 | 'test': DataLoader(test_dataset, 250 | batch_size=batch_size, 251 | shuffle=False, 252 | num_workers=1) 253 | } 254 | return loaders 255 | 256 | 257 | def load_forest(seed, train_prop=0.8, batch_size=64): 258 | data = pd.read_csv(get_path() + 'forestfires.csv', sep=',') 259 | X = data.drop('area', axis=1) 260 | for var in ['X', 'Y', 'month', 'day']: 261 | one_hot = pd.get_dummies(X[var], prefix=var, drop_first=True) # onehotencode categorical column 262 | X = X.drop(var, axis=1) 263 | X = X.join(one_hot) 264 | y = np.log(data['area'] + 1) 265 | 266 | X, y = X.to_numpy(), y.to_numpy() 267 | 268 | X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=train_prop, random_state=seed) 269 | 270 | X_scaler = StandardScaler() 271 | X_train = X_scaler.fit_transform(X_train) 272 | X_test = X_scaler.transform(X_test) 273 | y_scaler = StandardScaler() 274 | y_train = y_scaler.fit_transform(y_train.reshape(-1, 1)).reshape(-1) 275 | y_test = y_scaler.transform(y_test.reshape(-1, 1)).reshape(-1) 276 | 277 | train_dataset = TensorDataset(torch.Tensor(X_train), torch.Tensor(y_train)) 278 | test_dataset = TensorDataset(torch.Tensor(X_test), torch.Tensor(y_test)) 279 | 280 | loaders = { 281 | 'train': train_dataset, 282 | 283 | 'test': DataLoader(test_dataset, 284 | batch_size=batch_size, 285 | shuffle=False, 286 | num_workers=1) 287 | } 288 | return loaders 289 | 290 | def load_protein(seed, train_prop=0.8, batch_size=64): 291 | data = pd.read_csv(get_path() + 'CASP.csv', sep=',') 292 | data = data[:1000] 293 | X = data.drop('RMSD', axis=1) 294 | y = data['RMSD'] 295 | 296 | X, y = X.to_numpy(), y.to_numpy() 297 | 298 | X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=train_prop, random_state=seed) 299 | 300 | X_scaler = StandardScaler() 301 | X_train = X_scaler.fit_transform(X_train) 302 | X_test = X_scaler.transform(X_test) 303 | y_scaler = StandardScaler() 304 | y_train = y_scaler.fit_transform(y_train.reshape(-1,1)).reshape(-1) 305 | y_test = y_scaler.transform(y_test.reshape(-1,1)).reshape(-1) 306 | 307 | train_dataset = TensorDataset(torch.Tensor(X_train), torch.Tensor(y_train)) 308 | test_dataset = TensorDataset(torch.Tensor(X_test), torch.Tensor(y_test)) 309 | 310 | loaders = { 311 | 'train': train_dataset, 312 | 313 | 'test': DataLoader(test_dataset, 314 | batch_size=batch_size, 315 | shuffle=False, 316 | num_workers=1) 317 | } 318 | return loaders 319 | 320 | 321 | def load_heart(seed, train_prop=0.8, batch_size=64): 322 | data = pd.read_csv(get_path() + 'heart.dat', sep=' ', header=None) 323 | X = data.drop(13, axis=1) 324 | y = data[13] - 1 325 | data[9] = np.log(data[9] + 1) 326 | for var in [1, 2, 5, 6, 8, 10, 11, 12]: 327 | one_hot = pd.get_dummies(X[var], prefix=var, drop_first=True) # onehotencode categorical column 328 | X = X.drop(var, axis=1) 329 | X = X.join(one_hot) 330 | 331 | X, y = X.to_numpy(), y.to_numpy() 332 | 333 | X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=train_prop, random_state=seed) 334 | 335 | X_scaler = StandardScaler() 336 | X_train = X_scaler.fit_transform(X_train) 337 | X_test = X_scaler.transform(X_test) 338 | 339 | train_dataset = TensorDataset(torch.Tensor(X_train), torch.tensor(y_train, dtype=torch.long)) 340 | test_dataset = TensorDataset(torch.Tensor(X_test), torch.tensor(y_test, dtype=torch.long)) 341 | 342 | loaders = { 343 | 'train': train_dataset, 344 | 345 | 'test': DataLoader(test_dataset, 346 | batch_size=batch_size, 347 | shuffle=False, 348 | num_workers=1) 349 | } 350 | return loaders 351 | 352 | 353 | def load_breast(seed, train_prop=0.8, batch_size=64): 354 | data = pd.read_csv(get_path() + 'breast-cancer-wisconsin.data', header=None) 355 | X = data.drop([0, 10], axis=1) 356 | X[6].replace('?', np.nan, inplace=True) 357 | X[6].fillna((X[6].median()), inplace=True) 358 | y = data[10]/2 - 1 359 | 360 | X, y = X.to_numpy(), y.to_numpy() 361 | 362 | X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=train_prop, random_state=seed) 363 | 364 | X_scaler = StandardScaler() 365 | X_train = X_scaler.fit_transform(X_train) 366 | X_test = X_scaler.transform(X_test) 367 | 368 | train_dataset = TensorDataset(torch.Tensor(X_train), torch.tensor(y_train, dtype=torch.long)) 369 | test_dataset = TensorDataset(torch.Tensor(X_test), torch.tensor(y_test, dtype=torch.long)) 370 | 371 | loaders = { 372 | 'train': train_dataset, 373 | 374 | 'test': DataLoader(test_dataset, 375 | batch_size=batch_size, 376 | shuffle=False, 377 | num_workers=1) 378 | } 379 | return loaders 380 | 381 | def load_cervical(seed, train_prop=0.8, batch_size=64): 382 | data = pd.read_csv(get_path() + 'risk_factors_cervical_cancer.csv') 383 | X = data[['Age', 'Number of sexual partners', 'First sexual intercourse', 384 | 'Num of pregnancies', 'Smokes', 'Smokes (years)', 'Smokes (packs/year)', 385 | 'Hormonal Contraceptives', 'Hormonal Contraceptives (years)', 'IUD', 386 | 'IUD (years)', 'STDs', 'STDs (number)', 'STDs:condylomatosis', 387 | 'STDs:cervical condylomatosis', 'STDs:vaginal condylomatosis', 388 | 'STDs:vulvo-perineal condylomatosis', 'STDs:syphilis', 389 | 'STDs:pelvic inflammatory disease', 'STDs:genital herpes', 390 | 'STDs:molluscum contagiosum', 'STDs:AIDS', 'STDs:HIV', 391 | 'STDs:Hepatitis B', 'STDs:HPV', 'STDs: Number of diagnosis', 392 | 'STDs: Time since first diagnosis', 'STDs: Time since last diagnosis', 393 | 'Dx:Cancer', 'Dx:CIN', 'Dx:HPV', 'Dx']] 394 | X = X.replace('?', np.nan) 395 | X.fillna((X.median()), inplace=True) 396 | mapping = {'Hinselmann': 1, 'Schiller': 2, 'Citology': 3, 'Biopsy': 4} 397 | y = data[['Hinselmann', 'Schiller', 'Citology', 'Biopsy']].idxmax(axis=1) 398 | y = y.replace(mapping) 399 | 400 | for var in ['Smokes', 'Smokes (years)', 'Smokes (packs/year)', 'Hormonal Contraceptives', 401 | 'IUD', 'STDs', 'STDs:condylomatosis', 'STDs:cervical condylomatosis', 402 | 'STDs:vaginal condylomatosis', 'STDs:vulvo-perineal condylomatosis', 403 | 'STDs:syphilis','STDs:pelvic inflammatory disease', 'STDs:genital herpes', 404 | 'STDs:molluscum contagiosum', 'STDs:AIDS', 'STDs:HIV', 'STDs:Hepatitis B', 405 | 'STDs:HPV', 'Dx:Cancer', 'Dx:CIN', 'Dx:HPV', 'Dx']: 406 | one_hot = pd.get_dummies(X[var], prefix=var, drop_first=True) # onehotencode categorical column 407 | X = X.drop(var, axis=1) 408 | X = X.join(one_hot) 409 | 410 | X, y = X.to_numpy(), y.to_numpy() 411 | 412 | X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=train_prop, random_state=seed) 413 | 414 | X_scaler = StandardScaler() 415 | X_train = X_scaler.fit_transform(X_train) 416 | X_test = X_scaler.transform(X_test) 417 | 418 | train_dataset = TensorDataset(torch.Tensor(X_train), torch.tensor(y_train, dtype=torch.long)) 419 | test_dataset = TensorDataset(torch.Tensor(X_test), torch.tensor(y_test, dtype=torch.long)) 420 | 421 | loaders = { 422 | 'train': train_dataset, 423 | 424 | 'test': DataLoader(test_dataset, 425 | batch_size=batch_size, 426 | shuffle=False, 427 | num_workers=1) 428 | } 429 | return loaders 430 | 431 | def load_credit(seed, train_prop=0.8, batch_size=64): 432 | data = pd.read_csv(get_path() + "crx.data", header = None) 433 | data = data[data[13] != '?'] 434 | data[13] = np.log(data[13].astype(int) + 1) 435 | data[14] = np.log(data[14] + 1) 436 | data[7] = np.log(data[7] + 1) 437 | data[1].replace('?', np.nan, inplace=True) 438 | data[1].fillna((data[1].median()), inplace=True) 439 | X = data.drop(15, axis=1) 440 | y = data[15] 441 | mapping = {'+': 1, '-': 0} 442 | y.replace(mapping, inplace=True) 443 | for var in [0, 3, 4, 5, 6, 8, 9, 11, 12]: 444 | one_hot = pd.get_dummies(X[var], prefix=var, drop_first=True) # onehotencode categorical column 445 | X = X.drop(var, axis=1) 446 | X = X.join(one_hot) 447 | 448 | X, y = X.to_numpy(), y.to_numpy() 449 | 450 | X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=train_prop, random_state=seed) 451 | 452 | X_scaler = StandardScaler() 453 | X_train = X_scaler.fit_transform(X_train) 454 | X_test = X_scaler.transform(X_test) 455 | 456 | train_dataset = TensorDataset(torch.Tensor(X_train), torch.tensor(y_train, dtype=torch.long)) 457 | test_dataset = TensorDataset(torch.Tensor(X_test), torch.tensor(y_test, dtype=torch.long)) 458 | 459 | loaders = { 460 | 'train': train_dataset, 461 | 462 | 'test': DataLoader(test_dataset, 463 | batch_size=batch_size, 464 | shuffle=False, 465 | num_workers=1) 466 | } 467 | return loaders 468 | 469 | 470 | def load_hcv(seed, train_prop=0.8, batch_size=64): 471 | data = pd.read_csv(get_path() + "hcvdat0.csv", index_col=0) 472 | y = data['Category'].apply(lambda x: int(x[0])) 473 | X = data.drop('Category', axis=1) 474 | one_hot = pd.get_dummies(X['Sex'], prefix='Sex', drop_first=True) # onehotencode categorical column 475 | X = X.drop('Sex', axis=1) 476 | X = X.join(one_hot) 477 | X = X.fillna(X.mean()) 478 | 479 | X, y = X.to_numpy(), y.to_numpy() 480 | 481 | X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=train_prop, random_state=seed) 482 | 483 | X_scaler = StandardScaler() 484 | X_train = X_scaler.fit_transform(X_train) 485 | X_test = X_scaler.transform(X_test) 486 | 487 | train_dataset = TensorDataset(torch.Tensor(X_train), torch.tensor(y_train, dtype=torch.long)) 488 | test_dataset = TensorDataset(torch.Tensor(X_test), torch.tensor(y_test, dtype=torch.long)) 489 | 490 | loaders = { 491 | 'train': train_dataset, 492 | 493 | 'test': DataLoader(test_dataset, 494 | batch_size=batch_size, 495 | shuffle=False, 496 | num_workers=1) 497 | } 498 | return loaders 499 | 500 | 501 | def load_tumor(seed, train_prop=0.8, batch_size=64): 502 | data = pd.read_csv(get_path() + 'primary-tumor.data', header=None) 503 | y = data[0] - 1 504 | X = data.drop(0, axis=1) 505 | for var in X.columns: 506 | one_hot = pd.get_dummies(X[var], prefix=var, drop_first=True) # onehotencode categorical column 507 | X = X.drop(var, axis=1) 508 | X = X.join(one_hot) 509 | 510 | X, y = X.to_numpy(), y.to_numpy() 511 | 512 | X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=train_prop, random_state=seed) 513 | 514 | X_scaler = StandardScaler() 515 | X_train = X_scaler.fit_transform(X_train) 516 | X_test = X_scaler.transform(X_test) 517 | 518 | train_dataset = TensorDataset(torch.Tensor(X_train), torch.tensor(y_train, dtype=torch.long)) 519 | test_dataset = TensorDataset(torch.Tensor(X_test), torch.tensor(y_test, dtype=torch.long)) 520 | 521 | loaders = { 522 | 'train': train_dataset, 523 | 524 | 'test': DataLoader(test_dataset, 525 | batch_size=batch_size, 526 | shuffle=False, 527 | num_workers=1) 528 | } 529 | return loaders 530 | 531 | def load_soybean(seed, train_prop=0.8, batch_size=64): 532 | data0 = pd.read_csv(get_path() + 'soybean-large.data', header=None) 533 | data1 = pd.read_csv(get_path() + 'soybean-large.test', header=None) 534 | data = pd.concat([data0, data1], axis=0) 535 | data.reset_index(inplace=True) 536 | y = data[0].rank(method='dense', ascending=False).astype(int) - 1 537 | X = data.drop(0, axis=1) 538 | for var in X.columns: 539 | one_hot = pd.get_dummies(X[var], prefix=f'dum_{var}', drop_first=True) # onehotencode categorical column 540 | X = X.drop(var, axis=1) 541 | X = X.join(one_hot) 542 | 543 | X, y = X.to_numpy(), y.to_numpy() 544 | 545 | X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=train_prop, random_state=seed) 546 | 547 | X_scaler = StandardScaler() 548 | X_train = X_scaler.fit_transform(X_train) 549 | X_test = X_scaler.transform(X_test) 550 | 551 | train_dataset = TensorDataset(torch.Tensor(X_train), torch.tensor(y_train, dtype=torch.long)) 552 | test_dataset = TensorDataset(torch.Tensor(X_test), torch.tensor(y_test, dtype=torch.long)) 553 | 554 | loaders = { 555 | 'train': train_dataset, 556 | 557 | 'test': DataLoader(test_dataset, 558 | batch_size=batch_size, 559 | shuffle=False, 560 | num_workers=1) 561 | } 562 | return loaders 563 | 564 | def load_australian(seed, train_prop=0.8, batch_size=64): 565 | data = pd.read_csv(get_path() + 'australian.dat', sep=' ', header=None) 566 | data[2] = np.log(data[2] + 1) 567 | data[6] = np.log(data[6] + 1) 568 | data[12] = np.log(data[12] + 1) 569 | data[13] = np.log(data[13] + 1) 570 | y = data[14] 571 | X = data.drop(14, axis=1) 572 | for var in [0,3,4,5,7,8,9,10,11]: 573 | one_hot = pd.get_dummies(X[var], prefix=var, drop_first=True) # onehotencode categorical column 574 | X = X.drop(var, axis=1) 575 | X = X.join(one_hot) 576 | X, y = X.to_numpy(), y.to_numpy() 577 | 578 | X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=train_prop, random_state=seed) 579 | 580 | X_scaler = StandardScaler() 581 | X_train = X_scaler.fit_transform(X_train) 582 | X_test = X_scaler.transform(X_test) 583 | 584 | train_dataset = TensorDataset(torch.Tensor(X_train), torch.tensor(y_train, dtype=torch.long)) 585 | test_dataset = TensorDataset(torch.Tensor(X_test), torch.tensor(y_test, dtype=torch.long)) 586 | 587 | loaders = { 588 | 'train': train_dataset, 589 | 590 | 'test': DataLoader(test_dataset, 591 | batch_size=batch_size, 592 | shuffle=False, 593 | num_workers=1) 594 | } 595 | return loaders 596 | 597 | def load_entrance(seed, train_prop=0.8, batch_size=64): 598 | data_arff = arff.load(open(get_path() + 'CEE_DATA.arff')) 599 | data = pd.DataFrame(data_arff['data']) 600 | y = data[0].rank(method='dense', ascending=False).astype(int) - 1 601 | X = data.drop(0, axis=1) 602 | for var in X.columns: 603 | one_hot = pd.get_dummies(X[var], prefix=var, drop_first=True) # onehotencode categorical column 604 | X = X.drop(var, axis=1) 605 | X = X.join(one_hot) 606 | X, y = X.to_numpy(), y.to_numpy() 607 | 608 | X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=train_prop, random_state=seed) 609 | 610 | X_scaler = StandardScaler() 611 | X_train = X_scaler.fit_transform(X_train) 612 | X_test = X_scaler.transform(X_test) 613 | 614 | train_dataset = TensorDataset(torch.Tensor(X_train), torch.tensor(y_train, dtype=torch.long)) 615 | test_dataset = TensorDataset(torch.Tensor(X_test), torch.tensor(y_test, dtype=torch.long)) 616 | 617 | loaders = { 618 | 'train': train_dataset, 619 | 620 | 'test': DataLoader(test_dataset, 621 | batch_size=batch_size, 622 | shuffle=False, 623 | num_workers=1) 624 | } 625 | return loaders 626 | 627 | def load_thoracic(seed, train_prop=0.8, batch_size=64): 628 | data_arff = arff.load(open(get_path() + 'ThoraricSurgery.arff')) 629 | data = pd.DataFrame(data_arff['data']) 630 | y = data[16] 631 | y = y.replace({'T':1, 'F':0}) 632 | X = data.drop(16, axis=1) 633 | X[2] = np.log(X[2] + 1) 634 | for var in [0,3,4,5,6,7,8,9,10,11,12,13,14]: 635 | one_hot = pd.get_dummies(X[var], prefix=var, drop_first=True) # onehotencode categorical column 636 | X = X.drop(var, axis=1) 637 | X = X.join(one_hot) 638 | X, y = X.to_numpy(), y.to_numpy() 639 | 640 | X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=train_prop, random_state=seed) 641 | 642 | X_scaler = StandardScaler() 643 | X_train = X_scaler.fit_transform(X_train) 644 | X_test = X_scaler.transform(X_test) 645 | 646 | train_dataset = TensorDataset(torch.Tensor(X_train), torch.tensor(y_train, dtype=torch.long)) 647 | test_dataset = TensorDataset(torch.Tensor(X_test), torch.tensor(y_test, dtype=torch.long)) 648 | 649 | loaders = { 650 | 'train': train_dataset, 651 | 652 | 'test': DataLoader(test_dataset, 653 | batch_size=batch_size, 654 | shuffle=False, 655 | num_workers=1) 656 | } 657 | return loaders -------------------------------------------------------------------------------- /TANGOS_quickstart.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [] 7 | }, 8 | "kernelspec": { 9 | "name": "python3", 10 | "display_name": "Python 3" 11 | }, 12 | "language_info": { 13 | "name": "python" 14 | }, 15 | "accelerator": "GPU", 16 | "gpuClass": "standard" 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "source": [ 22 | "[![open in colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/alanjeffares/TANGOS/blob/main/TANGOS_quickstart.ipynb)" 23 | ], 24 | "metadata": { 25 | "id": "8mgNk6CZwB2r" 26 | } 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "source": [ 31 | "# TANGOS quickstart guide\n", 32 | "This script provides a simple example of applying TANGOS as a drop in regularizer in a standard pytorch workflow. We begin by defining a dataloader and a simple MLP architecture before providing a straightforward function for calculating the two TANGOS loss terms - specialization loss and orthogonalization loss. We then provide an example of this loss being applied to train a model in a standard training loop." 33 | ], 34 | "metadata": { 35 | "id": "zorOJBNUG7Cr" 36 | } 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 1, 41 | "metadata": { 42 | "id": "2FGTU5LWGYXR" 43 | }, 44 | "outputs": [], 45 | "source": [ 46 | "# !pip install functorch\n", 47 | "from functorch import jacrev\n", 48 | "from functorch import vmap\n", 49 | "import pandas as pd\n", 50 | "from sklearn.model_selection import train_test_split\n", 51 | "from sklearn.preprocessing import StandardScaler\n", 52 | "from torch.utils.data import DataLoader, TensorDataset\n", 53 | "import torch\n", 54 | "import torch.nn as nn\n", 55 | "from torch import optim\n", 56 | "import numpy as np\n", 57 | "import matplotlib.pyplot as plt\n", 58 | "from typing import Callable" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "source": [ 64 | "# download a dataset from the UCI repository\n", 65 | "!wget https://archive.ics.uci.edu/ml/machine-learning-databases/00510/Grisoni_et_al_2016_EnvInt88.csv" 66 | ], 67 | "metadata": { 68 | "id": "uJ3SBBJmGkPX" 69 | }, 70 | "execution_count": null, 71 | "outputs": [] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "source": [], 76 | "metadata": { 77 | "id": "ZM3mM_rxGw7E" 78 | }, 79 | "execution_count": 2, 80 | "outputs": [] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "source": [ 85 | "## First define a simple data loader " 86 | ], 87 | "metadata": { 88 | "id": "1MJLYMjnGyFK" 89 | } 90 | }, 91 | { 92 | "cell_type": "code", 93 | "source": [ 94 | "# a data loader for the data\n", 95 | "def load_bioconcentration(seed, train_prop=0.8, batch_size=64):\n", 96 | " \"\"\"Returns dataloaders for the bioconcentration dataset\"\"\"\n", 97 | " data = pd.read_csv('Grisoni_et_al_2016_EnvInt88.csv', sep=',')\n", 98 | "\n", 99 | " # apply onehotencoding where appropriate\n", 100 | " X = data[['nHM', 'piPC09', 'PCD', 'X2Av', 'MLOGP', 'ON1V', 'N-072', 'B02[C-N]', 'F04[C-O]']]\n", 101 | " for var in ['nHM', 'N-072', 'B02[C-N]', 'F04[C-O]']:\n", 102 | " one_hot = pd.get_dummies(X[var], prefix=var) # onehotencode categorical columns\n", 103 | " X = X.drop(var, axis=1)\n", 104 | " X = X.join(one_hot)\n", 105 | "\n", 106 | " y = data['logBCF']\n", 107 | " X, y = X.to_numpy(), y.to_numpy()\n", 108 | "\n", 109 | " # split data\n", 110 | " X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=train_prop, random_state=seed)\n", 111 | "\n", 112 | " # rescale data\n", 113 | " X_scaler = StandardScaler()\n", 114 | " X_train = X_scaler.fit_transform(X_train)\n", 115 | " X_test = X_scaler.transform(X_test)\n", 116 | " y_scaler = StandardScaler()\n", 117 | " y_train = y_scaler.fit_transform(y_train.reshape(-1,1)).reshape(-1)\n", 118 | " y_test = y_scaler.transform(y_test.reshape(-1,1)).reshape(-1)\n", 119 | "\n", 120 | " train_dataset = TensorDataset(torch.Tensor(X_train), torch.Tensor(y_train))\n", 121 | " test_dataset = TensorDataset(torch.Tensor(X_test), torch.Tensor(y_test))\n", 122 | "\n", 123 | " loaders = {\n", 124 | " 'train': DataLoader(train_dataset,\n", 125 | " batch_size=batch_size,\n", 126 | " shuffle=True,\n", 127 | " num_workers=1),\n", 128 | "\n", 129 | " 'test': DataLoader(test_dataset,\n", 130 | " batch_size=batch_size,\n", 131 | " shuffle=False,\n", 132 | " num_workers=1)\n", 133 | " }\n", 134 | " return loaders\n", 135 | " \n" 136 | ], 137 | "metadata": { 138 | "id": "ACmrnDmbGxUN" 139 | }, 140 | "execution_count": 3, 141 | "outputs": [] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "source": [ 146 | "# Next design a simple MLP architecture" 147 | ], 148 | "metadata": { 149 | "id": "b5nbQ-5ZdV7y" 150 | } 151 | }, 152 | { 153 | "cell_type": "code", 154 | "source": [ 155 | "class SimpleMLP(nn.Module):\n", 156 | " def __init__(self, num_features):\n", 157 | " super(SimpleMLP, self).__init__()\n", 158 | " d = num_features + 1\n", 159 | " num_outputs = 1\n", 160 | " self.fc1 = nn.Linear(num_features, d)\n", 161 | " self.bn1 = nn.BatchNorm1d(d)\n", 162 | " self.relu1 = nn.ReLU(inplace=False)\n", 163 | " self.fc2 = nn.Linear(d, d)\n", 164 | " self.bn2 = nn.BatchNorm1d(d)\n", 165 | " self.relu2 = nn.ReLU(inplace=False)\n", 166 | " self.fc3 = nn.Linear(d, num_outputs)\n", 167 | "\n", 168 | " def forward(self, x):\n", 169 | " out = self.fc1(x)\n", 170 | " out = self.relu1(out)\n", 171 | " out = self.fc2(out)\n", 172 | " h_output = self.relu2(out)\n", 173 | " out = self.fc3(h_output)\n", 174 | " return out, h_output # note that we ensure the model outputs both predictions and a latent representation" 175 | ], 176 | "metadata": { 177 | "id": "aFjcO1xHHhvM" 178 | }, 179 | "execution_count": 4, 180 | "outputs": [] 181 | }, 182 | { 183 | "cell_type": "markdown", 184 | "source": [ 185 | "# Finally, we define a drop in function that calculates the TANGOS loss - outputting both the specialization and the orthogonalization components." 186 | ], 187 | "metadata": { 188 | "id": "S29uMPHFdkLR" 189 | } 190 | }, 191 | { 192 | "cell_type": "code", 193 | "source": [ 194 | "def TANGOS_loss(forward_func: Callable, data_input: torch.tensor, \n", 195 | " subsample: int = 50, device: str ='cpu'):\n", 196 | " \"\"\"\n", 197 | " A drop in function for calculating the TANGOS regularization loss. The loss\n", 198 | " consists of two components (specialization and orthogonalization) which are\n", 199 | " described in more detail in the main paper.\n", 200 | "\n", 201 | " Args:\n", 202 | " forward_func (Callable): The forward function from a pytorch model with \n", 203 | " an output tuple consisting of (_, latent_representation).\n", 204 | " data_input (torch.tensor): A batch of data.\n", 205 | " subsample (int): Number of pairs to subsample for the orthogonalization\n", 206 | " component.\n", 207 | " device (str): Indicating what device to run on.\n", 208 | "\n", 209 | " Returns:\n", 210 | " tuple containing the specialization loss and the orthogonalization loss\n", 211 | " both in torch tensor format.\n", 212 | " \"\"\"\n", 213 | "\n", 214 | " batch_size = data_input.shape[0]\n", 215 | " def wrapper(input_):\n", 216 | " \"\"\"A simple wrapper required by functools\"\"\"\n", 217 | " _, h_out = forward_func(input_)\n", 218 | " return h_out\n", 219 | " data_input = data_input.clone().requires_grad_(True)\n", 220 | " jacobian = vmap(jacrev(wrapper))(data_input)\n", 221 | " neuron_attr = jacobian.swapaxes(0,1)\n", 222 | " h_dim = neuron_attr.shape[0]\n", 223 | " \n", 224 | " if len(neuron_attr.shape) > 3:\n", 225 | " # h_dim x batch_size x features\n", 226 | " neuron_attr = neuron_attr.flatten(start_dim=2)\n", 227 | "\n", 228 | " # calculate specialization loss component\n", 229 | " spec_loss = torch.norm(neuron_attr, p=1)/(batch_size*h_dim*neuron_attr.shape[2])\n", 230 | "\n", 231 | " cos = nn.CosineSimilarity(dim=1, eps=1e-6) \n", 232 | " orth_loss = torch.tensor(0., requires_grad=True).to(device)\n", 233 | " \n", 234 | " # apply subsampling routine for orthogonalization loss\n", 235 | " if subsample > 0 and subsample < h_dim*(h_dim-1)/2:\n", 236 | " tensor_pairs = [list(np.random.choice(h_dim, size=(2), replace=False)) for i in range(subsample)]\n", 237 | " for tensor_pair in tensor_pairs:\n", 238 | " pairwise_corr = cos(neuron_attr[tensor_pair[0], :, :], \n", 239 | " neuron_attr[tensor_pair[1], :, :]).norm(p=1)\n", 240 | " orth_loss = orth_loss + pairwise_corr\n", 241 | "\n", 242 | " orth_loss = orth_loss/(batch_size*subsample)\n", 243 | "\n", 244 | " else:\n", 245 | " for neuron_i in range(1, h_dim):\n", 246 | " for neuron_j in range(0, neuron_i):\n", 247 | " pairwise_corr = cos(neuron_attr[neuron_i, :, :],\n", 248 | " neuron_attr[neuron_j, :, :]).norm(p=1)\n", 249 | " orth_loss = orth_loss + pairwise_corr\n", 250 | " num_pairs = h_dim*(h_dim-1)/2\n", 251 | " orth_loss = orth_loss/(batch_size*num_pairs)\n", 252 | "\n", 253 | " return spec_loss, orth_loss\n" 254 | ], 255 | "metadata": { 256 | "id": "8FEPP5KKH9bt" 257 | }, 258 | "execution_count": 5, 259 | "outputs": [] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "source": [], 264 | "metadata": { 265 | "id": "S5Yr13ebH3p-" 266 | }, 267 | "execution_count": 5, 268 | "outputs": [] 269 | }, 270 | { 271 | "cell_type": "markdown", 272 | "source": [ 273 | "# Train a model with TANGOS regularization and another with L2 regularization. " 274 | ], 275 | "metadata": { 276 | "id": "Rr1EkGx4okOr" 277 | } 278 | }, 279 | { 280 | "cell_type": "code", 281 | "source": [ 282 | "# set seed for reproducablility\n", 283 | "torch.manual_seed(0)\n", 284 | "torch.cuda.manual_seed(0)\n", 285 | "np.random.seed(0)\n", 286 | "\n", 287 | "loss_func = nn.MSELoss()\n", 288 | "data = load_bioconcentration(0)\n", 289 | "train_loader = data['train']\n", 290 | "val_loader = data['test']\n", 291 | "\n", 292 | "lambda_1, lambda_2 = 100, 0.1\n", 293 | "lr = 0.001\n", 294 | "device = 'cuda'\n", 295 | "n_epochs = 100\n", 296 | "\n", 297 | "# instantiate models and optimimizers \n", 298 | "TANGOS_model = SimpleMLP(num_features=45).to(device)\n", 299 | "L2_model = SimpleMLP(num_features=45).to(device)\n", 300 | "TANGOS_optimiser = optim.Adam(TANGOS_model.parameters(), lr=lr, weight_decay=0)\n", 301 | "L2_optimiser = optim.Adam(L2_model.parameters(), lr=lr, weight_decay=0.1)\n", 302 | "\n", 303 | "TANGOS_train_loss_ls = []; L2_train_loss_ls = []\n", 304 | "TANGOS_val_loss_ls = []; L2_val_loss_ls = []\n", 305 | "\n", 306 | "for epoch in range(n_epochs):\n", 307 | " TANGOS_running_loss = 0; L2_running_loss = 0\n", 308 | " # training epoch\n", 309 | " for data, label in train_loader:\n", 310 | " TANGOS_model.train(); L2_model.train()\n", 311 | " data, label = data.to(device), label.to(device)\n", 312 | " TANGOS_optimiser.zero_grad(); L2_optimiser.zero_grad()\n", 313 | "\n", 314 | " # forward and backward pass for TANGOS model\n", 315 | " TANGOS_output, _ = TANGOS_model(data)\n", 316 | " MSE_loss = loss_func(TANGOS_output.squeeze(), label)\n", 317 | "\n", 318 | " spec_loss, orth_loss = TANGOS_loss(TANGOS_model, data, subsample=50,\n", 319 | " device=device) # calculate TANGOS loss\n", 320 | " TANGOS_reg_loss = lambda_1 * spec_loss + lambda_2 * orth_loss # weight the two terms\n", 321 | " TANGOS_loss_val = MSE_loss + TANGOS_reg_loss # add TANGOS loss to MSE loss\n", 322 | "\n", 323 | " TANGOS_running_loss += MSE_loss.item()\n", 324 | " TANGOS_loss_val.backward()\n", 325 | " TANGOS_optimiser.step()\n", 326 | "\n", 327 | " # forward and backward pass for L2 model\n", 328 | " L2_output, _ = L2_model(data)\n", 329 | " MSE_loss = loss_func(L2_output.squeeze(), label)\n", 330 | "\n", 331 | " L2_running_loss += MSE_loss.item()\n", 332 | " MSE_loss.backward()\n", 333 | " L2_optimiser.step()\n", 334 | "\n", 335 | " TANGOS_train_loss_ls.append(TANGOS_running_loss)\n", 336 | " L2_train_loss_ls.append(L2_running_loss)\n", 337 | "\n", 338 | " TANGOS_running_val_loss = 0; L2_running_val_loss = 0\n", 339 | " # validation epoch\n", 340 | " for data, label in val_loader:\n", 341 | " TANGOS_model.eval(); L2_model.eval()\n", 342 | " data, label = data.to(device), label.to(device)\n", 343 | "\n", 344 | " # evaluate TANGOS model\n", 345 | " TANGOS_output, _ = TANGOS_model(data)\n", 346 | " TANGOS_reg_loss = loss_func(TANGOS_output.squeeze(), label)\n", 347 | " TANGOS_running_val_loss += TANGOS_reg_loss.item()\n", 348 | "\n", 349 | " # evaluate l2 model\n", 350 | " L2_output, _ = L2_model(data)\n", 351 | " L2_loss = loss_func(L2_output.squeeze(), label)\n", 352 | " L2_running_val_loss += L2_loss.item()\n", 353 | "\n", 354 | " TANGOS_val_loss_ls.append(TANGOS_running_val_loss)\n", 355 | " L2_val_loss_ls.append(L2_running_val_loss)" 356 | ], 357 | "metadata": { 358 | "id": "P-Ky9Km1H3sQ" 359 | }, 360 | "execution_count": 6, 361 | "outputs": [] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "source": [], 366 | "metadata": { 367 | "id": "tZM6XeeHK4f-" 368 | }, 369 | "execution_count": 6, 370 | "outputs": [] 371 | }, 372 | { 373 | "cell_type": "markdown", 374 | "source": [ 375 | "# Plot the training and validation loss plots" 376 | ], 377 | "metadata": { 378 | "id": "Yu_PgrtZow1w" 379 | } 380 | }, 381 | { 382 | "cell_type": "code", 383 | "source": [ 384 | "plt.figure(figsize=(15,4))\n", 385 | "\n", 386 | "plt.subplot(1, 2, 1)\n", 387 | "plt.title('Train Loss')\n", 388 | "plt.plot(TANGOS_train_loss_ls, label='TANGOS')\n", 389 | "plt.plot(L2_train_loss_ls, label='L2')\n", 390 | "plt.xlabel('Epochs')\n", 391 | "plt.ylabel('Loss')\n", 392 | "plt.ylim(1.5,6)\n", 393 | "plt.legend()\n", 394 | "\n", 395 | "plt.subplot(1, 2, 2)\n", 396 | "plt.title('Val Loss')\n", 397 | "plt.plot(TANGOS_val_loss_ls, label='TANGOS')\n", 398 | "plt.plot(L2_val_loss_ls, label='L2')\n", 399 | "plt.legend()\n", 400 | "plt.xlabel('Epochs')\n", 401 | "plt.ylabel('Loss')\n", 402 | "plt.ylim(.6,1.8)\n", 403 | "plt.show()" 404 | ], 405 | "metadata": { 406 | "colab": { 407 | "base_uri": "https://localhost:8080/", 408 | "height": 295 409 | }, 410 | "id": "PubGcagQJj5_", 411 | "outputId": "9e253623-9b5b-4cb9-88f3-759dfe2972a5" 412 | }, 413 | "execution_count": 7, 414 | "outputs": [ 415 | { 416 | "output_type": "display_data", 417 | "data": { 418 | "text/plain": [ 419 | "
" 420 | ], 421 | "image/png": "\n" 422 | }, 423 | "metadata": { 424 | "needs_background": "light" 425 | } 426 | } 427 | ] 428 | }, 429 | { 430 | "cell_type": "code", 431 | "source": [], 432 | "metadata": { 433 | "id": "gbixC9QHSxF1" 434 | }, 435 | "execution_count": 7, 436 | "outputs": [] 437 | } 438 | ] 439 | } --------------------------------------------------------------------------------