├── random_seeds.pt ├── data ├── mnist3 │ ├── full_mnist_analysis_test.txt │ ├── train_subset_10pc.txt │ ├── train_subset_20pc.txt │ ├── full_mnist_analysis_train.txt │ ├── train_subset_50pc.txt │ └── test_subset.txt └── cifar10 │ ├── train_subset_10pc.txt │ ├── train_subset_20pc.txt │ ├── train_subset_30pc.txt │ ├── train_subset_40pc.txt │ ├── test_subset.txt │ ├── train_subset_50pc.txt │ └── train_subset_60pc.txt ├── LICENSE ├── run_cnn_training.py ├── run_compute_pvalues.py ├── run_subset_generation.py ├── run_vit_training.py ├── tda_cnn_scripts ├── run_cnn_s_test.py ├── run_cnn_ats.py ├── run_cnn_if.py ├── run_cnn_loo.py └── run_cnn_gd_and_gc.py ├── tda_vit_scripts ├── run_vit_s_test.py ├── run_vit_ats.py ├── run_vit_gd_and_gc.py ├── run_vit_loo.py └── run_vit_if.py ├── run_correlation_analysis.py ├── README.md ├── req.txt ├── nn_influence_utils_vit.py ├── utils.py └── nn_influence_utils.py /random_seeds.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ElisaNguyen/bayesian-tda/HEAD/random_seeds.pt -------------------------------------------------------------------------------- /data/mnist3/full_mnist_analysis_test.txt: -------------------------------------------------------------------------------- 1 | 8765 2 | 3373 3 | 7234 4 | 322 5 | 9290 6 | 5583 7 | 6248 8 | 1317 9 | 2867 10 | 437 11 | -------------------------------------------------------------------------------- /data/mnist3/train_subset_10pc.txt: -------------------------------------------------------------------------------- 1 | 17411 2 | 34296 3 | 55805 4 | 54351 5 | 27484 6 | 36519 7 | 28875 8 | 17134 9 | 3809 10 | 41606 11 | 290 12 | 36098 13 | 59519 14 | 24447 15 | 24617 16 | 59692 17 | 51372 18 | 34796 19 | 27207 20 | 18474 21 | 32763 22 | 36685 23 | 53941 24 | 29667 25 | 41197 26 | 32973 27 | 50733 28 | 13523 29 | 10502 30 | 29019 31 | -------------------------------------------------------------------------------- /data/mnist3/train_subset_20pc.txt: -------------------------------------------------------------------------------- 1 | 26427 2 | 11428 3 | 54351 4 | 27484 5 | 20983 6 | 58458 7 | 30827 8 | 36213 9 | 34097 10 | 11754 11 | 32526 12 | 28875 13 | 11834 14 | 12575 15 | 49800 16 | 51464 17 | 41606 18 | 43834 19 | 29934 20 | 702 21 | 102 22 | 58309 23 | 25235 24 | 47004 25 | 22066 26 | 5834 27 | 47063 28 | 3509 29 | 24626 30 | 53930 31 | 59519 32 | 36968 33 | 51372 34 | 10378 35 | 39403 36 | 11143 37 | 290 38 | 2692 39 | 34796 40 | 45677 41 | 50733 42 | 9748 43 | 48982 44 | 41082 45 | 36685 46 | 57658 47 | 45638 48 | 36949 49 | 189 50 | 8446 51 | 29019 52 | 41197 53 | 33685 54 | 15829 55 | 41260 56 | 32022 57 | 33361 58 | 17821 59 | 47649 60 | 17927 61 | -------------------------------------------------------------------------------- /data/cifar10/train_subset_10pc.txt: -------------------------------------------------------------------------------- 1 | 4830 2 | 5029 3 | 2963 4 | 2832 5 | 3129 6 | 434 7 | 6021 8 | 5852 9 | 7384 10 | 9454 11 | 9790 12 | 1311 13 | 6286 14 | 6366 15 | 7217 16 | 1213 17 | 985 18 | 9948 19 | 5188 20 | 120 21 | 3903 22 | 8638 23 | 3444 24 | 6717 25 | 5224 26 | 785 27 | 244 28 | 7454 29 | 4784 30 | 9802 31 | 8808 32 | 1009 33 | 8090 34 | 3118 35 | 8940 36 | 1421 37 | 7551 38 | 7593 39 | 965 40 | 3388 41 | 8138 42 | 1970 43 | 440 44 | 6879 45 | 2773 46 | 5718 47 | 9843 48 | 3149 49 | 4980 50 | 8955 51 | 1732 52 | 5092 53 | 6193 54 | 5801 55 | 5488 56 | 9754 57 | 3970 58 | 9526 59 | 485 60 | 7522 61 | 8212 62 | 6378 63 | 1818 64 | 9134 65 | 5615 66 | 319 67 | 3287 68 | 4552 69 | 6068 70 | 9774 71 | 4871 72 | 4039 73 | 5105 74 | 8437 75 | 6354 76 | 194 77 | 3405 78 | 435 79 | 263 80 | 8142 81 | 2581 82 | 9715 83 | 8370 84 | 7537 85 | 983 86 | 8375 87 | 3402 88 | 3706 89 | 1292 90 | 3041 91 | 7657 92 | 1758 93 | 9137 94 | 7928 95 | 6708 96 | 1995 97 | 1382 98 | 7171 99 | 3969 100 | 9467 101 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Elisa Nguyen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /data/mnist3/full_mnist_analysis_train.txt: -------------------------------------------------------------------------------- 1 | 7072 2 | 25511 3 | 54939 4 | 9299 5 | 39776 6 | 48502 7 | 10193 8 | 17833 9 | 46095 10 | 30891 11 | 55582 12 | 55617 13 | 9787 14 | 50778 15 | 13336 16 | 12953 17 | 28618 18 | 15618 19 | 4503 20 | 529 21 | 51969 22 | 6246 23 | 48271 24 | 54758 25 | 36083 26 | 43486 27 | 8431 28 | 35060 29 | 34701 30 | 23292 31 | 20984 32 | 3752 33 | 2532 34 | 3101 35 | 41018 36 | 55717 37 | 27757 38 | 25153 39 | 39263 40 | 55674 41 | 30179 42 | 6663 43 | 57605 44 | 21285 45 | 37555 46 | 30648 47 | 12460 48 | 22360 49 | 51109 50 | 34004 51 | 41062 52 | 19445 53 | 58225 54 | 6149 55 | 15939 56 | 28006 57 | 49771 58 | 13882 59 | 33142 60 | 11660 61 | 35522 62 | 30011 63 | 50150 64 | 42756 65 | 37559 66 | 7042 67 | 3403 68 | 8475 69 | 42060 70 | 52690 71 | 7537 72 | 18444 73 | 22944 74 | 26808 75 | 15834 76 | 9722 77 | 13391 78 | 5204 79 | 8632 80 | 21247 81 | 13196 82 | 23888 83 | 27565 84 | 27246 85 | 5692 86 | 57791 87 | 7995 88 | 20634 89 | 29107 90 | 10300 91 | 7105 92 | 44733 93 | 9993 94 | 45938 95 | 24772 96 | 14455 97 | 3813 98 | 23522 99 | 25311 100 | 9182 101 | -------------------------------------------------------------------------------- /run_cnn_training.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | import torch.optim as optim 5 | from torch.utils.data import DataLoader 6 | from utils import (NetBW, NetRGB, NetBWThree, NetRGBThree, train_model, load_seeds) 7 | 8 | 9 | def main(): 10 | seeds = load_seeds() 11 | for task in ['mnist3', 'cifar10']: 12 | num_epochs = 15 if task == 'mnist3' else 30 # Train models for 15 epochs for MNIST, 30 for CIFAR 13 | 14 | for num_per_class in [10, 20, 50]: 15 | train_dataset = torch.load(f'{os.getcwd()}/data/{task}/train_subset_{num_per_class}.pt') 16 | 17 | for seed in seeds: 18 | torch.manual_seed(seed) # Set the random seed 19 | 20 | # Set up the model, data loader and optimizer 21 | # If you want to train three-layer CNNs, change the model class in the next line 22 | model = NetRGB() if train_dataset[0][0].shape[0]==3 else NetBW() 23 | criterion = nn.CrossEntropyLoss(reduction='none') 24 | optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.005) 25 | train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) 26 | 27 | # Train and record the last 5 checkpoints 28 | train_model(model, 29 | train_loader=train_loader, 30 | optimizer=optimizer, 31 | criterion=criterion, 32 | num_epochs=num_epochs, 33 | save_path=f'{os.getcwd()}/models/cnn/{task}_{num_per_class}pc/{seed}') 34 | 35 | 36 | if __name__=="__main__": 37 | main() -------------------------------------------------------------------------------- /data/mnist3/train_subset_50pc.txt: -------------------------------------------------------------------------------- 1 | 3012 2 | 26427 3 | 34296 4 | 45189 5 | 11754 6 | 42005 7 | 41606 8 | 43834 9 | 13335 10 | 17454 11 | 30827 12 | 55805 13 | 45112 14 | 20983 15 | 29934 16 | 36519 17 | 27484 18 | 6316 19 | 56158 20 | 52993 21 | 17134 22 | 702 23 | 22589 24 | 11834 25 | 51464 26 | 17411 27 | 3809 28 | 44513 29 | 12575 30 | 32526 31 | 45766 32 | 54351 33 | 31628 34 | 17977 35 | 662 36 | 12103 37 | 57769 38 | 49800 39 | 15992 40 | 34097 41 | 28875 42 | 10121 43 | 50042 44 | 11428 45 | 59363 46 | 58458 47 | 11758 48 | 36213 49 | 18337 50 | 17360 51 | 34796 52 | 36098 53 | 9045 54 | 16211 55 | 50857 56 | 24626 57 | 11112 58 | 3509 59 | 5834 60 | 11143 61 | 40468 62 | 41862 63 | 24970 64 | 28828 65 | 9240 66 | 27207 67 | 37200 68 | 45677 69 | 25235 70 | 23 71 | 14929 72 | 18474 73 | 24447 74 | 290 75 | 47004 76 | 36968 77 | 48645 78 | 15044 79 | 47642 80 | 19440 81 | 24617 82 | 1680 83 | 53930 84 | 58309 85 | 102 86 | 7558 87 | 59692 88 | 47063 89 | 30053 90 | 44919 91 | 23238 92 | 39403 93 | 22066 94 | 10378 95 | 2692 96 | 24567 97 | 39776 98 | 10123 99 | 51372 100 | 59519 101 | 58323 102 | 45638 103 | 8446 104 | 21441 105 | 15829 106 | 32973 107 | 51721 108 | 58710 109 | 32763 110 | 41197 111 | 189 112 | 33361 113 | 42503 114 | 10502 115 | 59094 116 | 3105 117 | 57658 118 | 49181 119 | 29667 120 | 23652 121 | 13506 122 | 18845 123 | 42223 124 | 35344 125 | 36949 126 | 48982 127 | 53941 128 | 28525 129 | 9748 130 | 17821 131 | 29019 132 | 50733 133 | 47649 134 | 12299 135 | 13523 136 | 45929 137 | 41260 138 | 41082 139 | 36685 140 | 43430 141 | 32022 142 | 17927 143 | 34604 144 | 30032 145 | 39342 146 | 49411 147 | 31150 148 | 23325 149 | 32142 150 | 33685 151 | -------------------------------------------------------------------------------- /run_compute_pvalues.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import pandas as pd 4 | from scipy import stats 5 | from utils import (load_attribution_types, get_mu_and_sigma) 6 | 7 | 8 | def main(): 9 | tau_types = load_attribution_types() 10 | experiment = 'cnn_mnist3_10pc' # __pc 11 | 12 | # Check if the path to save to already exists or not 13 | if not os.path.exists(f'{os.getcwd()}/results/{experiment}/'): 14 | os.makedirs(f'{os.getcwd()}/results/{experiment}/') 15 | 16 | # Load the test subset and its ids 17 | test_subset = torch.load(f'{os.getcwd()}/data/{experiment.split("_")[1]}/test_subset.pt') 18 | test_idx = [idx for _,_,idx in test_subset] 19 | 20 | ckpts = range(1,6) 21 | 22 | # Set up result dataframe for mean p-values. 23 | df_res = pd.DataFrame() 24 | df_res['num_ckpts'] = ckpts 25 | for tau in tau_types: 26 | mean_p_values = [] 27 | for num_ckpts in ckpts: 28 | # Compute the means and standard deviations of each train-test pair across random seeds and num_ckpts checkpoints. 29 | all_means, all_stds = get_mu_and_sigma(tau=tau, 30 | num_ckpts=num_ckpts, 31 | experiment=experiment, 32 | test_idx=test_idx) 33 | 34 | # Compute z-score and p-value 35 | z_scores = all_means/all_stds 36 | p_values = pd.DataFrame(stats.norm.sf(z_scores.abs())*2, columns=all_means.columns) 37 | 38 | # Save to results folder 39 | p_values.to_csv(f'{os.getcwd()}/results/{experiment}/pvalues_{tau}_across_{num_ckpts}_ckpts.csv', index=False) 40 | 41 | # Compute mean p-value of the experiment 42 | mean_p_value = p_values.values.mean() 43 | mean_p_values.append(mean_p_value) 44 | df_res[tau] = mean_p_values 45 | df_res.to_csv(f'{os.getcwd()}/results/{experiment}/mean_p_values.csv', index=False) 46 | 47 | 48 | if __name__=="__main__": 49 | main() -------------------------------------------------------------------------------- /data/cifar10/train_subset_20pc.txt: -------------------------------------------------------------------------------- 1 | 28722 2 | 44014 3 | 37048 4 | 24895 5 | 40519 6 | 44696 7 | 32572 8 | 30031 9 | 15338 10 | 24401 11 | 46050 12 | 37835 13 | 19955 14 | 10364 15 | 47297 16 | 28760 17 | 16390 18 | 8015 19 | 29506 20 | 39697 21 | 5138 22 | 16107 23 | 42058 24 | 4457 25 | 7199 26 | 40494 27 | 44722 28 | 32355 29 | 31686 30 | 43511 31 | 10830 32 | 40948 33 | 8416 34 | 24027 35 | 18789 36 | 35607 37 | 21988 38 | 512 39 | 17099 40 | 4536 41 | 40642 42 | 42636 43 | 17613 44 | 14744 45 | 28319 46 | 34760 47 | 23350 48 | 175 49 | 42631 50 | 14777 51 | 2515 52 | 40739 53 | 45323 54 | 32892 55 | 2520 56 | 6115 57 | 3879 58 | 17556 59 | 33725 60 | 14051 61 | 17492 62 | 41230 63 | 28210 64 | 23439 65 | 33870 66 | 29485 67 | 33985 68 | 13615 69 | 43600 70 | 7722 71 | 16310 72 | 43514 73 | 16711 74 | 5827 75 | 37491 76 | 23921 77 | 18078 78 | 45567 79 | 12841 80 | 4101 81 | 21630 82 | 4052 83 | 13071 84 | 1861 85 | 21487 86 | 8324 87 | 40388 88 | 14640 89 | 48893 90 | 48608 91 | 39309 92 | 37298 93 | 7275 94 | 40527 95 | 38365 96 | 44256 97 | 17212 98 | 2589 99 | 42098 100 | 2538 101 | 23344 102 | 38128 103 | 7323 104 | 28715 105 | 29586 106 | 14370 107 | 6676 108 | 48848 109 | 19155 110 | 18922 111 | 16183 112 | 36499 113 | 11870 114 | 16972 115 | 36589 116 | 8499 117 | 49885 118 | 49695 119 | 29331 120 | 30720 121 | 15843 122 | 47971 123 | 48432 124 | 35441 125 | 2255 126 | 13576 127 | 36560 128 | 28412 129 | 37953 130 | 31498 131 | 28279 132 | 21296 133 | 40122 134 | 33143 135 | 49180 136 | 22492 137 | 23434 138 | 2974 139 | 47758 140 | 31259 141 | 27909 142 | 37967 143 | 17644 144 | 22226 145 | 30288 146 | 46682 147 | 41159 148 | 8183 149 | 19781 150 | 48132 151 | 11498 152 | 11718 153 | 20058 154 | 22161 155 | 41461 156 | 31148 157 | 1696 158 | 40446 159 | 25184 160 | 32219 161 | 47509 162 | 8154 163 | 28754 164 | 31836 165 | 19042 166 | 2305 167 | 47847 168 | 12621 169 | 22707 170 | 46069 171 | 12682 172 | 10597 173 | 34048 174 | 47024 175 | 8039 176 | 29660 177 | 38172 178 | 45051 179 | 9948 180 | 17931 181 | 12944 182 | 37032 183 | 284 184 | 7504 185 | 5954 186 | 46513 187 | 47022 188 | 14499 189 | 35883 190 | 18657 191 | 44719 192 | 40621 193 | 38091 194 | 527 195 | 40654 196 | 44805 197 | 28525 198 | 12883 199 | 11276 200 | 20452 201 | -------------------------------------------------------------------------------- /run_subset_generation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | import numpy as np 4 | import os 5 | from utils import (MNISTWithIdx, CIFAR10WithIdx, load_subset_indices) 6 | 7 | 8 | def get_subset_indices(dataset, num_per_class): 9 | """Randomly chooses num_per_class indices from dataset to make a balanced subset and returns a list of indices.""" 10 | # Get the indices of each class in the dataset 11 | indices = {} 12 | for i in range(len(dataset)): 13 | _, label,_ = dataset[i] 14 | if label not in indices: 15 | indices[label] = [] 16 | indices[label].append(i) 17 | 18 | # Select a balanced subset of the dataset 19 | subset_indices = [] 20 | for label in indices: 21 | subset_indices += np.random.choice(indices[label], num_per_class, replace=False).tolist() 22 | return subset_indices 23 | 24 | 25 | def main(): 26 | # Download the data from torchvision 27 | trainset_mnist = MNISTWithIdx(root='./data', train=True, transform=transforms.ToTensor(), download=True) 28 | testset_mnist = MNISTWithIdx(root='./data', train=False, transform=transforms.ToTensor(), download=True) 29 | transform_cifar = transforms.Compose([transforms.ToTensor(), 30 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 31 | trainset_cifar = CIFAR10WithIdx(root='./data', train=True, transform=transform_cifar, download=True) 32 | testset_cifar = CIFAR10WithIdx(root='./data', train=False, transform=transform_cifar, download=True) 33 | 34 | for task, datasets in zip(['mnist3', 'cifar10'], [(trainset_mnist, testset_mnist), (trainset_cifar, testset_cifar)]): 35 | trainset, testset = datasets 36 | for num_per_class in [10, 20, 50]: 37 | # Create the balanced subset dataset (loading same indices as used in our study) 38 | train_indices = load_subset_indices(f'{os.getcwd()}/data/{task}/train_subset_{num_per_class}pc.txt') 39 | test_indices = load_subset_indices(f'{os.getcwd()}/data/{task}/test_subset.txt') 40 | 41 | # !! If you would like to define new subsets, uncomment this: 42 | # train_indices = get_subset_indices(trainset, num_per_class) 43 | # test_indices = get_subset_indices(testset, num_per_class) 44 | 45 | train_subset = torch.utils.data.Subset(trainset, train_indices) 46 | test_subset = torch.utils.data.Subset(testset, test_indices) 47 | 48 | # Save 49 | torch.save(train_subset, f'{os.getcwd()}/data/{task}/train_subset_{num_per_class}pc.pt') 50 | torch.save(test_subset, f'{os.getcwd()}/data/{task}/test_subset.pt') 51 | 52 | 53 | 54 | if __name__=='__main__': 55 | main() -------------------------------------------------------------------------------- /run_vit_training.py: -------------------------------------------------------------------------------- 1 | from transformers import get_linear_schedule_with_warmup 2 | import torch 3 | import tqdm 4 | import os 5 | from utils import load_seeds, ViTLoRA, load_vit_data 6 | 7 | 8 | def main(): 9 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 10 | seeds = load_seeds() 11 | 12 | for task in ['mnist3', 'cifar10']: 13 | num_epochs = 15 if task == 'mnist3' else 30 # Train models for 15 epochs for MNIST, 30 for CIFAR 14 | 15 | for num_per_class in [10, 20, 30, 40, 50, 60]: 16 | # Load the preprocessed data 17 | trainset, _ = load_vit_data(task, num_per_class) 18 | 19 | for seed in seeds: 20 | # Set up the save path if it does not exist yet 21 | save_path = f'{os.getcwd()}/models/vit/{task}_{num_per_class}pc/{seed}/' 22 | if not os.path.exists(save_path): 23 | os.makedirs(save_path) 24 | 25 | def collate_fn(examples): 26 | pixel_values = torch.stack([example["pixel_values"] for example in examples]) 27 | labels = torch.tensor([example["label"] for example in examples]) 28 | return {"pixel_values": pixel_values, "labels": labels} 29 | 30 | # Set the seed and dataloader 31 | torch.manual_seed(seed) 32 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=32, collate_fn=collate_fn, shuffle=True) 33 | 34 | # Load the LoRA model 35 | model = ViTLoRA(device=device) 36 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.005) 37 | lr_scheduler = get_linear_schedule_with_warmup( 38 | optimizer=optimizer, 39 | num_warmup_steps=0, 40 | num_training_steps=(len(train_loader) * num_epochs), 41 | ) 42 | 43 | # Train model and save last 5 checkpoints 44 | model.train() 45 | for epoch in range(num_epochs): 46 | for batch in tqdm.tqdm(train_loader): 47 | inputs = batch['pixel_values'] 48 | labels = batch['labels'] 49 | inputs = inputs.to(device) 50 | labels = labels.to(device) 51 | outputs = model(inputs, labels=labels) 52 | loss = outputs.loss.mean() 53 | loss.backward() 54 | optimizer.step() 55 | lr_scheduler.step() 56 | optimizer.zero_grad() 57 | if epoch > (num_epochs-5): 58 | torch.save(model.state_dict(), os.path.join(save_path, f'ckpt_{epoch}.pth')) 59 | 60 | 61 | if __name__=="__main__": 62 | main() 63 | -------------------------------------------------------------------------------- /data/cifar10/train_subset_30pc.txt: -------------------------------------------------------------------------------- 1 | 7747 2 | 5986 3 | 3125 4 | 2300 5 | 6726 6 | 115 7 | 5733 8 | 9678 9 | 8271 10 | 1660 11 | 9292 12 | 3113 13 | 4776 14 | 6299 15 | 279 16 | 6687 17 | 4886 18 | 2779 19 | 1386 20 | 5232 21 | 7420 22 | 3835 23 | 5128 24 | 4965 25 | 3602 26 | 688 27 | 2923 28 | 9080 29 | 1883 30 | 8112 31 | 5140 32 | 3683 33 | 4258 34 | 3870 35 | 1264 36 | 2431 37 | 2156 38 | 1075 39 | 5073 40 | 4539 41 | 6466 42 | 3366 43 | 1223 44 | 576 45 | 233 46 | 7464 47 | 4296 48 | 7599 49 | 9619 50 | 1659 51 | 1139 52 | 9719 53 | 173 54 | 5648 55 | 1648 56 | 408 57 | 1218 58 | 8418 59 | 3813 60 | 1943 61 | 1726 62 | 7146 63 | 4759 64 | 2622 65 | 477 66 | 7434 67 | 5189 68 | 526 69 | 8977 70 | 6812 71 | 5278 72 | 9590 73 | 7511 74 | 3927 75 | 3550 76 | 3207 77 | 9048 78 | 6302 79 | 9980 80 | 7838 81 | 9031 82 | 886 83 | 8405 84 | 8917 85 | 4180 86 | 7803 87 | 3441 88 | 7402 89 | 9749 90 | 5462 91 | 6229 92 | 3840 93 | 9489 94 | 8168 95 | 2602 96 | 4033 97 | 4743 98 | 7469 99 | 8425 100 | 6700 101 | 1935 102 | 2934 103 | 4477 104 | 1589 105 | 8442 106 | 2499 107 | 5727 108 | 2121 109 | 3427 110 | 3540 111 | 9006 112 | 5616 113 | 6398 114 | 8220 115 | 6956 116 | 3899 117 | 1757 118 | 8902 119 | 3885 120 | 6543 121 | 9421 122 | 9563 123 | 4598 124 | 2665 125 | 1880 126 | 8445 127 | 7862 128 | 8342 129 | 3828 130 | 1437 131 | 3691 132 | 240 133 | 4897 134 | 2099 135 | 5624 136 | 8489 137 | 1549 138 | 1435 139 | 7893 140 | 9126 141 | 4201 142 | 7001 143 | 7416 144 | 8595 145 | 1444 146 | 759 147 | 1711 148 | 9042 149 | 3904 150 | 3976 151 | 5311 152 | 9037 153 | 4945 154 | 2307 155 | 7512 156 | 554 157 | 7639 158 | 3141 159 | 4904 160 | 4761 161 | 4514 162 | 139 163 | 8430 164 | 8413 165 | 7932 166 | 6565 167 | 3958 168 | 5143 169 | 9445 170 | 7640 171 | 8000 172 | 7526 173 | 7686 174 | 1082 175 | 3288 176 | 3837 177 | 7220 178 | 4779 179 | 1220 180 | 1627 181 | 3268 182 | 1548 183 | 9146 184 | 9058 185 | 4051 186 | 4696 187 | 9922 188 | 9741 189 | 7157 190 | 7615 191 | 7446 192 | 4724 193 | 8211 194 | 4964 195 | 9260 196 | 1388 197 | 6506 198 | 4555 199 | 4507 200 | 6389 201 | 1116 202 | 1276 203 | 2590 204 | 1700 205 | 7662 206 | 8432 207 | 6332 208 | 5984 209 | 8192 210 | 3017 211 | 5330 212 | 5527 213 | 2663 214 | 99 215 | 8074 216 | 4924 217 | 7380 218 | 3879 219 | 2360 220 | 1755 221 | 8629 222 | 3377 223 | 6326 224 | 3241 225 | 5406 226 | 1299 227 | 4839 228 | 8207 229 | 5351 230 | 5103 231 | 5606 232 | 571 233 | 4781 234 | 6404 235 | 8737 236 | 9835 237 | 2734 238 | 2842 239 | 4183 240 | 4926 241 | 427 242 | 3346 243 | 930 244 | 7520 245 | 2708 246 | 7744 247 | 6682 248 | 3545 249 | 3790 250 | 4828 251 | 8280 252 | 3242 253 | 1533 254 | 8058 255 | 5026 256 | 7726 257 | 4645 258 | 1393 259 | 5757 260 | 3635 261 | 7651 262 | 647 263 | 4731 264 | 6572 265 | 4211 266 | 2699 267 | 9129 268 | 5389 269 | 5428 270 | 8637 271 | 3525 272 | 3139 273 | 8462 274 | 8853 275 | 7494 276 | 2123 277 | 6699 278 | 8747 279 | 9039 280 | 6230 281 | 4244 282 | 7351 283 | 7171 284 | 7235 285 | 6352 286 | 4434 287 | 113 288 | 7617 289 | 1108 290 | 1623 291 | 2464 292 | 7829 293 | 2090 294 | 821 295 | 396 296 | 7911 297 | 3541 298 | 7666 299 | 2920 300 | 5169 301 | -------------------------------------------------------------------------------- /tda_cnn_scripts/run_cnn_s_test.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import os 4 | from torch.utils.data import DataLoader 5 | from nn_influence_utils import compute_s_test 6 | from utils import (NetBW, NetRGB, load_seeds) 7 | from argparse import ArgumentParser 8 | 9 | 10 | def main(): 11 | parser = ArgumentParser() 12 | parser.add_argument('--seed_id', type=int, default=0) 13 | parser.add_argument('--task', type=str, default='mnist3', help='Either mnist3 or cifar10') 14 | parser.add_argument('--num_per_class', type=int, default=10, help='Number of samples per class that the model was trained on from {10,20,50}') 15 | args = parser.parse_args() 16 | 17 | # Load datasets and variables needed for the computation 18 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 19 | train_dataset = torch.load(f'{os.getcwd()}/../data/{args.task}/train_subset_{args.num_per_class}pc.pt') 20 | train_loader = DataLoader(train_dataset, batch_size=8) 21 | test_dataset = torch.load(f'{os.getcwd()}/../data/{args.task}/test_subset.pt') 22 | seeds = load_seeds() 23 | seed = seeds[args.seed_id] 24 | num_epochs = 15 if 'mnist' in args.task else 30 25 | ckpts = range(num_epochs-5, num_epochs) 26 | 27 | # Hyperparameters of s_test estimation 28 | s_test_num_samples= min(len(train_loader), 1000) 29 | s_test_damp=5e-3 30 | s_test_scale=1e4 31 | s_test_iterations = 1 32 | 33 | 34 | for num_ckpt in ckpts: 35 | s_tests = {} 36 | # Load trained model 37 | model = NetRGB() if train_dataset[0][0].shape[0]==3 else NetBW() 38 | ckpt = torch.load(f'{os.getcwd()}/../models/cnn/{args.task}_{args.num_per_class}pc/{seed}/ckpt_epoch_{num_ckpt}.pth') 39 | model.load_state_dict(ckpt['model_state_dict']) 40 | model.eval() 41 | 42 | # Set up save path and check if it already exists 43 | save_path = f'{os.getcwd()}/../tda_scores/cnn/if/{args.task}_{args.num_per_class}pc/{seed}/' 44 | if not os.path.exists(save_path): 45 | os.makedirs(save_path) 46 | 47 | for z_test in test_dataset: 48 | s_test = None 49 | for _ in range(s_test_iterations): 50 | _s_test = compute_s_test( 51 | n_gpu=1, 52 | device=device, 53 | model=model, 54 | test_inputs=z_test, 55 | train_data_loaders=[train_loader], 56 | params_filter= None, 57 | weight_decay= None, 58 | weight_decay_ignores= None, 59 | damp=s_test_damp, 60 | scale=s_test_scale, 61 | num_samples=s_test_num_samples, 62 | verbose=False) 63 | 64 | # Sum the values across runs 65 | if s_test is None: 66 | s_test = _s_test 67 | else: 68 | s_test = [ 69 | a + b for a, b in zip(s_test, _s_test) 70 | ] 71 | # Do the averaging 72 | s_test = [a / s_test_iterations for a in s_test] 73 | s_tests[z_test[2]] = s_test 74 | # Save s_test 75 | torch.save(s_tests, f'{save_path}/s_tests_ckpt_{num_ckpt}.pt') 76 | 77 | 78 | if __name__=="__main__": 79 | main() -------------------------------------------------------------------------------- /tda_cnn_scripts/run_cnn_ats.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import os 5 | from torch.utils.data import DataLoader 6 | import numpy as np 7 | import pandas as pd 8 | from argparse import ArgumentParser 9 | from utils import (NetBW, NetRGB, train_model, load_seeds, test_model) 10 | 11 | 12 | def main(): 13 | parser = ArgumentParser() 14 | parser.add_argument('--seed_id', type=int, default=0) 15 | parser.add_argument('--task', type=str, default='mnist3', help='Either mnist3 or cifar10') 16 | parser.add_argument('--num_per_class', type=int, default=10, help='Number of samples per class that the model was trained on from {10,20,50}') 17 | args = parser.parse_args() 18 | 19 | num_epochs = 15 if 'mnist' in args.task else 30 20 | ckpts = range(num_epochs-5, num_epochs) 21 | 22 | seeds = load_seeds() 23 | seed = seeds[args.seed_id] 24 | 25 | save_path = f"{os.getcwd()}/../tda_scores/cnn/ats/{args.task}_{args.num_per_class}pc/{seed}/" 26 | if not os.path.exists(save_path): 27 | os.makedirs(save_path) 28 | 29 | criterion = nn.CrossEntropyLoss(reduction='none') 30 | train_dataset = torch.load(f'{os.getcwd()}/../data/{args.task}/train_subset_{args.num_per_class}pc.pt') 31 | test_dataset = torch.load(f'{os.getcwd()}/../data/{args.task}/test_subset.pt') 32 | colnames = [f'z_test_{idx}' for _,_,idx in test_dataset] 33 | colnames.insert(0, 'train_idx') 34 | test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False) 35 | 36 | for num_ckpt in ckpts: 37 | df_ats = pd.DataFrame(columns=colnames) 38 | df_ats['train_idx'] = [idx for _,_,idx in train_dataset] 39 | for data, label, z_train_idx in train_dataset: 40 | # Load the model and get the initial loss values 41 | model = NetRGB() if train_dataset[0][0].shape[0]==3 else NetBW() 42 | ckpt = torch.load(f'{os.getcwd()}/../models/cnn/{args.task}_{args.num_per_class}pc/{seed}/ckpt_epoch_{num_ckpt}.pth') 43 | model.load_state_dict(ckpt['model_state_dict']) 44 | model.eval() 45 | 46 | test_loss, _, _ = test_model(model=model, 47 | test_loader=test_loader, 48 | criterion=criterion) 49 | 50 | # Set model to train mode 51 | model.train() 52 | optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.005) 53 | 54 | # Train for 1 step 55 | train_instance_loader = DataLoader([[data, label, z_train_idx]], batch_size=1, shuffle=False) 56 | model = train_model(model, 57 | train_loader=train_instance_loader, 58 | optimizer=optimizer, 59 | criterion=criterion, 60 | num_epochs=1) 61 | 62 | # Run the ATS model on the test set 63 | model.eval() 64 | loss, _, _ = test_model(model=model, 65 | test_loader=test_loader, 66 | criterion=criterion) 67 | 68 | # Record the loss change 69 | delta_loss = loss - test_loss 70 | row_idx = np.where(df_ats['train_idx']==z_train_idx)[0][0] 71 | df_ats.loc[row_idx, colnames[1]:] = delta_loss 72 | 73 | # Save 74 | df_ats.to_csv(f"{save_path}/attribution_ckpt_{num_ckpt}.csv", index=False) 75 | 76 | 77 | if __name__=="__main__": 78 | main() -------------------------------------------------------------------------------- /tda_vit_scripts/run_vit_s_test.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import os 4 | from argparse import ArgumentParser 5 | from torch.utils.data import DataLoader 6 | from nn_influence_utils_vit import compute_s_test 7 | from utils import load_seeds, ViTLoRA, load_vit_data 8 | import warnings 9 | warnings.filterwarnings('ignore') 10 | torch.autograd.set_detect_anomaly(True) 11 | 12 | 13 | def main(): 14 | parser = ArgumentParser() 15 | parser.add_argument('--seed_id', type=int, default=0) 16 | parser.add_argument('--task', type=str, default='mnist3', help='Either mnist3 or cifar10') 17 | parser.add_argument('--num_per_class', type=int, default=10, help='Number of samples per class that the model was trained on from {10,20,50}') 18 | args = parser.parse_args() 19 | 20 | # Load variables needed for the computation 21 | seeds = load_seeds() 22 | seed = seeds[args.seed_id] 23 | num_epochs = 15 if 'mnist' in args.task else 30 24 | ckpts = range(num_epochs-5, num_epochs) 25 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 26 | 27 | # Load datasets 28 | train_dataset, test_dataset = load_vit_data(args.experiment) 29 | def collate_fn(examples): 30 | pixel_values = torch.stack([example["pixel_values"] for example in examples]) 31 | labels = torch.tensor([example["label"] for example in examples]) 32 | idx = torch.tensor([example["idx"] for example in examples]) 33 | return {"pixel_values": pixel_values, "labels": labels, "idx": idx} 34 | train_loader = DataLoader(train_dataset, batch_size=32, collate_fn=collate_fn, shuffle=False) 35 | test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=collate_fn, shuffle=False) 36 | 37 | # Hyperparameters of s_test estimation 38 | s_test_num_samples= min(len(train_loader), 1000) 39 | s_test_damp=5e-3 40 | s_test_scale=1e4 41 | s_test_iterations = 1 42 | 43 | for num_ckpt in ckpts: 44 | save_path = f'{os.getcwd()}/../tda_scores/if/{args.task}_{args.num_per_class}/{seed}/' 45 | if not os.path.exists(save_path): 46 | os.makedirs(save_path) 47 | 48 | s_tests = {} 49 | # Load the model 50 | model = ViTLoRA(device) 51 | state_dict = torch.load(f'{os.getcwd()}/../models/vit/{args.task}_{args.num_per_class}pc/{seed}/ckpt_epoch_{num_ckpt}.pth') 52 | model.load_state_dict(state_dict) 53 | 54 | 55 | for z_test in test_loader: 56 | s_test = None 57 | for _ in range(s_test_iterations): 58 | _s_test, _ = compute_s_test( 59 | n_gpu=1, 60 | device=device, 61 | model=model, 62 | test_inputs=z_test, 63 | train_data_loaders=[train_loader], 64 | params_filter= None, 65 | weight_decay= None, 66 | weight_decay_ignores= None, 67 | damp=s_test_damp, 68 | scale=s_test_scale, 69 | num_samples=s_test_num_samples, 70 | verbose=False) 71 | 72 | # Sum the values across runs 73 | if s_test is None: 74 | s_test = _s_test 75 | else: 76 | s_test = [ 77 | a + b for a, b in zip(s_test, _s_test) 78 | ] 79 | # Do the averaging 80 | s_test = [a / s_test_iterations for a in s_test] # list of tensors 81 | s_tests[z_test['idx'].item()] = s_test 82 | 83 | # Save s_test and history 84 | torch.save(s_tests, f'{save_path}/s_tests_ckpt_{num_ckpt}.pt') 85 | 86 | 87 | if __name__=="__main__": 88 | main() -------------------------------------------------------------------------------- /tda_cnn_scripts/run_cnn_if.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from torch.utils.data import DataLoader 4 | from nn_influence_utils import compute_influences 5 | from utils import (NetBW, NetRGB, load_seeds) 6 | import pandas as pd 7 | from argparse import ArgumentParser 8 | 9 | 10 | def main(): 11 | parser = ArgumentParser() 12 | parser.add_argument('--seed_id', type=int, default=0) 13 | parser.add_argument('--task', type=str, default='mnist3', help='Either mnist3 or cifar10') 14 | parser.add_argument('--num_per_class', type=int, default=10, help='Number of samples per class that the model was trained on from {10,20,50}') 15 | args = parser.parse_args() 16 | 17 | # Load datasets and variables needed for the computation 18 | num_epochs = 15 if 'mnist' in args.task else 30 19 | ckpts = range(num_epochs-5, num_epochs) 20 | train_dataset = torch.load(f'{os.getcwd()}/../data/{args.task}/train_subset_{args.num_per_class}pc.pt') 21 | test_dataset = torch.load(f'{os.getcwd()}/../data/{args.task}/test_subset.pt') 22 | colnames = [f'z_test_{idx}' for _,_,idx in test_dataset] 23 | colnames.insert(0, 'train_idx') 24 | batch_train_data_loader = DataLoader(train_dataset, batch_size=8) 25 | instance_train_data_loader=DataLoader(train_dataset, batch_size=1) 26 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 27 | seeds = load_seeds() 28 | seed = seeds[args.seed_id] 29 | 30 | # Hyperparameters of s_test estimation 31 | s_test_num_samples= min(len(train_dataset), 1000) 32 | s_test_damp=5e-3 33 | s_test_scale=1e4 34 | s_test_iterations = 1 35 | 36 | for num_ckpt in ckpts: 37 | # Define save path and load precomputed s_test HVPs 38 | save_path = f"{os.getcwd()}/../tda_scores/cnn/if/{args.task}_{args.num_per_class}pc/{seed}/attribution_ckpt_{num_ckpt}.csv" 39 | s_test_path = f'{os.getcwd()}/../tda_scores/cnn/if/{args.task}_{args.num_per_class}pc/{seed}/' 40 | precomputed_s_tests = torch.load(f'{s_test_path}/s_tests_ckpt_{num_ckpt}.pt') 41 | 42 | # Load trained model 43 | model = NetRGB() if train_dataset[0][0].shape[0]==3 else NetBW() 44 | ckpt = torch.load(f'{os.getcwd()}/../models/cnn/{args.task}_{args.num_per_class}pc/{seed}/ckpt_epoch_{num_ckpt}.pth') 45 | model.load_state_dict(ckpt['model_state_dict']) 46 | model.eval() 47 | 48 | # Set up DataFrame for saving results 49 | df_if = pd.DataFrame() 50 | df_if['train_idx'] = [idx for _,_,idx in train_dataset] 51 | 52 | for z_test in test_dataset: 53 | z_test_idx = z_test[2] 54 | precomputed_s_test = precomputed_s_tests[z_test_idx] 55 | # Inluences is dict {train_sample_index: influence} will be of size num_training_samples 56 | influences = compute_influences(n_gpu=1, 57 | device=device, 58 | model=model, 59 | test_inputs=z_test, 60 | batch_train_data_loader=batch_train_data_loader, 61 | instance_train_data_loader=instance_train_data_loader, 62 | s_test_num_samples = s_test_num_samples, 63 | s_test_damp = s_test_damp, 64 | s_test_scale=s_test_scale, 65 | s_test_iterations=s_test_iterations, 66 | precomputed_s_test=precomputed_s_test, 67 | ) 68 | 69 | # Save influences 70 | df_if[f"z_test_{z_test_idx}"] = influences.values() 71 | df_if.to_csv(save_path, index=False) 72 | 73 | 74 | if __name__ == "__main__": 75 | main() -------------------------------------------------------------------------------- /tda_cnn_scripts/run_cnn_loo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import os 5 | from torch.utils.data import DataLoader 6 | import numpy as np 7 | import pandas as pd 8 | from utils import (NetBW, NetRGB, train_model, load_seeds, test_model) 9 | from argparse import ArgumentParser 10 | 11 | 12 | def main(): 13 | parser = ArgumentParser() 14 | parser.add_argument('--seed_id', type=int, default=0) 15 | parser.add_argument('--task', type=str, default='mnist3', help='Either mnist3 or cifar10') 16 | parser.add_argument('--num_per_class', type=int, default=10, help='Number of samples per class that the model was trained on from {10,20,50}') 17 | args = parser.parse_args() 18 | 19 | # Load datasets and variables needed for the computation 20 | num_epochs = 15 if 'mnist' in args.task else 30 21 | ckpts = range(num_epochs-5, num_epochs) 22 | train_dataset = torch.load(f'{os.getcwd()}/../data/{args.task}/train_subset_{args.num_per_class}pc.pt') 23 | test_dataset = torch.load(f'{os.getcwd()}/../data/{args.task}/test_subset.pt') 24 | colnames = [f'z_test_{idx}' for _,_,idx in test_dataset] 25 | colnames.insert(0, 'train_idx') 26 | test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False) 27 | seeds = load_seeds() 28 | seed = seeds[args.seed_id] 29 | criterion = nn.CrossEntropyLoss(reduction='none') 30 | 31 | # Set up save path for saving results 32 | save_path = f"{os.getcwd()}/../tda_scores/cnn/loo/{args.task}_{args.num_per_class}pc/{seed}/" 33 | if not os.path.exists(save_path): 34 | os.makedirs(save_path) 35 | 36 | for num_ckpt in ckpts: 37 | # Set up dataframe for results 38 | df_loo = pd.DataFrame(columns=colnames) 39 | df_loo['train_idx'] = [idx for _,_,idx in train_dataset] 40 | for _,_, z_train_idx in train_dataset: 41 | # Load the model and get the initial loss values 42 | model = NetRGB() if train_dataset[0][0].shape[0]==3 else NetBW() 43 | ckpt = torch.load(f'{os.getcwd()}/../models/cnn/{args.task}_{args.num_per_class}pc/{seed}/ckpt_epoch_{num_ckpt}.pth') 44 | model.load_state_dict(ckpt['model_state_dict']) 45 | model.eval() 46 | 47 | test_loss, _, _ = test_model(model=model, 48 | test_loader=test_loader, 49 | criterion=criterion) 50 | 51 | 52 | model.train() # Set model to train mode for retraining 53 | optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.005) # Same parameters as training 54 | 55 | # Set the random seed and load the training set 56 | torch.manual_seed(seed) 57 | train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) 58 | 59 | # Train with z_train_idx removed and record the checkpoints 60 | model = train_model(model, 61 | train_loader=train_loader, 62 | optimizer=optimizer, 63 | criterion=criterion, 64 | num_epochs=num_epochs, 65 | loo_idx=z_train_idx) 66 | 67 | # Run the loo model on the test set 68 | model.eval() 69 | loo_loss, _, _ = test_model(model=model, 70 | test_loader=test_loader, 71 | criterion=criterion) 72 | 73 | # Record the loss change 74 | delta_loss = loo_loss - test_loss 75 | row_idx = np.where(df_loo['train_idx']==z_train_idx)[0][0] 76 | df_loo.loc[row_idx, colnames[1]:] = delta_loss 77 | 78 | # Save 79 | df_loo.to_csv(f"{save_path}/attribution_ckpt_{num_ckpt}.csv", index=False) 80 | 81 | 82 | if __name__=="__main__": 83 | main() -------------------------------------------------------------------------------- /run_correlation_analysis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | from utils import get_mu_and_sigma, load_attribution_types 5 | from scipy.stats import spearmanr 6 | import seaborn as sns 7 | import matplotlib.pyplot as plt 8 | plt.style.use('ggplot') 9 | 10 | 11 | def save_corr_matrix(experiment, corr_matrix, save_name): 12 | # Set up the figure and axes 13 | fig, ax = plt.subplots(figsize=(4,4)) 14 | 15 | # Create a colormap 16 | cmap = sns.diverging_palette(h_neg=197, 17 | h_pos=24, 18 | as_cmap=True, 19 | center='light') 20 | 21 | # Plot the correlation matrix 22 | sns.heatmap(corr_matrix.round(2), vmin=-1, vmax=1, cmap=cmap, annot=True, annot_kws={"fontsize":14}, fmt=".2f", 23 | square=True, cbar=False, ax=ax) 24 | 25 | # Set the title and labels 26 | labels = ['LOO', 'ATS', 'IF', 'GD', 'GC'] 27 | ax.set_xticklabels(labels, rotation=45, horizontalalignment='right', fontsize=14) 28 | ax.set_yticklabels(labels, rotation=0, fontsize=14) 29 | 30 | # Show the plot 31 | plt.tight_layout() 32 | 33 | # Save the plot 34 | plt.savefig(f'{os.getcwd()}/results/{experiment}/{save_name}', dpi=300) 35 | 36 | 37 | def main(): 38 | taus = load_attribution_types() 39 | num_ckpts=5 # As TDA scores are most stable with max number of checkpoints, we set this to 5. 40 | experiments = ['cnn_mnist3_10pc', 'cnn_mnist3_20pc'] # Add experiments to this list that you want to analyse 41 | 42 | for experiment in experiments: 43 | task = experiment.split('_')[1] 44 | test_subset = torch.load(f'{os.getcwd()}/data/{task}/test_subset.pt') 45 | test_idx = [idx for _,_,idx in test_subset] 46 | 47 | mus = {} 48 | sigmas = {} 49 | ps ={} 50 | for tau in taus: 51 | # Compute mean, standard deviation and p-value 52 | mus[tau], sigmas[tau] = get_mu_and_sigma(tau=tau, num_ckpts=num_ckpts, experiment=experiment, test_idx=test_idx) 53 | mus[tau] = mus[tau].values.flatten() 54 | sigmas[tau] = sigmas[tau].values.flatten() 55 | ps[tau] = mus[tau]/sigmas[tau] 56 | 57 | # Compute Pearson correlation 58 | pearson_corr_mu = np.corrcoef(np.stack(mus.values()), rowvar=True) 59 | pearson_corr_sigma = np.corrcoef(np.stack(sigmas.values()), rowvar=True) 60 | pearson_corr_p = np.corrcoef(np.stack(ps.values()), rowvar=True) 61 | 62 | # Compute Spearman correlation 63 | spearman_corr_mu, _ = spearmanr(np.stack(mus.values()), axis=1) 64 | spearman_corr_sigma, _ = spearmanr(np.stack(sigmas.values()), axis=1) 65 | spearman_corr_p, _ = spearmanr(np.stack(ps.values()), axis=1) 66 | 67 | # Save correlation matrices 68 | save_corr_matrix(experiment=experiment, 69 | corr_matrix=spearman_corr_mu, 70 | save_name='spearman_corr_mu.pdf') 71 | 72 | save_corr_matrix(experiment=experiment, 73 | corr_matrix=spearman_corr_sigma, 74 | save_name='spearman_corr_sigma.pdf') 75 | 76 | save_corr_matrix(experiment=experiment, 77 | corr_matrix=spearman_corr_p, 78 | save_name='spearman_corr_p.pdf') 79 | 80 | save_corr_matrix(experiment=experiment, 81 | corr_matrix=pearson_corr_mu, 82 | save_name='pearson_corr_mu.pdf') 83 | 84 | save_corr_matrix(experiment=experiment, 85 | corr_matrix=pearson_corr_sigma, 86 | save_name='pearson_corr_sigma.pdf') 87 | 88 | save_corr_matrix(experiment=experiment, 89 | corr_matrix=pearson_corr_p, 90 | save_name='pearson_corr_p.pdf') 91 | 92 | 93 | if __name__=="__main__": 94 | main() -------------------------------------------------------------------------------- /data/cifar10/train_subset_40pc.txt: -------------------------------------------------------------------------------- 1 | 127 2 | 103 3 | 7313 4 | 4559 5 | 9121 6 | 5228 7 | 2660 8 | 7853 9 | 3143 10 | 2829 11 | 8616 12 | 7709 13 | 3570 14 | 8601 15 | 5588 16 | 3700 17 | 3810 18 | 6781 19 | 464 20 | 245 21 | 8279 22 | 4561 23 | 7569 24 | 6200 25 | 4553 26 | 2049 27 | 1607 28 | 7807 29 | 1468 30 | 8114 31 | 669 32 | 4637 33 | 77 34 | 5701 35 | 7496 36 | 6293 37 | 7627 38 | 3280 39 | 5525 40 | 9648 41 | 2153 42 | 2048 43 | 8583 44 | 4533 45 | 1004 46 | 5920 47 | 495 48 | 2356 49 | 1508 50 | 8676 51 | 1271 52 | 1671 53 | 1742 54 | 4539 55 | 4223 56 | 4720 57 | 1721 58 | 1223 59 | 3021 60 | 6737 61 | 4870 62 | 5466 63 | 8939 64 | 6950 65 | 214 66 | 6342 67 | 9930 68 | 9083 69 | 7320 70 | 2511 71 | 9101 72 | 5500 73 | 2354 74 | 5485 75 | 2921 76 | 4095 77 | 3630 78 | 3552 79 | 9291 80 | 1213 81 | 5173 82 | 6077 83 | 4470 84 | 5880 85 | 7434 86 | 4180 87 | 4748 88 | 3219 89 | 284 90 | 1932 91 | 9860 92 | 1010 93 | 7778 94 | 9561 95 | 5339 96 | 2966 97 | 2098 98 | 276 99 | 6437 100 | 4673 101 | 7096 102 | 1067 103 | 7994 104 | 6729 105 | 1973 106 | 8121 107 | 8163 108 | 406 109 | 3158 110 | 2812 111 | 5327 112 | 496 113 | 7987 114 | 1879 115 | 9890 116 | 1052 117 | 2368 118 | 3679 119 | 2095 120 | 9608 121 | 5556 122 | 9325 123 | 3010 124 | 3996 125 | 2584 126 | 4908 127 | 893 128 | 7445 129 | 2602 130 | 8573 131 | 3374 132 | 8368 133 | 4669 134 | 827 135 | 9863 136 | 7907 137 | 4668 138 | 5286 139 | 19 140 | 1204 141 | 1680 142 | 5434 143 | 4454 144 | 9920 145 | 8923 146 | 7423 147 | 6809 148 | 5149 149 | 9861 150 | 5727 151 | 4082 152 | 7159 153 | 5436 154 | 994 155 | 6731 156 | 3724 157 | 2121 158 | 2149 159 | 2508 160 | 8569 161 | 8593 162 | 2255 163 | 9016 164 | 5422 165 | 6166 166 | 9941 167 | 5723 168 | 9609 169 | 2220 170 | 8388 171 | 8282 172 | 6251 173 | 1969 174 | 308 175 | 6767 176 | 2334 177 | 6417 178 | 645 179 | 9227 180 | 5375 181 | 7889 182 | 629 183 | 2784 184 | 9510 185 | 9535 186 | 1564 187 | 8097 188 | 1141 189 | 4850 190 | 4080 191 | 8094 192 | 4532 193 | 9178 194 | 7946 195 | 7199 196 | 8342 197 | 1372 198 | 7893 199 | 4164 200 | 3678 201 | 1231 202 | 2126 203 | 8830 204 | 8750 205 | 9754 206 | 3208 207 | 4324 208 | 9516 209 | 222 210 | 9760 211 | 6649 212 | 9378 213 | 5916 214 | 2853 215 | 1601 216 | 7826 217 | 9409 218 | 6256 219 | 6469 220 | 3856 221 | 7499 222 | 2932 223 | 9324 224 | 8065 225 | 686 226 | 7058 227 | 8731 228 | 3166 229 | 1913 230 | 6198 231 | 7362 232 | 3344 233 | 3365 234 | 3368 235 | 3547 236 | 4796 237 | 7639 238 | 6971 239 | 4091 240 | 507 241 | 9348 242 | 6671 243 | 7085 244 | 9289 245 | 550 246 | 817 247 | 1120 248 | 518 249 | 3221 250 | 3268 251 | 3924 252 | 155 253 | 7959 254 | 7292 255 | 7834 256 | 8901 257 | 345 258 | 2316 259 | 1045 260 | 5504 261 | 8697 262 | 190 263 | 5229 264 | 7265 265 | 4587 266 | 6423 267 | 810 268 | 6014 269 | 4078 270 | 7446 271 | 1738 272 | 1577 273 | 4501 274 | 2219 275 | 1097 276 | 5796 277 | 9394 278 | 3355 279 | 1548 280 | 2890 281 | 1615 282 | 1635 283 | 5778 284 | 5168 285 | 812 286 | 6388 287 | 6276 288 | 5872 289 | 6600 290 | 2234 291 | 87 292 | 3944 293 | 1149 294 | 9468 295 | 6883 296 | 8419 297 | 7610 298 | 6910 299 | 2796 300 | 3061 301 | 9587 302 | 4664 303 | 620 304 | 1130 305 | 220 306 | 2770 307 | 371 308 | 1647 309 | 9013 310 | 8949 311 | 6144 312 | 5438 313 | 6404 314 | 7611 315 | 1477 316 | 198 317 | 5794 318 | 2062 319 | 4660 320 | 7426 321 | 2567 322 | 8928 323 | 1303 324 | 737 325 | 3296 326 | 1484 327 | 2917 328 | 9715 329 | 4443 330 | 4239 331 | 509 332 | 6736 333 | 8861 334 | 7403 335 | 5897 336 | 4573 337 | 1233 338 | 4344 339 | 5878 340 | 9858 341 | 9441 342 | 9074 343 | 5505 344 | 3098 345 | 6265 346 | 1674 347 | 6916 348 | 9911 349 | 580 350 | 7956 351 | 1541 352 | 117 353 | 1971 354 | 1645 355 | 4047 356 | 165 357 | 2778 358 | 4197 359 | 3069 360 | 7041 361 | 9424 362 | 569 363 | 6427 364 | 123 365 | 2984 366 | 7858 367 | 6604 368 | 2691 369 | 8924 370 | 3146 371 | 1254 372 | 3737 373 | 5708 374 | 84 375 | 6272 376 | 8404 377 | 3159 378 | 8768 379 | 3875 380 | 6707 381 | 931 382 | 7530 383 | 755 384 | 391 385 | 603 386 | 7911 387 | 2806 388 | 5800 389 | 9366 390 | 8982 391 | 2426 392 | 9624 393 | 2083 394 | 9137 395 | 618 396 | 1567 397 | 1257 398 | 6170 399 | 7206 400 | 4414 401 | -------------------------------------------------------------------------------- /tda_cnn_scripts/run_cnn_gd_and_gc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | from argparse import ArgumentParser 5 | import pandas as pd 6 | from tqdm import tqdm 7 | from utils import (NetBW, NetRGB, load_seeds, compute_gradient) 8 | 9 | 10 | def main(): 11 | parser = ArgumentParser() 12 | parser.add_argument('--seed_id', type=int, default=0) 13 | parser.add_argument('--task', type=str, default='mnist3', help='Either mnist3 or cifar10') 14 | parser.add_argument('--num_per_class', type=int, default=10, help='Number of samples per class that the model was trained on from {10,20,50}') 15 | args = parser.parse_args() 16 | 17 | # Load datasets and variables needed for the computation 18 | num_epochs = 15 if 'mnist' in args.task else 30 19 | ckpts = range(num_epochs-5, num_epochs) 20 | train_dataset = torch.load(f'{os.getcwd()}/../data/{args.task}/train_subset_{args.num_per_class}pc.pt') 21 | test_dataset = torch.load(f'{os.getcwd()}/../data/{args.task}/test_subset.pt') 22 | colnames = [f'z_test_{idx}' for _,_,idx in test_dataset] 23 | colnames.insert(0, 'train_idx') 24 | seeds = load_seeds() 25 | seed = seeds[args.seed_id] 26 | criterion = nn.CrossEntropyLoss(reduction='none') 27 | 28 | for num_ckpt in ckpts: 29 | # Load the trained model 30 | model = NetRGB() if train_dataset[0][0].shape[0]==3 else NetBW() 31 | ckpt = torch.load(f'{os.getcwd()}/../models/cnn/{args.task}_{args.num_per_class}pc/{seed}/ckpt_epoch_{num_ckpt}.pth') 32 | model.load_state_dict(ckpt['model_state_dict']) 33 | model.eval() 34 | 35 | # Set up save paths if they don't exist yet 36 | save_path_gd = f"{os.getcwd()}/../tda_scores/cnn/gd/{args.task}_{args.num_per_class}pc/{seed}/attribution_ckpt_{num_ckpt}.csv" 37 | save_path_gc = f"{os.getcwd()}/../tda_scores/cnn/gc/{args.task}_{args.num_per_class}pc/{seed}/attribution_ckpt_{num_ckpt}.csv" 38 | if not os.path.exists(os.path.split(save_path_gd)[0]): 39 | os.makedirs(os.path.split(save_path_gd)[0]) 40 | if not os.path.exists(os.path.split(save_path_gc)[0]): 41 | os.makedirs(os.path.split(save_path_gc)[0]) 42 | 43 | # Prepare dataframes for saving results 44 | df_gd = pd.DataFrame() 45 | df_gc = pd.DataFrame() 46 | df_gd['train_idx'] = [idx for _,_,idx in train_dataset] 47 | df_gc['train_idx'] = [idx for _,_,idx in train_dataset] 48 | 49 | # Prepare instance data loaders for computation 50 | test_instance_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False) 51 | train_instance_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=False) 52 | 53 | for z_test in tqdm(test_instance_loader): 54 | z_test_idx = z_test[2].cpu().item() 55 | gd = [] 56 | gc = [] 57 | 58 | # Compute gradient of z 59 | grad_z_test = compute_gradient(model=model, 60 | criterion=criterion, 61 | instance=z_test) 62 | flat_grad_z_test = torch.concat([layer_grad.flatten() for layer_grad in grad_z_test]) 63 | 64 | for z_train in train_instance_loader: 65 | # Compute gradient of zj 66 | grad_z_train = compute_gradient(model=model, 67 | criterion=criterion, 68 | instance=z_train) 69 | flat_grad_z_train = torch.concat([layer_grad.flatten() for layer_grad in grad_z_train]) 70 | 71 | # Compute dot product (Grad-Dot) 72 | grad_dot = torch.dot(flat_grad_z_test, flat_grad_z_train) 73 | gd.append(grad_dot.item()) 74 | 75 | # Compute cosine similarity (Grad-Cos) 76 | grad_cos = nn.functional.cosine_similarity(flat_grad_z_test, flat_grad_z_train, dim=0) 77 | gc.append(grad_cos.item()) 78 | 79 | df_gd[f'z_test_{z_test_idx}'] = gd 80 | df_gc[f'z_test_{z_test_idx}'] = gc 81 | 82 | df_gd.to_csv(save_path_gd, index=False) 83 | df_gc.to_csv(save_path_gc, index=False) 84 | 85 | 86 | if __name__=="__main__": 87 | main() -------------------------------------------------------------------------------- /tda_vit_scripts/run_vit_ats.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | from argparse import ArgumentParser 5 | from utils import load_seeds, ViTLoRA, test_vit, load_vit_data 6 | import pandas as pd 7 | from torch.utils.data import DataLoader 8 | 9 | 10 | def main(): 11 | parser = ArgumentParser() 12 | parser.add_argument('--seed_id', type=int, default=0) 13 | parser.add_argument('--task', type=str, default='mnist3', help='Either mnist3 or cifar10') 14 | parser.add_argument('--num_per_class', type=int, default=10, help='Number of samples per class that the model was trained on from {10,20,50}') 15 | args = parser.parse_args() 16 | 17 | # Load variables needed for the computation 18 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 19 | seeds = load_seeds() 20 | seed = seeds[args.seed_id] 21 | num_epochs = 15 if 'mnist' in args.task else 30 22 | ckpts = range(num_epochs-5, num_epochs) 23 | 24 | save_path = f"{os.getcwd()}/../tda_scores/vit/ats/{args.task}_{args.num_per_class}pc/{seed}/" 25 | if not os.path.exists(save_path): 26 | os.makedirs(save_path) 27 | 28 | # Load data 29 | trainset, testset = load_vit_data(args.task, args.num_per_class) 30 | def collate_fn(examples): 31 | pixel_values = torch.stack([example["pixel_values"] for example in examples]) 32 | labels = torch.tensor([example["label"] for example in examples]) 33 | idx = torch.tensor([example["idx"] for example in examples]) 34 | return {"pixel_values": pixel_values, "labels": labels, "idx": idx} 35 | 36 | # Set random seed and data loader 37 | torch.manual_seed(seed) 38 | test_loader = DataLoader(testset, batch_size=64, collate_fn=collate_fn, shuffle=False) 39 | 40 | colnames = [f"z_test_{instance['idx']}" for instance in testset] 41 | colnames.insert(0, 'train_idx') 42 | 43 | for num_ckpt in ckpts: 44 | # Set up dataframe for results 45 | df_ats = pd.DataFrame(columns=colnames) 46 | df_ats['train_idx'] = [instance['idx'] for instance in trainset] 47 | 48 | # Load the model trained with whole dataset 49 | model = ViTLoRA(device=device) 50 | state_dict = torch.load(f'{os.getcwd()}/../models/vit/{args.task}_{args.num_per_class}pc/{seed}/ckpt_epoch_{num_ckpt}.pth') 51 | model.load_state_dict(state_dict) 52 | 53 | # Get loss of the model trained on whole dataset 54 | _, full_loss, _ = test_vit(data_loader=test_loader, 55 | device=device, 56 | model=model) 57 | 58 | ##### Do one additional training step with the instance 59 | for instance in trainset: 60 | # Load the model 61 | model = ViTLoRA(device=device) 62 | state_dict = torch.load(f'{os.getcwd()}/../models/vit/{args.task}_{args.num_per_class}pc/{seed}/ckpt_epoch_{num_ckpt}.pth') 63 | model.load_state_dict(state_dict) 64 | 65 | # Set up training 66 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.005) 67 | model.train() 68 | 69 | # Train for one step on this instance 70 | inputs = instance['pixel_values'].unsqueeze(0) # To make the batch dimension 71 | labels = torch.tensor([instance['label']]) 72 | inputs = inputs.to(device) 73 | labels = labels.to(device) 74 | outputs = model(inputs, labels=labels) 75 | loss = outputs.loss 76 | loss = loss.mean() 77 | loss.backward() 78 | optimizer.step() 79 | optimizer.zero_grad() 80 | 81 | # Get loss of the model trained on dataset + one additional step on zj 82 | model.eval() 83 | _, ats_loss, _ = test_vit(data_loader=test_loader, 84 | device=device, 85 | model=model) 86 | 87 | # Record the loss change 88 | delta_loss = np.array(ats_loss) - np.array(full_loss) 89 | row_idx = np.where(df_ats['train_idx']==instance['idx'])[0][0] 90 | df_ats.loc[row_idx, colnames[1]:] = delta_loss 91 | df_ats.to_csv(f"{save_path}/attribution_ckpt_{num_ckpt}.csv", index=False) 92 | 93 | 94 | if __name__=="__main__": 95 | main() 96 | -------------------------------------------------------------------------------- /tda_vit_scripts/run_vit_gd_and_gc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import DataLoader 4 | from utils import (load_seeds, load_vit_data, ViTLoRA) 5 | from nn_influence_utils_vit import compute_gradients 6 | from argparse import ArgumentParser 7 | import os 8 | import pandas as pd 9 | 10 | 11 | def main(): 12 | parser = ArgumentParser() 13 | parser.add_argument('--seed_id', type=int, default=0) 14 | parser.add_argument('--task', type=str, default='mnist3', help='Either mnist3 or cifar10') 15 | parser.add_argument('--num_per_class', type=int, default=10, help='Number of samples per class that the model was trained on from {10,20,50}') 16 | args = parser.parse_args() 17 | 18 | # Load variables needed for the computation 19 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 20 | seeds = load_seeds() 21 | seed = seeds[args.seed_id] 22 | num_epochs = 15 if 'mnist' in args.task else 30 23 | ckpts = range(num_epochs-5, num_epochs) 24 | 25 | # Load the dataset 26 | train_dataset, test_dataset = load_vit_data(args.task, args.num_per_class) 27 | def collate_fn(examples): 28 | pixel_values = torch.stack([example["pixel_values"] for example in examples]) 29 | labels = torch.tensor([example["label"] for example in examples]) 30 | idx = torch.tensor([example["idx"] for example in examples]) 31 | return {"pixel_values": pixel_values, "labels": labels, "idx": idx} 32 | train_loader = DataLoader(train_dataset, batch_size=1, collate_fn=collate_fn, shuffle=False) 33 | test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=collate_fn, shuffle=False) 34 | 35 | # Set up save paths 36 | save_path_dp = f"{os.getcwd()}/../tda_scores/vit/gd/{args.task}_{args.num_per_class}pc/{seed}/" 37 | if not os.path.exists(save_path_dp): 38 | os.makedirs(save_path_dp) 39 | save_path_cs = f"{os.getcwd()}/../tda_scores/vit/gc/{args.task}_{args.num_per_class}pc/{seed}/" 40 | if not os.path.exists(save_path_cs): 41 | os.makedirs(save_path_cs) 42 | 43 | for num_ckpt in ckpts: 44 | # Set up dataframes for results 45 | df_dp = pd.DataFrame() 46 | df_dp['train_idx'] = [instance['idx'] for instance in train_dataset] 47 | df_cos = pd.DataFrame() 48 | df_cos['train_idx'] = [instance['idx'] for instance in train_dataset] 49 | 50 | # Load trained model 51 | model = ViTLoRA(device=device) 52 | state_dict = torch.load(f'{os.getcwd()}/../models/vit/{args.task}_{args.num_per_class}pc/{seed}/ckpt_epoch_{num_ckpt}.pth') 53 | model.load_state_dict(state_dict) 54 | 55 | for z_test in test_loader: 56 | z_test_idx = z_test['idx'].item() 57 | 58 | dp_attribution = [] 59 | cs_attribution = [] 60 | 61 | # Compute gradient of z 62 | grad_z_test = compute_gradients(device=device, 63 | model=model, 64 | inputs=z_test, 65 | params_filter=None) 66 | flat_grad_z_test = torch.concat([layer_grad.flatten() for layer_grad in grad_z_test]) 67 | for z_train in train_loader: 68 | # Compute gradient of zj 69 | grad_z_train = compute_gradients(device=device, 70 | model=model, 71 | inputs=z_train, 72 | params_filter=None) 73 | flat_grad_z_train = torch.concat([layer_grad.flatten() for layer_grad in grad_z_train]) 74 | 75 | # Compute dot product 76 | grad_dot = torch.dot(flat_grad_z_test, flat_grad_z_train) 77 | dp_attribution.append(grad_dot.item()) 78 | 79 | # Comput cosine similarity 80 | grad_cos = nn.functional.cosine_similarity(flat_grad_z_test, flat_grad_z_train, dim=0) 81 | cs_attribution.append(grad_cos.item()) 82 | 83 | 84 | df_dp[f'z_test_{z_test_idx}'] = dp_attribution 85 | df_cos[f'z_test_{z_test_idx}'] = cs_attribution 86 | df_dp.to_csv(os.path.join(save_path_dp, f"attribution_ckpt_{num_ckpt}.csv"), index=False) 87 | df_cos.to_csv(os.path.join(save_path_cs, f"attribution_ckpt_{num_ckpt}.csv"), index=False) 88 | 89 | 90 | if __name__=="__main__": 91 | main() -------------------------------------------------------------------------------- /tda_vit_scripts/run_vit_loo.py: -------------------------------------------------------------------------------- 1 | from transformers import get_linear_schedule_with_warmup 2 | import numpy as np 3 | import torch 4 | import tqdm 5 | import os 6 | from argparse import ArgumentParser 7 | from utils import load_seeds, ViTLoRA, test_vit, load_vit_data 8 | import pandas as pd 9 | from torch.utils.data import DataLoader 10 | 11 | 12 | def main(): 13 | parser = ArgumentParser() 14 | parser.add_argument('--seed_id', type=int, default=0) 15 | parser.add_argument('--task', type=str, default='mnist3', help='Either mnist3 or cifar10') 16 | parser.add_argument('--num_per_class', type=int, default=10, help='Number of samples per class that the model was trained on from {10,20,50}') 17 | args = parser.parse_args() 18 | 19 | # Load variables needed for the computation 20 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 21 | seeds = load_seeds() 22 | seed = seeds[args.seed_id] 23 | num_epochs = 15 if 'mnist' in args.task else 30 24 | ckpts = range(num_epochs-5, num_epochs) 25 | 26 | # Set up save folder 27 | save_path = f"{os.getcwd()}/../tda_scores/vit/loo/{args.task}_{args.num_per_class}pc/{seed}/" 28 | if not os.path.exists(save_path): 29 | os.makedirs(save_path) 30 | 31 | # Load data 32 | trainset, testset = load_vit_data(args.task, args.num_per_class) 33 | 34 | def collate_fn(examples): 35 | pixel_values = torch.stack([example["pixel_values"] for example in examples]) 36 | labels = torch.tensor([example["label"] for example in examples]) 37 | idx = torch.tensor([example["idx"] for example in examples]) 38 | return {"pixel_values": pixel_values, "labels": labels, "idx": idx} 39 | 40 | # Set random seed and data loaders 41 | torch.manual_seed(seed) 42 | train_loader = DataLoader(trainset, batch_size=32, collate_fn=collate_fn, shuffle=True) 43 | test_loader = DataLoader(testset, batch_size=64, collate_fn=collate_fn, shuffle=False) 44 | 45 | colnames = [f"z_test_{instance['idx']}" for instance in testset] 46 | colnames.insert(0, 'train_idx') 47 | 48 | for num_ckpt in ckpts: 49 | df_loo = pd.DataFrame(columns=colnames) 50 | df_loo['train_idx'] = [instance['idx'] for instance in trainset] 51 | 52 | # Load the model trained with whole dataset 53 | model = ViTLoRA(device=device) 54 | state_dict = torch.load(f'{os.getcwd()}/../models/vit/{args.task}_{args.num_per_class}pc/{seed}/ckpt_epoch_{num_ckpt}.pth') 55 | model.load_state_dict(state_dict) 56 | 57 | # Get loss and preds of the model trained on whole dataset 58 | _, full_loss, _ = test_vit(data_loader=test_loader, 59 | device=device, 60 | model=model) 61 | 62 | ##### Retrain the model without the training instance 63 | for instance in trainset: 64 | # Load the pretrained model 65 | model = ViTLoRA(device=device) 66 | 67 | # Set up training 68 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.005) 69 | lr_scheduler = get_linear_schedule_with_warmup( 70 | optimizer=optimizer, 71 | num_warmup_steps=0, 72 | num_training_steps=(len(train_loader) * num_epochs), 73 | ) 74 | model.train() 75 | 76 | # Train 77 | for _ in range(num_epochs): 78 | for batch in tqdm.tqdm(train_loader): 79 | inputs = batch['pixel_values'] 80 | labels = batch['labels'] 81 | inputs = inputs.to(device) 82 | labels = labels.to(device) 83 | outputs = model(inputs, labels=labels) 84 | loss = outputs.loss 85 | # If the train instance to remove is in this batch, zero out the loss 86 | if instance['idx'] in batch['idx']: 87 | idx_to_remove = torch.where(instance['idx'] == batch['idx'])[0].item() 88 | loss[idx_to_remove] = 0 89 | loss = loss.mean() 90 | loss.backward() 91 | optimizer.step() 92 | lr_scheduler.step() 93 | optimizer.zero_grad() 94 | 95 | # Get loss and preds of the model trained on dataset with one removed 96 | model.eval() 97 | _, loo_loss, _ = test_vit(data_loader=test_loader, 98 | device=device, 99 | model=model) 100 | 101 | # Record the loss change 102 | delta_loss = np.array(loo_loss) - np.array(full_loss) 103 | row_idx = np.where(df_loo['train_idx']==instance['idx'])[0][0] 104 | df_loo.loc[row_idx, colnames[1]:] = delta_loss 105 | # Save 106 | df_loo.to_csv(f"{save_path}/attribution_ckpt_{num_ckpt}.csv", index=False) 107 | 108 | 109 | 110 | if __name__=="__main__": 111 | main() 112 | -------------------------------------------------------------------------------- /data/cifar10/test_subset.txt: -------------------------------------------------------------------------------- 1 | 3390 2 | 9142 3 | 9218 4 | 7045 5 | 7401 6 | 9901 7 | 866 8 | 3113 9 | 9087 10 | 9722 11 | 5185 12 | 5200 13 | 4106 14 | 332 15 | 3035 16 | 5228 17 | 4900 18 | 7933 19 | 6298 20 | 3674 21 | 2350 22 | 456 23 | 8754 24 | 9255 25 | 4248 26 | 4808 27 | 8337 28 | 4987 29 | 4754 30 | 6126 31 | 6021 32 | 1525 33 | 7377 34 | 1179 35 | 870 36 | 4876 37 | 8326 38 | 9093 39 | 7518 40 | 1979 41 | 3900 42 | 3065 43 | 1605 44 | 2798 45 | 6553 46 | 7174 47 | 5272 48 | 6647 49 | 6781 50 | 2598 51 | 9792 52 | 6743 53 | 5954 54 | 8234 55 | 5879 56 | 3749 57 | 6067 58 | 5140 59 | 1044 60 | 6453 61 | 3270 62 | 1366 63 | 3842 64 | 2052 65 | 9000 66 | 1142 67 | 6449 68 | 6070 69 | 8290 70 | 5325 71 | 6947 72 | 3974 73 | 6526 74 | 3519 75 | 8186 76 | 9936 77 | 7261 78 | 3683 79 | 7922 80 | 9619 81 | 6758 82 | 3695 83 | 4813 84 | 8905 85 | 1600 86 | 6411 87 | 4566 88 | 5034 89 | 7193 90 | 2043 91 | 5338 92 | 2893 93 | 5104 94 | 859 95 | 721 96 | 4444 97 | 4735 98 | 358 99 | 6611 100 | 5766 101 | 6191 102 | 7803 103 | 4953 104 | 1010 105 | 9469 106 | 4792 107 | 6055 108 | 4426 109 | 524 110 | 6328 111 | 5392 112 | 7774 113 | 180 114 | 8851 115 | 4037 116 | 6663 117 | 6574 118 | 9822 119 | 7010 120 | 4453 121 | 3100 122 | 3171 123 | 5460 124 | 8961 125 | 9664 126 | 9577 127 | 3845 128 | 2614 129 | 6418 130 | 9046 131 | 1161 132 | 9271 133 | 689 134 | 7356 135 | 3321 136 | 3112 137 | 1694 138 | 9561 139 | 4056 140 | 9890 141 | 5036 142 | 1103 143 | 3897 144 | 7164 145 | 1718 146 | 4192 147 | 5255 148 | 376 149 | 2979 150 | 929 151 | 1722 152 | 4672 153 | 8090 154 | 6309 155 | 672 156 | 5740 157 | 292 158 | 963 159 | 7870 160 | 5676 161 | 4179 162 | 4130 163 | 7006 164 | 617 165 | 1669 166 | 146 167 | 8534 168 | 4465 169 | 2584 170 | 7629 171 | 7205 172 | 6560 173 | 5501 174 | 2772 175 | 5876 176 | 1455 177 | 6239 178 | 3373 179 | 2666 180 | 5607 181 | 3711 182 | 788 183 | 681 184 | 5949 185 | 8808 186 | 670 187 | 6801 188 | 1634 189 | 9068 190 | 8604 191 | 1211 192 | 3775 193 | 4279 194 | 2016 195 | 827 196 | 2275 197 | 4 198 | 7687 199 | 8781 200 | 7467 201 | 3663 202 | 2142 203 | 8850 204 | 7173 205 | 4831 206 | 2378 207 | 1811 208 | 5741 209 | 3612 210 | 4607 211 | 2086 212 | 9588 213 | 8945 214 | 4285 215 | 7133 216 | 9909 217 | 2477 218 | 7276 219 | 5347 220 | 4128 221 | 796 222 | 1141 223 | 8304 224 | 5805 225 | 7473 226 | 987 227 | 9436 228 | 9559 229 | 1925 230 | 871 231 | 5380 232 | 9811 233 | 513 234 | 9894 235 | 3730 236 | 9113 237 | 5678 238 | 7604 239 | 7413 240 | 3560 241 | 8910 242 | 1970 243 | 973 244 | 6915 245 | 5918 246 | 8438 247 | 7759 248 | 7225 249 | 1880 250 | 6100 251 | 9938 252 | 1732 253 | 1841 254 | 4656 255 | 364 256 | 3099 257 | 7597 258 | 1914 259 | 34 260 | 3163 261 | 5432 262 | 2744 263 | 4796 264 | 6017 265 | 3312 266 | 47 267 | 4932 268 | 1221 269 | 5166 270 | 9791 271 | 2190 272 | 9898 273 | 7073 274 | 7805 275 | 2868 276 | 3770 277 | 8915 278 | 2244 279 | 3501 280 | 4026 281 | 6963 282 | 8380 283 | 3314 284 | 4255 285 | 8726 286 | 2765 287 | 3644 288 | 7639 289 | 7058 290 | 644 291 | 5391 292 | 9718 293 | 9147 294 | 3504 295 | 4261 296 | 1629 297 | 6034 298 | 2204 299 | 8814 300 | 9717 301 | 431 302 | 4160 303 | 9188 304 | 640 305 | 9586 306 | 2219 307 | 5828 308 | 5613 309 | 9300 310 | 550 311 | 1774 312 | 5820 313 | 2187 314 | 1571 315 | 5540 316 | 4917 317 | 5570 318 | 5447 319 | 9160 320 | 7368 321 | 5611 322 | 6677 323 | 5102 324 | 7244 325 | 6671 326 | 7545 327 | 7267 328 | 7294 329 | 3369 330 | 8084 331 | 1230 332 | 7794 333 | 4017 334 | 483 335 | 1948 336 | 5603 337 | 6257 338 | 4747 339 | 7972 340 | 4552 341 | 4699 342 | 5337 343 | 9264 344 | 4696 345 | 3808 346 | 2590 347 | 6683 348 | 3236 349 | 6993 350 | 9465 351 | 6612 352 | 3919 353 | 2658 354 | 892 355 | 2732 356 | 7003 357 | 9103 358 | 9672 359 | 5824 360 | 3241 361 | 2067 362 | 3917 363 | 2640 364 | 8733 365 | 3417 366 | 3101 367 | 5045 368 | 8865 369 | 6791 370 | 1951 371 | 6324 372 | 7658 373 | 8376 374 | 8088 375 | 8587 376 | 5114 377 | 1426 378 | 571 379 | 7507 380 | 2583 381 | 5346 382 | 5546 383 | 3337 384 | 8320 385 | 7837 386 | 1476 387 | 8124 388 | 826 389 | 795 390 | 7281 391 | 1708 392 | 1884 393 | 4853 394 | 3104 395 | 7293 396 | 3461 397 | 5403 398 | 5371 399 | 2776 400 | 9286 401 | 9524 402 | 5150 403 | 7201 404 | 2562 405 | 1063 406 | 4734 407 | 971 408 | 7383 409 | 4399 410 | 4448 411 | 1118 412 | 1777 413 | 2195 414 | 1558 415 | 8944 416 | 188 417 | 2699 418 | 4828 419 | 6362 420 | 6308 421 | 6108 422 | 2386 423 | 1401 424 | 479 425 | 3016 426 | 1448 427 | 1693 428 | 5510 429 | 5854 430 | 5237 431 | 5283 432 | 1375 433 | 465 434 | 7298 435 | 3636 436 | 4132 437 | 6422 438 | 9163 439 | 5018 440 | 595 441 | 3360 442 | 9116 443 | 7253 444 | 1503 445 | 9549 446 | 8667 447 | 7044 448 | 6032 449 | 8354 450 | 5496 451 | 7650 452 | 9171 453 | 7286 454 | 3363 455 | 3506 456 | 129 457 | 8085 458 | 6679 459 | 5133 460 | 537 461 | 3824 462 | 3783 463 | 1240 464 | 1587 465 | 2592 466 | 3656 467 | 2735 468 | 2691 469 | 3978 470 | 563 471 | 7663 472 | 9154 473 | 1481 474 | 2312 475 | 5599 476 | 4955 477 | 4903 478 | 249 479 | 8454 480 | 5244 481 | 4034 482 | 6533 483 | 7951 484 | 1110 485 | 9267 486 | 497 487 | 3534 488 | 4359 489 | 307 490 | 6387 491 | 6040 492 | 4833 493 | 3991 494 | 156 495 | 8252 496 | 7624 497 | 1495 498 | 7000 499 | 2391 500 | 6881 501 | -------------------------------------------------------------------------------- /tda_vit_scripts/run_vit_if.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import os 4 | import pandas as pd 5 | from argparse import ArgumentParser 6 | from torch.utils.data import DataLoader 7 | from nn_influence_utils_vit import compute_influences, compute_gradients 8 | from utils import load_seeds, ViTLoRA, load_vit_data 9 | from tqdm import tqdm 10 | 11 | 12 | def main(): 13 | parser = ArgumentParser() 14 | parser.add_argument('--seed_id', type=int, default=0) 15 | parser.add_argument('--task', type=str, default='mnist3', help='Either mnist3 or cifar10') 16 | parser.add_argument('--num_per_class', type=int, default=10, help='Number of samples per class that the model was trained on from {10,20,50}') 17 | args = parser.parse_args() 18 | 19 | # Load variables needed for the computation 20 | seeds = load_seeds() 21 | seed = seeds[args.seed_id] 22 | num_epochs = 15 if 'mnist' in args.task else 30 23 | ckpts = range(num_epochs-5, num_epochs) 24 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 25 | 26 | # Load datasets 27 | train_dataset, test_dataset = load_vit_data(args.task, args.num_per_class) 28 | def collate_fn(examples): 29 | pixel_values = torch.stack([example["pixel_values"] for example in examples]) 30 | labels = torch.tensor([example["label"] for example in examples]) 31 | idx = torch.tensor([example["idx"] for example in examples]) 32 | return {"pixel_values": pixel_values, "labels": labels, "idx": idx} 33 | train_loader = DataLoader(train_dataset, batch_size=32, collate_fn=collate_fn, shuffle=True) 34 | instance_train_data_loader=DataLoader(train_dataset, batch_size=1, collate_fn=collate_fn, shuffle=False) 35 | test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=collate_fn, shuffle=False) 36 | 37 | # Hyperparameters of s_test estimation 38 | s_test_num_samples= min(len(train_loader), 1000) 39 | s_test_damp=5e-3 40 | s_test_scale=1e4 41 | s_test_iterations = 1 42 | 43 | s_test_path = f'{os.getcwd()}/../tda_scores/vit/if/{args.task}/{seed}/' 44 | 45 | for num_ckpt in ckpts: 46 | # If s_test was precomputed, load it. This may break the available GPU memory, in that case do not provide precomputed s_tests. 47 | if os.path.exists(f'{s_test_path}/s_tests_ckpt_{num_ckpt}.pt'): 48 | precomputed_s_tests = torch.load(f'{s_test_path}/s_tests_ckpt_{num_ckpt}.pt') 49 | else: 50 | precomputed_s_tests = None 51 | 52 | # Load the model 53 | model = ViTLoRA(device=device) 54 | state_dict = torch.load(f'{os.getcwd()}/../models/vit/{args.task}_{args.num_per_class}pc/{seed}/ckpt_epoch_{num_ckpt}.pth') 55 | model.load_state_dict(state_dict) 56 | 57 | save_path = f'{os.getcwd()}/../tda_scores/vit/if/{args.task}_{args.num_per_class}pc/{seed}/attribution_ckpt_{num_ckpt}.csv' 58 | if os.path.exists(save_path): 59 | df_attribution = pd.read_csv(save_path, index_col=False) 60 | if df_attribution.shape == (len(train_dataset), len(test_dataset)+1): 61 | continue 62 | else: 63 | finished_z_test = [eval(colname.split('_')[-1]) for colname in df_attribution.columns[1:]] 64 | else: 65 | df_attribution = pd.DataFrame() 66 | df_attribution['train_idx'] = [instance['idx'] for instance in train_dataset] 67 | finished_z_test = [] 68 | 69 | # Compute the z_train gradients 70 | grads_zj = {} 71 | for train_inputs in tqdm(instance_train_data_loader): 72 | grad_zj = compute_gradients( 73 | n_gpu=1, 74 | device=device, 75 | model=model, 76 | inputs=train_inputs, 77 | params_filter=None, 78 | weight_decay=None, 79 | weight_decay_ignores=None) 80 | grads_zj[train_inputs['idx'].item()] = grad_zj 81 | 82 | # Compute influences 83 | for z_test in test_loader: 84 | z_test_idx = z_test['idx'].item() 85 | if z_test_idx in finished_z_test: 86 | continue 87 | if precomputed_s_tests is not None: 88 | precomputed_s_test = precomputed_s_tests[z_test_idx] 89 | else: 90 | precomputed_s_test = None 91 | 92 | influences = compute_influences(n_gpu=1, 93 | device=device, 94 | model=model, 95 | test_inputs=z_test, 96 | batch_train_data_loader=train_loader, 97 | instance_train_data_loader=instance_train_data_loader, 98 | s_test_num_samples=s_test_num_samples, 99 | s_test_iterations=s_test_iterations, 100 | s_test_scale=s_test_scale, 101 | s_test_damp=s_test_damp, 102 | precomputed_s_test=precomputed_s_test, 103 | precomputed_grad_zjs=grads_zj) 104 | # save influences 105 | df_attribution[f"z_test_{z_test_idx}"] = influences.values() 106 | df_attribution.to_csv(save_path, index=False) 107 | torch.cuda.empty_cache() 108 | 109 | 110 | if __name__=="__main__": 111 | main() -------------------------------------------------------------------------------- /data/cifar10/train_subset_50pc.txt: -------------------------------------------------------------------------------- 1 | 21030 2 | 18494 3 | 20706 4 | 37694 5 | 24312 6 | 5374 7 | 45078 8 | 40993 9 | 7175 10 | 15110 11 | 41622 12 | 49962 13 | 48593 14 | 3730 15 | 45135 16 | 33883 17 | 1047 18 | 23363 19 | 17960 20 | 33541 21 | 40455 22 | 43588 23 | 49210 24 | 7133 25 | 30812 26 | 42780 27 | 38922 28 | 31325 29 | 7538 30 | 32793 31 | 34986 32 | 7403 33 | 36419 34 | 5711 35 | 38831 36 | 6536 37 | 8937 38 | 23251 39 | 38773 40 | 3976 41 | 46991 42 | 39188 43 | 2992 44 | 11851 45 | 43761 46 | 12411 47 | 34338 48 | 19927 49 | 39657 50 | 35118 51 | 27570 52 | 3652 53 | 34423 54 | 10473 55 | 21592 56 | 12456 57 | 25963 58 | 20402 59 | 13919 60 | 39199 61 | 42739 62 | 3061 63 | 5923 64 | 15249 65 | 5388 66 | 30379 67 | 9396 68 | 33538 69 | 24175 70 | 28511 71 | 41396 72 | 17447 73 | 42471 74 | 42764 75 | 8835 76 | 46059 77 | 39830 78 | 44009 79 | 24601 80 | 27548 81 | 44860 82 | 20095 83 | 9425 84 | 8295 85 | 14626 86 | 32717 87 | 3007 88 | 48734 89 | 32489 90 | 29604 91 | 25510 92 | 39837 93 | 45528 94 | 26805 95 | 39065 96 | 3129 97 | 46278 98 | 36045 99 | 44547 100 | 19365 101 | 36345 102 | 24306 103 | 25113 104 | 14305 105 | 5074 106 | 32643 107 | 1826 108 | 23779 109 | 21733 110 | 23474 111 | 25834 112 | 9962 113 | 34410 114 | 49595 115 | 38465 116 | 39495 117 | 10134 118 | 45702 119 | 14644 120 | 23350 121 | 21454 122 | 36234 123 | 16878 124 | 12698 125 | 31585 126 | 23755 127 | 44717 128 | 6270 129 | 30345 130 | 36690 131 | 4987 132 | 30156 133 | 47928 134 | 36420 135 | 10282 136 | 42795 137 | 37389 138 | 7154 139 | 4400 140 | 14744 141 | 26100 142 | 38019 143 | 1904 144 | 29111 145 | 42982 146 | 49718 147 | 19548 148 | 42072 149 | 27234 150 | 17845 151 | 38174 152 | 31245 153 | 27709 154 | 29555 155 | 16598 156 | 25652 157 | 39118 158 | 17492 159 | 20761 160 | 6889 161 | 21181 162 | 10795 163 | 17534 164 | 35516 165 | 13144 166 | 17334 167 | 8973 168 | 19063 169 | 12946 170 | 35850 171 | 44185 172 | 16272 173 | 26059 174 | 40172 175 | 21675 176 | 44019 177 | 21328 178 | 17268 179 | 11663 180 | 3020 181 | 4636 182 | 31831 183 | 21384 184 | 49045 185 | 28818 186 | 5747 187 | 35086 188 | 43356 189 | 48970 190 | 17815 191 | 20844 192 | 20244 193 | 19351 194 | 44210 195 | 6489 196 | 835 197 | 4631 198 | 35567 199 | 46893 200 | 13999 201 | 34497 202 | 19611 203 | 45451 204 | 2860 205 | 46125 206 | 41551 207 | 42696 208 | 10926 209 | 11936 210 | 27091 211 | 46481 212 | 29470 213 | 48102 214 | 37742 215 | 48314 216 | 29087 217 | 37797 218 | 46686 219 | 46365 220 | 6392 221 | 4931 222 | 283 223 | 44622 224 | 48290 225 | 23441 226 | 43747 227 | 13557 228 | 20138 229 | 36700 230 | 803 231 | 38244 232 | 36072 233 | 28611 234 | 4598 235 | 46516 236 | 18987 237 | 46877 238 | 11873 239 | 21940 240 | 40908 241 | 30631 242 | 46931 243 | 39528 244 | 2346 245 | 2062 246 | 42112 247 | 30232 248 | 38705 249 | 2133 250 | 13130 251 | 40295 252 | 11846 253 | 133 254 | 3013 255 | 17463 256 | 46946 257 | 5436 258 | 45054 259 | 35383 260 | 39835 261 | 30581 262 | 26196 263 | 14086 264 | 6497 265 | 6257 266 | 18823 267 | 7969 268 | 32239 269 | 23599 270 | 38196 271 | 26444 272 | 21371 273 | 41095 274 | 27187 275 | 3138 276 | 25454 277 | 43168 278 | 45112 279 | 24449 280 | 36135 281 | 20094 282 | 2263 283 | 35341 284 | 12394 285 | 6865 286 | 31073 287 | 45493 288 | 7682 289 | 29688 290 | 10891 291 | 15058 292 | 10073 293 | 12864 294 | 2539 295 | 39807 296 | 43272 297 | 589 298 | 48948 299 | 33559 300 | 2492 301 | 8283 302 | 8014 303 | 11832 304 | 38969 305 | 36747 306 | 46758 307 | 39707 308 | 5675 309 | 47827 310 | 15948 311 | 35985 312 | 35109 313 | 40110 314 | 22576 315 | 42220 316 | 31183 317 | 49553 318 | 15843 319 | 49393 320 | 18700 321 | 14718 322 | 20500 323 | 44458 324 | 18236 325 | 6403 326 | 26484 327 | 32323 328 | 23883 329 | 40609 330 | 6346 331 | 13641 332 | 19374 333 | 31762 334 | 40292 335 | 44236 336 | 2923 337 | 246 338 | 24138 339 | 48569 340 | 13886 341 | 9324 342 | 31332 343 | 34104 344 | 8358 345 | 20762 346 | 7222 347 | 47206 348 | 5273 349 | 25580 350 | 20182 351 | 45026 352 | 23862 353 | 46977 354 | 11293 355 | 4010 356 | 12190 357 | 33635 358 | 34146 359 | 29247 360 | 17785 361 | 9166 362 | 20339 363 | 42779 364 | 241 365 | 45502 366 | 36027 367 | 46801 368 | 31912 369 | 9060 370 | 17661 371 | 31503 372 | 17879 373 | 7749 374 | 35024 375 | 33779 376 | 47819 377 | 21687 378 | 21697 379 | 9182 380 | 7872 381 | 7107 382 | 19781 383 | 30389 384 | 10618 385 | 42496 386 | 6557 387 | 11028 388 | 22826 389 | 13330 390 | 33915 391 | 30479 392 | 26554 393 | 9222 394 | 8246 395 | 4921 396 | 4982 397 | 20917 398 | 35026 399 | 39526 400 | 25138 401 | 29390 402 | 30711 403 | 11493 404 | 17392 405 | 23587 406 | 10670 407 | 47334 408 | 13161 409 | 29603 410 | 22911 411 | 2108 412 | 12051 413 | 20382 414 | 24651 415 | 5840 416 | 29757 417 | 9274 418 | 48625 419 | 33800 420 | 23432 421 | 7873 422 | 3163 423 | 10503 424 | 4486 425 | 31028 426 | 32641 427 | 15378 428 | 49988 429 | 20076 430 | 27032 431 | 22021 432 | 22893 433 | 25875 434 | 24323 435 | 40932 436 | 21858 437 | 15354 438 | 26871 439 | 5981 440 | 41720 441 | 49883 442 | 20776 443 | 33001 444 | 8894 445 | 32118 446 | 45134 447 | 46997 448 | 18732 449 | 1545 450 | 38966 451 | 37881 452 | 23024 453 | 35873 454 | 34280 455 | 24600 456 | 48419 457 | 47431 458 | 35201 459 | 31369 460 | 7550 461 | 16866 462 | 21386 463 | 39397 464 | 22622 465 | 17395 466 | 9611 467 | 12836 468 | 25160 469 | 19445 470 | 33162 471 | 28036 472 | 11514 473 | 28380 474 | 10589 475 | 13627 476 | 16010 477 | 18218 478 | 18745 479 | 14210 480 | 43659 481 | 11633 482 | 13230 483 | 41242 484 | 49346 485 | 43301 486 | 27202 487 | 47766 488 | 2322 489 | 41868 490 | 38783 491 | 46277 492 | 29754 493 | 22688 494 | 30172 495 | 23644 496 | 2804 497 | 940 498 | 41266 499 | 9413 500 | 22276 501 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Bayesian Approach To Analysing Training Data Attribution In Deep Learning 2 | 3 | #### Elisa Nguyen, Minjoon Seo, Seong Joon Oh 4 | 5 | Training data attribution (TDA) techniques find influential training data for the model's prediction on the test data of interest. They approximate the impact of down- or up-weighting a particular training sample. While conceptually useful, they are hardly applicable to deep models in practice, particularly because of their sensitivity to different model initialisation. In this paper, we introduce a Bayesian perspective on the TDA task, where the learned model is treated as a Bayesian posterior and the TDA estimates as random variables. From this novel viewpoint, we observe that the influence of an individual training sample is often overshadowed by the noise stemming from model initialisation and SGD batch composition. Based on this observation, we argue that TDA can only be reliably used for explaining deep model predictions that are consistently influenced by certain training data, independent of other noise factors. Our experiments demonstrate the rarity of such noise-independent training-test data pairs but confirm their existence. We recommend that future researchers and practitioners trust TDA estimates only in such cases. Further, we find a disagreement between ground truth and estimated TDA distributions and encourage future work to study this gap. 6 | 7 | #### [Link to paper](https://arxiv.org/abs/2305.19765) 8 | 9 | ------------------------------ 10 | ## Reproducing the experiments 11 | 12 | ### Requirements 13 | 14 | The main dependencies are: 15 | 16 | - `python==3.10.4` 17 | - `torch==2.0.0` 18 | - `torchvision==0.15.0` 19 | - `transformers==4.28.1` 20 | - `datasets==2.12.0` 21 | - `numpy==1.23.5` 22 | - `pandas==1.5.2` 23 | - `scikit-learn==1.2.2` 24 | - `seaborn==0.12.2` 25 | 26 | A `req.txt` is provided that details the packages required for reproducing the experiments. To install the same conda environment, use `$ conda create --name --file req.txt` 27 | We conducted the experiments using this environment on a Nvidia 2080ti GPU. 28 | 29 | ### Data 30 | 31 | We subsample MNIST and CIFAR10, and provide the indices of the datasets used in our experiments in `data/subset_indices`. 32 | To reproduce the dataset, run `run_subset_generation.py`. 33 | 34 | ### Models 35 | 36 | To train the CNN models, run `run_cnn_training.py`. This trains two layer CNNs. If you want to train three layer CNNs, update this script with the respective classes. 37 | To finetune the ViT model with LoRA, run `run_vit_finetuning.py`. 38 | 39 | These scripts train the respective model 10 times on the seeds specified in `random_seeds.pt`, which we also use in the paper. This corresponds to sampling a model $\theta$ from the posterior $p(\theta|\mathcal{D})$ using Deep Ensembling. 40 | The checkpoints after the last 5 epochs are saved to `models/`. This corresponds to sampling a model $\theta$ from the posterior $p(\theta|\mathcal{D})$ similar to stochastic weight averaging. 41 | 42 | This is done for all datasets, if you wish to run it on a specific one, change it directly in the script. 43 | 44 | ### Experiments 45 | In the paper, we conduct hypothesis testing of the signal-to-noise ratio in TDA scores and report the p-value as an indicator of the statistical significance of the estimated scores. Additionally, we inspect the Pearson and Spearman correlations of the TDA scores of different methods to find out how well they correspond to each other. Below are instructions on how to reproduce these analyses. 46 | 47 | #### Step 1: Computing the TDA scores 48 | We test 5 different TDA methods. We provide the scripts in the folders `tda_cnn_scripts` and `tda_vit_scripts` for computing the TDA scores of across the ensemble of models for the CNN and ViT respectively. 49 | 50 | Each of the scripts should be called with the following parameters: `python --task --num_per_class --seed_id `. 51 | 52 | - `--task` specifies the task of the experiment, i.e. either `mnist3` or `cifar10`. 53 | - `--num_per_class` is an integer $\in$ {10, 20, 50} that refers to how many samples per class the model was trained on. 54 | - `--seed_id` is an integer that specifies the seed from the `random_seeds.pt` file. This parameter is used for parallel processing, in case multiple GPUs are available. 55 | 56 | For computing influence functions, we use the code provided by the [FastIF repository](https://github.com/salesforce/fast-influence-functions). Beware to compute the HVP s_test before computing the influence function. 57 | 58 | Please note that this step may take a while, depending on the size of the model. 59 | 60 | #### Step 2: Computing p-values 61 | After TDA scores are computed, we can analyse the reliability of the scores by the p-values. To compute p-values across the TDA scores computed for each train-test pair of each sample $\theta$, run `run_compute_pvalues.py`. This will generate CSV files of the p-values in a new `results/` folder. 62 | 63 | Specify the experiment to compute p-values for as `__pc`, e.g. `cnn_mnist3_10pc`. 64 | 65 | #### Step 3: Computing correlations 66 | We provide the script to compute the Pearson and Spearman correlation between the mean, standard deviation and p-values computed across the different samples $\theta$ in `run_correlation_analysis.py`. This will save the correlation matrices in the `results/` folder. 67 | 68 | ---------- 69 | ## Contact 70 | For any problem with implementation or bug, please contact Elisa Nguyen 71 | . 72 | 73 | ## How to cite 74 | ``` 75 | @inproceedings{nguyen2023bayesiantda, 76 | title = {A Bayesian Perspective On Training Data Attribution}, 77 | author = {Nguyen, Elisa and Seo, Minjoon and Oh, Seong Joon}, 78 | year = {2023}, 79 | booktitle = {Conference on Neural Information Processing Systems}, 80 | } 81 | ``` 82 | -------------------------------------------------------------------------------- /req.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=conda_forge 5 | _openmp_mutex=4.5=2_gnu 6 | accelerate=0.18.0=pypi_0 7 | aiohttp=3.8.4=pypi_0 8 | aiosignal=1.3.1=pypi_0 9 | appdirs=1.4.4=pypi_0 10 | asttokens=2.1.0=pypi_0 11 | async-timeout=4.0.2=pypi_0 12 | attrs=23.1.0=pypi_0 13 | backcall=0.2.0=pypi_0 14 | backports=1.0=py_2 15 | backports-functools-lru-cache=1.6.4=pypi_0 16 | backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 17 | blas=1.0=mkl 18 | bottleneck=1.3.5=py310ha9d4c09_0 19 | brotli=1.0.9=h5eee18b_7 20 | brotli-bin=1.0.9=h5eee18b_7 21 | brotlipy=0.7.0=py310h7f8727e_1002 22 | bzip2=1.0.8=h7b6447c_0 23 | ca-certificates=2023.01.10=h06a4308_0 24 | certifi=2022.12.7=py310h06a4308_0 25 | cffi=1.15.1=py310h74dc2b5_0 26 | charset-normalizer=2.0.4=pypi_0 27 | click=8.0.4=py310h06a4308_0 28 | colorama=0.4.6=pyhd8ed1ab_0 29 | contourpy=1.0.5=py310hdb19cb5_0 30 | cryptography=39.0.1=py310h9ce1e76_0 31 | cuda-cudart=11.7.99=0 32 | cuda-cupti=11.7.101=0 33 | cuda-libraries=11.7.1=0 34 | cuda-nvrtc=11.7.99=0 35 | cuda-nvtx=11.7.91=0 36 | cuda-runtime=11.7.1=0 37 | cycler=0.11.0=pypi_0 38 | dataclasses=0.8=pyh6d0b6a4_7 39 | datasets=2.12.0=pypi_0 40 | dbus=1.13.18=hb2f20db_0 41 | debugpy=1.6.0=py310hd8f1fbe_0 42 | decorator=5.1.1=pypi_0 43 | dill=0.3.6=pypi_0 44 | entrypoints=0.4=pypi_0 45 | executing=1.2.0=pypi_0 46 | expat=2.4.9=h6a678d5_0 47 | ffmpeg=4.3=hf484d3e_0 48 | fftw=3.3.9=h27cfd23_1 49 | filelock=3.9.0=py310h06a4308_0 50 | flit-core=3.8.0=py310h06a4308_0 51 | fontconfig=2.14.1=hef1e5e3_0 52 | fonttools=4.25.0=pypi_0 53 | freetype=2.11.0=h70c0345_0 54 | frozenlist=1.3.3=pypi_0 55 | fsspec=2023.4.0=pypi_0 56 | giflib=5.2.1=h5eee18b_3 57 | glib=2.69.1=h4ff587b_1 58 | gmp=6.2.1=h295c915_3 59 | gmpy2=2.1.2=py310heeb90bb_0 60 | gnutls=3.6.15=he1e5248_0 61 | gst-plugins-base=1.14.1=h6a678d5_1 62 | gstreamer=1.14.1=h5eee18b_1 63 | huggingface_hub=0.14.1=py_0 64 | icu=58.2=he6710b0_3 65 | idna=3.4=py310h06a4308_0 66 | importlib-metadata=6.0.0=py310h06a4308_0 67 | importlib_metadata=6.0.0=hd3eb1b0_0 68 | intel-openmp=2021.4.0=h06a4308_3561 69 | ipykernel=6.14.0=py310hfdc917e_0 70 | ipython=8.4.0=py310hff52083_0 71 | jedi=0.18.1=pypi_0 72 | jinja2=3.1.2=py310h06a4308_0 73 | joblib=1.1.1=py310h06a4308_0 74 | jpeg=9e=h5eee18b_1 75 | jupyter-client=7.4.7=pypi_0 76 | jupyter_client=7.4.7=pyhd8ed1ab_0 77 | jupyter_core=5.0.0=py310hff52083_0 78 | kiwisolver=1.4.4=py310h6a678d5_0 79 | lame=3.100=h7b6447c_0 80 | lcms2=2.12=h3be6417_0 81 | ld_impl_linux-64=2.38=h1181459_1 82 | libbrotlicommon=1.0.9=h5eee18b_7 83 | libbrotlidec=1.0.9=h5eee18b_7 84 | libbrotlienc=1.0.9=h5eee18b_7 85 | libcublas=11.10.3.66=0 86 | libcufft=10.7.2.124=h4fbf590_0 87 | libcufile=1.6.0.25=0 88 | libcurand=10.3.2.56=0 89 | libcusolver=11.4.0.1=0 90 | libcusparse=11.7.4.91=0 91 | libffi=3.3=he6710b0_2 92 | libgcc-ng=12.2.0=h65d4601_19 93 | libgfortran-ng=11.2.0=h00389a5_1 94 | libgfortran5=11.2.0=h1234567_1 95 | libgomp=12.2.0=h65d4601_19 96 | libiconv=1.16=h7f8727e_2 97 | libidn2=2.3.2=h7f8727e_0 98 | libnpp=11.7.4.75=0 99 | libnsl=2.0.0=h7f98852_0 100 | libnvjpeg=11.8.0.2=0 101 | libpng=1.6.37=hbc83047_0 102 | libprotobuf=3.20.1=h4ff587b_0 103 | libsodium=1.0.18=h36c2ea0_1 104 | libstdcxx-ng=11.2.0=h1234567_1 105 | libtasn1=4.19.0=h5eee18b_0 106 | libtiff=4.2.0=hecacb30_2 107 | libunistring=0.9.10=h27cfd23_0 108 | libuuid=1.41.5=h5eee18b_0 109 | libwebp=1.2.4=h11a3e52_1 110 | libwebp-base=1.2.4=h5eee18b_1 111 | libxcb=1.15=h7f8727e_0 112 | libxml2=2.9.14=h74e7548_0 113 | libzlib=1.2.12=h166bdaf_2 114 | lz4-c=1.9.4=h6a678d5_0 115 | markupsafe=2.1.1=py310h7f8727e_0 116 | matplotlib=3.7.1=py310h06a4308_1 117 | matplotlib-base=3.7.1=py310h1128e8f_1 118 | matplotlib-inline=0.1.6=pypi_0 119 | mkl=2021.4.0=h06a4308_640 120 | mkl-service=2.4.0=py310h7f8727e_0 121 | mkl_fft=1.3.1=py310hd6ae3a3_0 122 | mkl_random=1.2.2=py310h00e6091_0 123 | mpc=1.1.0=h10f8cd9_1 124 | mpfr=4.0.2=hb69a4c5_1 125 | mpmath=1.2.1=pypi_0 126 | multidict=6.0.4=pypi_0 127 | multiprocess=0.70.14=pypi_0 128 | munkres=1.1.4=pypi_0 129 | ncurses=6.3=h7f8727e_2 130 | nest-asyncio=1.5.6=pypi_0 131 | nettle=3.7.3=hbbd107a_1 132 | networkx=2.8.4=py310h06a4308_1 133 | numexpr=2.8.4=py310h8879344_0 134 | numpy=1.23.5=py310hd5efca6_0 135 | numpy-base=1.23.5=py310h8e6c178_0 136 | openh264=2.1.1=h4ff587b_0 137 | openssl=1.1.1t=h7f8727e_0 138 | packaging=22.0=py310h06a4308_0 139 | pandas=1.5.2=py310h1128e8f_0 140 | parso=0.8.3=pypi_0 141 | pcre=8.45=h295c915_0 142 | peft=0.2.0=pypi_0 143 | pexpect=4.8.0=pypi_0 144 | pickleshare=0.7.5=pypi_0 145 | pillow=9.2.0=py310hace64e9_1 146 | pip=23.0.1=py310h06a4308_0 147 | platformdirs=2.5.2=pypi_0 148 | pooch=1.4.0=pypi_0 149 | prompt-toolkit=3.0.32=pypi_0 150 | protobuf=3.20.1=py310h295c915_0 151 | psutil=5.9.4=py310h5764c6d_0 152 | ptyprocess=0.7.0=pypi_0 153 | pure-eval=0.2.2=pypi_0 154 | pure_eval=0.2.2=pyhd8ed1ab_0 155 | pyarrow=11.0.0=pypi_0 156 | pycparser=2.21=pypi_0 157 | pygments=2.13.0=pypi_0 158 | pyopenssl=23.0.0=py310h06a4308_0 159 | pyparsing=3.0.9=py310h06a4308_0 160 | pyqt=5.9.2=py310h295c915_6 161 | pysocks=1.7.1=py310h06a4308_0 162 | python=3.10.4=h12debd9_0 163 | python-dateutil=2.8.2=pypi_0 164 | python_abi=3.10=2_cp310 165 | pytorch=2.0.0=py3.10_cuda11.7_cudnn8.5.0_0 166 | pytorch-cuda=11.7=h778d358_3 167 | pytorch-mutex=1.0=cuda 168 | pytz=2022.7=py310h06a4308_0 169 | pyyaml=6.0=py310h5eee18b_1 170 | pyzmq=23.0.0=py310h330234f_0 171 | qt=5.9.7=h5867ecd_1 172 | readline=8.1.2=h7f8727e_1 173 | regex=2022.7.9=py310h5eee18b_0 174 | requests=2.28.1=py310h06a4308_1 175 | responses=0.18.0=pypi_0 176 | sacremoses=master=py_0 177 | scikit-learn=1.2.2=pypi_0 178 | scipy=1.10.0=py310hd5efca6_0 179 | seaborn=0.12.2=pypi_0 180 | setuptools=65.6.3=py310h06a4308_0 181 | sip=4.19.13=py310h295c915_0 182 | six=1.16.0=pypi_0 183 | sqlite=3.38.5=hc218d9a_0 184 | stack-data=0.6.1=pypi_0 185 | stack_data=0.6.1=pyhd8ed1ab_0 186 | sympy=1.11.1=py310h06a4308_0 187 | threadpoolctl=3.1.0=pypi_0 188 | tk=8.6.12=h1ccaba5_0 189 | tokenizers=0.11.4=py310h3dcd8bd_1 190 | torchaudio=2.0.0=py310_cu117 191 | torchtriton=2.0.0=py310 192 | torchvision=0.15.0=py310_cu117 193 | tornado=6.2=py310h5eee18b_0 194 | tqdm=4.64.1=pyhd8ed1ab_0 195 | traitlets=5.5.0=pypi_0 196 | transformers=4.28.1=pypi_0 197 | typing-extensions=4.4.0=py310h06a4308_0 198 | typing_extensions=4.4.0=py310h06a4308_0 199 | tzdata=2023c=h04d1e81_0 200 | urllib3=1.26.15=py310h06a4308_0 201 | wcwidth=0.2.5=pypi_0 202 | wheel=0.38.4=py310h06a4308_0 203 | xxhash=3.2.0=pypi_0 204 | xz=5.2.5=h7f8727e_1 205 | yaml=0.2.5=h7b6447c_0 206 | yarl=1.9.2=pypi_0 207 | zeromq=4.3.4=h9c3ff4c_1 208 | zipp=3.11.0=py310h06a4308_0 209 | zlib=1.2.12=h7f8727e_2 210 | zstd=1.5.2=ha4553b6_0 211 | -------------------------------------------------------------------------------- /data/cifar10/train_subset_60pc.txt: -------------------------------------------------------------------------------- 1 | 3139 2 | 124 3 | 13383 4 | 17602 5 | 24519 6 | 36110 7 | 47752 8 | 43264 9 | 22885 10 | 38864 11 | 16064 12 | 35312 13 | 35506 14 | 532 15 | 16639 16 | 23086 17 | 18653 18 | 9238 19 | 17427 20 | 43880 21 | 23697 22 | 8194 23 | 16539 24 | 11754 25 | 17715 26 | 45300 27 | 45622 28 | 48974 29 | 8009 30 | 28370 31 | 12168 32 | 15726 33 | 12526 34 | 49708 35 | 17697 36 | 40337 37 | 12337 38 | 33173 39 | 24213 40 | 11851 41 | 1802 42 | 33468 43 | 33769 44 | 24312 45 | 9923 46 | 721 47 | 18724 48 | 21228 49 | 15211 50 | 48770 51 | 42484 52 | 36913 53 | 32536 54 | 41801 55 | 30563 56 | 35170 57 | 49305 58 | 20993 59 | 22100 60 | 48337 61 | 49911 62 | 495 63 | 25393 64 | 8003 65 | 24278 66 | 45235 67 | 1827 68 | 8762 69 | 9541 70 | 9843 71 | 29435 72 | 2551 73 | 12715 74 | 32266 75 | 36321 76 | 21584 77 | 40215 78 | 28543 79 | 16688 80 | 22415 81 | 13625 82 | 49398 83 | 2761 84 | 34507 85 | 40237 86 | 19709 87 | 29458 88 | 39746 89 | 2531 90 | 3143 91 | 3666 92 | 46755 93 | 17895 94 | 47278 95 | 1112 96 | 12011 97 | 28046 98 | 19590 99 | 29580 100 | 10375 101 | 49026 102 | 44723 103 | 39813 104 | 36794 105 | 1043 106 | 27707 107 | 13384 108 | 1587 109 | 12063 110 | 5819 111 | 12497 112 | 4704 113 | 38324 114 | 28285 115 | 27111 116 | 12670 117 | 47789 118 | 22252 119 | 18957 120 | 29186 121 | 24290 122 | 28002 123 | 42695 124 | 16294 125 | 36326 126 | 36168 127 | 27946 128 | 46994 129 | 35457 130 | 40240 131 | 45683 132 | 20137 133 | 21868 134 | 25703 135 | 22139 136 | 40219 137 | 36523 138 | 15508 139 | 3761 140 | 25813 141 | 46446 142 | 37627 143 | 40116 144 | 40769 145 | 8403 146 | 11983 147 | 43741 148 | 10561 149 | 6234 150 | 35522 151 | 41822 152 | 4725 153 | 11226 154 | 28186 155 | 33302 156 | 29111 157 | 5815 158 | 27167 159 | 36234 160 | 39607 161 | 43170 162 | 8987 163 | 13307 164 | 24046 165 | 43765 166 | 25941 167 | 4962 168 | 40146 169 | 16924 170 | 32018 171 | 27144 172 | 6280 173 | 42725 174 | 47384 175 | 42035 176 | 24241 177 | 47591 178 | 39182 179 | 1784 180 | 31689 181 | 18303 182 | 37592 183 | 31395 184 | 48709 185 | 28024 186 | 16428 187 | 21528 188 | 11336 189 | 5917 190 | 44499 191 | 13207 192 | 11966 193 | 32486 194 | 31102 195 | 45403 196 | 30865 197 | 32983 198 | 27084 199 | 26168 200 | 43853 201 | 35878 202 | 35070 203 | 19313 204 | 42697 205 | 8791 206 | 8216 207 | 29543 208 | 43297 209 | 4517 210 | 36838 211 | 49136 212 | 29007 213 | 11308 214 | 8127 215 | 40000 216 | 46547 217 | 21593 218 | 12923 219 | 8884 220 | 36067 221 | 26115 222 | 22936 223 | 34442 224 | 9508 225 | 10755 226 | 35465 227 | 32161 228 | 10271 229 | 27601 230 | 14376 231 | 3059 232 | 6813 233 | 28622 234 | 2597 235 | 43865 236 | 44272 237 | 29806 238 | 4553 239 | 13534 240 | 2752 241 | 24272 242 | 29581 243 | 13486 244 | 11035 245 | 40727 246 | 44256 247 | 37895 248 | 41321 249 | 43116 250 | 34088 251 | 45346 252 | 7184 253 | 4608 254 | 4196 255 | 21057 256 | 5634 257 | 32308 258 | 21942 259 | 21262 260 | 21886 261 | 46522 262 | 43873 263 | 5537 264 | 48092 265 | 907 266 | 2641 267 | 41231 268 | 8178 269 | 35596 270 | 32377 271 | 16581 272 | 29198 273 | 49536 274 | 27542 275 | 12656 276 | 5296 277 | 27197 278 | 20005 279 | 37594 280 | 37343 281 | 9404 282 | 12380 283 | 31308 284 | 16208 285 | 39882 286 | 49397 287 | 40480 288 | 29181 289 | 25440 290 | 12752 291 | 10944 292 | 423 293 | 20125 294 | 1606 295 | 46016 296 | 42990 297 | 37339 298 | 22974 299 | 27480 300 | 3856 301 | 28801 302 | 49202 303 | 38422 304 | 37897 305 | 28366 306 | 29664 307 | 37286 308 | 44922 309 | 38128 310 | 21069 311 | 6522 312 | 48272 313 | 11334 314 | 28152 315 | 31719 316 | 28771 317 | 49350 318 | 5269 319 | 40570 320 | 34707 321 | 31482 322 | 30581 323 | 42457 324 | 37122 325 | 26692 326 | 42061 327 | 38220 328 | 42596 329 | 16025 330 | 36829 331 | 42353 332 | 4393 333 | 28198 334 | 18712 335 | 38196 336 | 11411 337 | 47241 338 | 29208 339 | 46797 340 | 23124 341 | 11605 342 | 7682 343 | 391 344 | 27244 345 | 13136 346 | 45409 347 | 39711 348 | 3261 349 | 21083 350 | 30385 351 | 28392 352 | 48464 353 | 5336 354 | 22485 355 | 26709 356 | 37100 357 | 32586 358 | 38667 359 | 38920 360 | 28104 361 | 33852 362 | 5284 363 | 6219 364 | 49177 365 | 40255 366 | 8831 367 | 1138 368 | 48641 369 | 32974 370 | 34743 371 | 39766 372 | 37994 373 | 1479 374 | 17696 375 | 15374 376 | 38381 377 | 38218 378 | 41201 379 | 10974 380 | 7995 381 | 42807 382 | 574 383 | 37845 384 | 21716 385 | 38625 386 | 14934 387 | 2298 388 | 38348 389 | 39232 390 | 41757 391 | 17394 392 | 16779 393 | 45095 394 | 5245 395 | 27550 396 | 45094 397 | 36200 398 | 9545 399 | 45486 400 | 12286 401 | 44092 402 | 8707 403 | 20290 404 | 48558 405 | 19090 406 | 44054 407 | 48158 408 | 16027 409 | 47597 410 | 45243 411 | 16532 412 | 4743 413 | 47999 414 | 21374 415 | 4233 416 | 6143 417 | 3155 418 | 31425 419 | 32165 420 | 4412 421 | 20529 422 | 23250 423 | 33980 424 | 14108 425 | 49089 426 | 16364 427 | 4796 428 | 32923 429 | 27586 430 | 1965 431 | 8839 432 | 20918 433 | 20953 434 | 43105 435 | 31065 436 | 38449 437 | 30058 438 | 10488 439 | 26399 440 | 14938 441 | 41213 442 | 23261 443 | 1962 444 | 40362 445 | 15872 446 | 47518 447 | 38699 448 | 46443 449 | 13595 450 | 1496 451 | 9034 452 | 8162 453 | 11026 454 | 26626 455 | 174 456 | 29215 457 | 11739 458 | 23322 459 | 48243 460 | 37584 461 | 46337 462 | 5713 463 | 14706 464 | 18216 465 | 25676 466 | 48814 467 | 14674 468 | 15411 469 | 43752 470 | 9564 471 | 26907 472 | 47478 473 | 20151 474 | 25378 475 | 13371 476 | 26064 477 | 24460 478 | 47049 479 | 14373 480 | 30015 481 | 33193 482 | 48785 483 | 3616 484 | 47620 485 | 3221 486 | 47192 487 | 12256 488 | 43307 489 | 33494 490 | 45882 491 | 15672 492 | 13073 493 | 45903 494 | 38346 495 | 38074 496 | 12246 497 | 38624 498 | 19683 499 | 26365 500 | 35867 501 | 15686 502 | 19042 503 | 5840 504 | 42451 505 | 9875 506 | 13345 507 | 49467 508 | 31120 509 | 2240 510 | 22387 511 | 20101 512 | 18083 513 | 26303 514 | 13721 515 | 47920 516 | 11351 517 | 28144 518 | 4849 519 | 9388 520 | 15362 521 | 6760 522 | 23046 523 | 20277 524 | 37915 525 | 37720 526 | 48843 527 | 30763 528 | 18913 529 | 5090 530 | 43801 531 | 12135 532 | 19800 533 | 33006 534 | 22715 535 | 1656 536 | 38507 537 | 49315 538 | 24838 539 | 4535 540 | 44345 541 | 37806 542 | 34578 543 | 48986 544 | 11908 545 | 33780 546 | 9895 547 | 6867 548 | 8098 549 | 33589 550 | 38683 551 | 19682 552 | 45413 553 | 17355 554 | 32602 555 | 8748 556 | 15931 557 | 12129 558 | 14666 559 | 8079 560 | 10243 561 | 77 562 | 43996 563 | 48136 564 | 22967 565 | 17763 566 | 35414 567 | 19318 568 | 45349 569 | 42226 570 | 42320 571 | 36049 572 | 28636 573 | 16529 574 | 22118 575 | 39121 576 | 7168 577 | 6188 578 | 49633 579 | 7914 580 | 782 581 | 10146 582 | 37199 583 | 30575 584 | 22622 585 | 41055 586 | 40607 587 | 39531 588 | 15678 589 | 8207 590 | 18299 591 | 10330 592 | 44320 593 | 9512 594 | 44820 595 | 44245 596 | 17082 597 | 45301 598 | 42784 599 | 11944 600 | 31481 601 | -------------------------------------------------------------------------------- /data/mnist3/test_subset.txt: -------------------------------------------------------------------------------- 1 | 2061 2 | 5259 3 | 2328 4 | 9951 5 | 8528 6 | 8978 7 | 5290 8 | 3322 9 | 7873 10 | 7702 11 | 6979 12 | 2378 13 | 6510 14 | 188 15 | 6097 16 | 5818 17 | 3310 18 | 5301 19 | 4208 20 | 7593 21 | 7317 22 | 3768 23 | 6383 24 | 4361 25 | 3359 26 | 5450 27 | 8614 28 | 6925 29 | 620 30 | 7452 31 | 4824 32 | 586 33 | 3804 34 | 2326 35 | 1841 36 | 6808 37 | 6711 38 | 4082 39 | 5980 40 | 3358 41 | 3052 42 | 7062 43 | 8691 44 | 4857 45 | 2631 46 | 5069 47 | 3640 48 | 6351 49 | 4788 50 | 4542 51 | 7907 52 | 704 53 | 2713 54 | 6959 55 | 8772 56 | 9753 57 | 5079 58 | 3368 59 | 6923 60 | 5244 61 | 3677 62 | 5993 63 | 8057 64 | 6921 65 | 4227 66 | 3942 67 | 1084 68 | 5519 69 | 1273 70 | 2087 71 | 4327 72 | 3305 73 | 7882 74 | 3179 75 | 2318 76 | 2096 77 | 3325 78 | 7889 79 | 7381 80 | 3706 81 | 2106 82 | 8481 83 | 8117 84 | 2996 85 | 7449 86 | 5043 87 | 4672 88 | 6398 89 | 2777 90 | 6818 91 | 9724 92 | 2641 93 | 2717 94 | 6287 95 | 5653 96 | 6114 97 | 7558 98 | 3013 99 | 5255 100 | 2370 101 | 642 102 | 5652 103 | 9105 104 | 8795 105 | 7555 106 | 1570 107 | 804 108 | 7168 109 | 6342 110 | 5603 111 | 5141 112 | 9952 113 | 1638 114 | 4675 115 | 5249 116 | 4186 117 | 8351 118 | 9239 119 | 8359 120 | 4048 121 | 6798 122 | 9688 123 | 1223 124 | 1203 125 | 5756 126 | 8251 127 | 4780 128 | 1468 129 | 8764 130 | 8657 131 | 5435 132 | 4959 133 | 4473 134 | 6400 135 | 2703 136 | 5241 137 | 5251 138 | 3646 139 | 2627 140 | 7747 141 | 4090 142 | 7072 143 | 4432 144 | 3450 145 | 5243 146 | 3090 147 | 3417 148 | 6016 149 | 8055 150 | 1898 151 | 4720 152 | 5368 153 | 7225 154 | 3895 155 | 6181 156 | 3 157 | 4887 158 | 7175 159 | 997 160 | 28 161 | 9158 162 | 871 163 | 6211 164 | 4604 165 | 3315 166 | 8738 167 | 6470 168 | 6795 169 | 3242 170 | 6835 171 | 7037 172 | 1047 173 | 9436 174 | 6532 175 | 407 176 | 8634 177 | 5006 178 | 8414 179 | 4804 180 | 9220 181 | 1650 182 | 2714 183 | 6191 184 | 2461 185 | 7987 186 | 5031 187 | 7354 188 | 1416 189 | 9439 190 | 4617 191 | 8085 192 | 1297 193 | 1692 194 | 3959 195 | 2794 196 | 8808 197 | 1713 198 | 6633 199 | 2188 200 | 7898 201 | 8496 202 | 8137 203 | 9488 204 | 8951 205 | 6546 206 | 7855 207 | 6099 208 | 6621 209 | 904 210 | 5932 211 | 1220 212 | 5750 213 | 6129 214 | 9037 215 | 6286 216 | 3735 217 | 5673 218 | 4282 219 | 6753 220 | 1892 221 | 6072 222 | 9167 223 | 3540 224 | 9339 225 | 3265 226 | 5613 227 | 4281 228 | 9455 229 | 3624 230 | 9016 231 | 7269 232 | 7846 233 | 1590 234 | 713 235 | 9902 236 | 7671 237 | 2764 238 | 2568 239 | 3867 240 | 6131 241 | 564 242 | 3044 243 | 1271 244 | 3191 245 | 2937 246 | 7263 247 | 4527 248 | 8991 249 | 3512 250 | 1504 251 | 3574 252 | 1399 253 | 2449 254 | 8595 255 | 6015 256 | 6428 257 | 6597 258 | 327 259 | 8567 260 | 8127 261 | 7342 262 | 6833 263 | 7161 264 | 305 265 | 5965 266 | 4023 267 | 1197 268 | 3734 269 | 6478 270 | 3710 271 | 5139 272 | 9800 273 | 8513 274 | 3764 275 | 6770 276 | 4477 277 | 6062 278 | 3087 279 | 9060 280 | 3989 281 | 9621 282 | 5165 283 | 1723 284 | 7112 285 | 763 286 | 2638 287 | 2999 288 | 1995 289 | 8098 290 | 932 291 | 8325 292 | 4957 293 | 5463 294 | 8731 295 | 8016 296 | 1487 297 | 851 298 | 525 299 | 9542 300 | 8913 301 | 7019 302 | 4308 303 | 9348 304 | 7171 305 | 4655 306 | 4516 307 | 8187 308 | 2027 309 | 7176 310 | 6202 311 | 6758 312 | 8324 313 | 9438 314 | 6743 315 | 9699 316 | 4168 317 | 5072 318 | 783 319 | 1448 320 | 3019 321 | 7899 322 | 9896 323 | 2676 324 | 2379 325 | 3922 326 | 5872 327 | 529 328 | 6397 329 | 5917 330 | 4025 331 | 6125 332 | 9969 333 | 949 334 | 9705 335 | 5112 336 | 9265 337 | 265 338 | 8745 339 | 5566 340 | 4032 341 | 6456 342 | 2239 343 | 5132 344 | 4006 345 | 276 346 | 74 347 | 3097 348 | 3455 349 | 5651 350 | 2599 351 | 5520 352 | 7567 353 | 6338 354 | 7000 355 | 1129 356 | 615 357 | 4153 358 | 3438 359 | 7183 360 | 6438 361 | 1527 362 | 5005 363 | 5970 364 | 251 365 | 2950 366 | 2661 367 | 4871 368 | 5953 369 | 4589 370 | 9836 371 | 1054 372 | 228 373 | 9140 374 | 419 375 | 4328 376 | 8905 377 | 3990 378 | 3747 379 | 1254 380 | 5896 381 | 2166 382 | 6430 383 | 1338 384 | 5630 385 | 4670 386 | 267 387 | 2421 388 | 9838 389 | 9799 390 | 7048 391 | 682 392 | 3789 393 | 1213 394 | 3562 395 | 8659 396 | 9725 397 | 4013 398 | 9946 399 | 354 400 | 8367 401 | 7344 402 | 4337 403 | 1136 404 | 8536 405 | 3003 406 | 4171 407 | 7253 408 | 6961 409 | 6751 410 | 7911 411 | 1434 412 | 1780 413 | 2734 414 | 1038 415 | 777 416 | 8427 417 | 8360 418 | 2277 419 | 489 420 | 2510 421 | 5943 422 | 3638 423 | 8340 424 | 3070 425 | 4595 426 | 1643 427 | 2725 428 | 5666 429 | 1994 430 | 3143 431 | 2803 432 | 4110 433 | 2612 434 | 3648 435 | 826 436 | 9005 437 | 5 438 | 7822 439 | 9249 440 | 6891 441 | 8078 442 | 5477 443 | 7738 444 | 4178 445 | 3471 446 | 506 447 | 7717 448 | 4774 449 | 6115 450 | 8239 451 | 1673 452 | 9222 453 | 5461 454 | 2261 455 | 4859 456 | 6783 457 | 1189 458 | 7839 459 | 6073 460 | 3852 461 | 4687 462 | 1691 463 | 8666 464 | 7325 465 | 9549 466 | 5331 467 | 1548 468 | 4005 469 | 190 470 | 348 471 | 7005 472 | 8995 473 | 2416 474 | 7900 475 | 107 476 | 8633 477 | 6819 478 | 2221 479 | 2674 480 | 37 481 | 2746 482 | 342 483 | 5889 484 | 3480 485 | 4729 486 | 5092 487 | 6533 488 | 2868 489 | 4147 490 | 8219 491 | 7783 492 | 2335 493 | 6902 494 | 9674 495 | 7881 496 | 6969 497 | 4074 498 | 3651 499 | 5732 500 | 3777 501 | 9124 502 | 9061 503 | 4903 504 | 2524 505 | 2867 506 | 9313 507 | 8559 508 | 3353 509 | 977 510 | 6232 511 | 3919 512 | 8159 513 | 3354 514 | 809 515 | 1238 516 | 5553 517 | 7411 518 | 3231 519 | 2409 520 | 850 521 | 1179 522 | 6231 523 | 2258 524 | 2705 525 | 1211 526 | 4216 527 | 4717 528 | 1025 529 | 2753 530 | 8128 531 | 3314 532 | 8862 533 | 3930 534 | 5208 535 | 8113 536 | 3050 537 | 4069 538 | 8268 539 | 7890 540 | 504 541 | 3601 542 | 5506 543 | 9336 544 | 5614 545 | 2984 546 | 5323 547 | 7916 548 | 948 549 | 4212 550 | 2825 551 | 5258 552 | 3255 553 | 4653 554 | 2473 555 | 9017 556 | 7725 557 | 3320 558 | 3211 559 | 5984 560 | 2757 561 | 695 562 | 6634 563 | 9956 564 | 2273 565 | 3900 566 | 7270 567 | 1707 568 | 8887 569 | 8930 570 | 4869 571 | 768 572 | 5689 573 | 3765 574 | 1603 575 | 196 576 | 5457 577 | 9356 578 | 1715 579 | 1922 580 | 8493 581 | 2912 582 | 749 583 | 5113 584 | 2982 585 | 1528 586 | 224 587 | 3649 588 | 663 589 | 984 590 | 2704 591 | 1257 592 | 5014 593 | 2576 594 | 4104 595 | 416 596 | 7247 597 | 3605 598 | 6976 599 | 7323 600 | 4984 601 | 4384 602 | 4296 603 | 9521 604 | 8966 605 | 9209 606 | 1356 607 | 7762 608 | 8943 609 | 3339 610 | 775 611 | 4812 612 | 1049 613 | 8492 614 | 318 615 | 2459 616 | 4250 617 | 624 618 | 1915 619 | 2397 620 | 4402 621 | 3696 622 | 4180 623 | 7012 624 | 8950 625 | 3257 626 | 199 627 | 9803 628 | 7678 629 | 106 630 | 8094 631 | 7420 632 | 6713 633 | 3435 634 | 9018 635 | 1409 636 | 8794 637 | 728 638 | 2353 639 | 4503 640 | 9346 641 | 1656 642 | 3658 643 | 1076 644 | 3785 645 | 8295 646 | 298 647 | 7518 648 | 4100 649 | 8810 650 | 5954 651 | 8721 652 | 6513 653 | 8153 654 | 7994 655 | 4326 656 | 306 657 | 892 658 | 1345 659 | 3775 660 | 4415 661 | 9547 662 | 7469 663 | 7840 664 | 2158 665 | 9082 666 | 6635 667 | 5507 668 | 2098 669 | 1412 670 | 172 671 | 9535 672 | 3823 673 | 5515 674 | 8907 675 | 256 676 | 2904 677 | 1174 678 | 516 679 | 119 680 | 6311 681 | 6305 682 | 4705 683 | 8253 684 | 3611 685 | 653 686 | 9839 687 | 8394 688 | 2433 689 | 4387 690 | 9716 691 | 7154 692 | 7486 693 | 5931 694 | 7587 695 | 2981 696 | 613 697 | 8704 698 | 1184 699 | 9832 700 | 7074 701 | 7115 702 | 6106 703 | 1585 704 | 2415 705 | 1696 706 | 6132 707 | 7184 708 | 7685 709 | 4482 710 | 6431 711 | 7848 712 | 1817 713 | 2396 714 | 7364 715 | 2432 716 | 3577 717 | 1514 718 | 6279 719 | 4367 720 | 3256 721 | 3337 722 | 5595 723 | 4289 724 | 9446 725 | 4339 726 | 7710 727 | 7739 728 | 9046 729 | 3258 730 | 3239 731 | 8544 732 | 77 733 | 5667 734 | 4427 735 | 4401 736 | 1771 737 | 2848 738 | 6051 739 | 5127 740 | 1802 741 | 6109 742 | 9867 743 | 7575 744 | 3628 745 | 6240 746 | 1957 747 | 9510 748 | 3436 749 | 1871 750 | 659 751 | 3057 752 | 7412 753 | 5895 754 | 980 755 | 5338 756 | 1811 757 | 690 758 | 4723 759 | 8165 760 | 6929 761 | 5436 762 | 2481 763 | 6466 764 | 1002 765 | 6123 766 | 8915 767 | 4920 768 | 3037 769 | 7457 770 | 8838 771 | 6412 772 | 5381 773 | 9847 774 | 5536 775 | 3207 776 | 1506 777 | 8310 778 | 3547 779 | 4248 780 | 8532 781 | 3176 782 | 6645 783 | 5089 784 | 9855 785 | 303 786 | 3021 787 | 4127 788 | 1991 789 | 221 790 | 7022 791 | 38 792 | 5767 793 | 6754 794 | 7151 795 | 9168 796 | 2959 797 | 9131 798 | 1890 799 | 7229 800 | 7930 801 | 2652 802 | 5605 803 | 249 804 | 6084 805 | 147 806 | 6384 807 | 6036 808 | 3384 809 | 8908 810 | 9246 811 | 4445 812 | 2082 813 | 8198 814 | 1932 815 | 3415 816 | 8569 817 | 547 818 | 8673 819 | 9700 820 | 331 821 | 2602 822 | 5267 823 | 6519 824 | 6064 825 | 9215 826 | 9937 827 | 3522 828 | 6534 829 | 7120 830 | 452 831 | 6156 832 | 9500 833 | 6562 834 | 6554 835 | 8139 836 | 9032 837 | 8381 838 | 6773 839 | 1551 840 | 4418 841 | 8168 842 | 6796 843 | 8746 844 | 7785 845 | 2390 846 | 278 847 | 285 848 | 1337 849 | 3375 850 | 8482 851 | 5717 852 | 4274 853 | 5983 854 | 4066 855 | 7564 856 | 3293 857 | 3513 858 | 6035 859 | 5189 860 | 4636 861 | 8470 862 | 7291 863 | 444 864 | 2496 865 | 888 866 | 8108 867 | 7683 868 | 2891 869 | 5716 870 | 5739 871 | 5733 872 | 6217 873 | 8683 874 | 2391 875 | 2784 876 | 3772 877 | 4098 878 | 8793 879 | 9097 880 | 8592 881 | 6014 882 | 4939 883 | 6844 884 | 4969 885 | 1410 886 | 715 887 | 3958 888 | 922 889 | 9010 890 | 512 891 | 3250 892 | 9696 893 | 4113 894 | 2495 895 | 6340 896 | 208 897 | 1224 898 | 362 899 | 6180 900 | 4187 901 | -------------------------------------------------------------------------------- /nn_influence_utils_vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | import torch 7 | import numpy as np 8 | from tqdm import tqdm 9 | import os 10 | from typing import Dict, List, Union, Optional, Tuple, Iterator, Any, AnyStr 11 | 12 | 13 | def count_parameters(model: torch.nn.Module) -> int: 14 | return sum(p.numel() for p in model.parameters()) 15 | 16 | 17 | def compute_gradients( 18 | device: torch.device, 19 | n_gpu: int, 20 | model: torch.nn.Module, 21 | inputs: Dict[str, torch.Tensor], 22 | params_filter: Optional[List[str]], 23 | weight_decay: Optional[float], 24 | weight_decay_ignores: Optional[List[str]] 25 | ) -> List[torch.FloatTensor]: 26 | 27 | if params_filter is None: 28 | params_filter = [] 29 | 30 | model.zero_grad() 31 | processed_inputs = inputs['pixel_values'] 32 | labels = inputs['labels'] 33 | processed_inputs = processed_inputs.to(device) 34 | labels = labels.to(device) 35 | 36 | outputs = model(processed_inputs, labels=labels) 37 | loss = outputs.loss.mean() 38 | 39 | return torch.autograd.grad( 40 | outputs=loss, 41 | inputs=[ 42 | param for _, param 43 | in model.named_parameters() 44 | if param.requires_grad],) 45 | # create_graph=True) 46 | 47 | 48 | def compute_hessian_vector_products( 49 | device: torch.device, 50 | n_gpu: int, 51 | model: torch.nn.Module, 52 | inputs: Dict[str, torch.Tensor], 53 | vectors: torch.FloatTensor, 54 | params_filter: Optional[List[str]], 55 | weight_decay: Optional[float], 56 | weight_decay_ignores: Optional[List[str]] 57 | ) -> List[torch.FloatTensor]: 58 | 59 | if params_filter is None: 60 | params_filter = [] 61 | 62 | model.zero_grad() 63 | processed_inputs = inputs['pixel_values'] 64 | labels = inputs['labels'] 65 | processed_inputs = processed_inputs.to(device) 66 | labels = labels.to(device) 67 | 68 | outputs = model(processed_inputs, labels=labels) 69 | loss = outputs.loss.mean() 70 | 71 | torch.cuda.empty_cache() 72 | grad_tuple = torch.autograd.grad( 73 | outputs=loss, 74 | inputs=[ 75 | param for _, param 76 | in model.named_parameters() 77 | if param.requires_grad], 78 | create_graph=True) 79 | 80 | torch.cuda.empty_cache() 81 | model.zero_grad() 82 | grad_grad_tuple = torch.autograd.grad( 83 | outputs=grad_tuple, 84 | inputs=[ 85 | param for _, param 86 | in model.named_parameters() 87 | if param.requires_grad], 88 | grad_outputs=vectors, 89 | only_inputs=True 90 | ) 91 | 92 | return grad_grad_tuple 93 | 94 | 95 | def compute_s_test( 96 | n_gpu: int, 97 | device: torch.device, 98 | model: torch.nn.Module, 99 | test_inputs: Dict[str, torch.Tensor], 100 | train_data_loaders: List[torch.utils.data.DataLoader], 101 | params_filter: Optional[List[str]], 102 | weight_decay: Optional[float], 103 | weight_decay_ignores: Optional[List[str]], 104 | damp: float, 105 | scale: float, 106 | num_samples: Optional[int] = None, 107 | verbose: bool = False, 108 | ) -> List[torch.FloatTensor]: 109 | 110 | v = compute_gradients( 111 | model=model, 112 | n_gpu=n_gpu, 113 | device=device, 114 | inputs=test_inputs, 115 | params_filter=params_filter, 116 | weight_decay=weight_decay, 117 | weight_decay_ignores=weight_decay_ignores) 118 | 119 | torch.cuda.empty_cache() 120 | # Technically, it's hv^-1 121 | last_estimate = list(v).copy() 122 | cumulative_num_samples = 0 123 | with tqdm(total=num_samples) as pbar: 124 | for data_loader in train_data_loaders: 125 | for i, inputs in enumerate(data_loader): 126 | this_estimate = compute_hessian_vector_products( 127 | model=model, 128 | n_gpu=n_gpu, 129 | device=device, 130 | vectors=last_estimate, 131 | inputs=inputs, 132 | params_filter=params_filter, 133 | weight_decay=weight_decay, 134 | weight_decay_ignores=weight_decay_ignores) 135 | # Recursively caclulate h_estimate 136 | # https://github.com/dedeswim/pytorch_influence_functions/blob/master/pytorch_influence_functions/influence_functions/hvp_grad.py#L118 137 | with torch.no_grad(): 138 | new_estimate = [ 139 | a + (1 - damp) * b - c / scale 140 | for a, b, c in zip(v, last_estimate, this_estimate) 141 | ] 142 | 143 | pbar.update(1) 144 | if verbose is True: 145 | new_estimate_norm = new_estimate[0].norm().item() 146 | last_estimate_norm = last_estimate[0].norm().item() 147 | estimate_norm_diff = new_estimate_norm - last_estimate_norm 148 | pbar.set_description(f"{new_estimate_norm:.2f} | {estimate_norm_diff:.2f}") 149 | 150 | cumulative_num_samples += 1 151 | last_estimate = new_estimate 152 | torch.cuda.empty_cache() 153 | if num_samples is not None and i > num_samples: 154 | break 155 | 156 | # References: 157 | # https://github.com/kohpangwei/influence-release/blob/master/influence/genericNeuralNet.py#L475 158 | # Do this for each iteration of estimation 159 | # Since we use one estimation, we put this at the end 160 | inverse_hvp = [X / scale for X in last_estimate] 161 | 162 | # Sanity check 163 | # Note that in parallel settings, we should have `num_samples` 164 | # whereas in sequential settings we would have `num_samples + 2`. 165 | # This is caused by some loose stop condition. In parallel settings, 166 | # We only allocate `num_samples` data to reduce communication overhead. 167 | # Should probably make this more consistent sometime. 168 | if cumulative_num_samples not in [num_samples, num_samples + 2]: 169 | raise ValueError(f"cumulative_num_samples={cumulative_num_samples} f" 170 | f"but num_samples={num_samples}: Untested Territory") 171 | 172 | return inverse_hvp 173 | 174 | 175 | def compute_grad_zs( 176 | n_gpu: int, 177 | device: torch.device, 178 | model: torch.nn.Module, 179 | data_loader: torch.utils.data.DataLoader, 180 | params_filter: Optional[List[str]] = None, 181 | weight_decay: Optional[float] = None, 182 | weight_decay_ignores: Optional[List[str]] = None, 183 | ) -> List[List[torch.FloatTensor]]: 184 | 185 | if weight_decay_ignores is None: 186 | weight_decay_ignores = [ 187 | "bias", 188 | "LayerNorm.weight"] 189 | 190 | grad_zs = [] 191 | for inputs in data_loader: 192 | grad_z = compute_gradients( 193 | n_gpu=n_gpu, device=device, 194 | model=model, inputs=inputs, 195 | params_filter=params_filter, 196 | weight_decay=weight_decay, 197 | weight_decay_ignores=weight_decay_ignores) 198 | with torch.no_grad(): 199 | grad_zs.append([X.cpu() for X in grad_z]) 200 | 201 | return grad_zs 202 | 203 | 204 | def compute_influences( 205 | n_gpu: int, 206 | device: torch.device, 207 | model: torch.nn.Module, 208 | test_inputs: Dict[str, torch.Tensor], 209 | batch_train_data_loader: torch.utils.data.DataLoader, 210 | instance_train_data_loader: torch.utils.data.DataLoader, 211 | params_filter: Optional[List[str]] = None, 212 | weight_decay: Optional[float] = None, 213 | weight_decay_ignores: Optional[List[str]] = None, 214 | s_test_damp: float = 3e-5, 215 | s_test_scale: float = 1e4, 216 | s_test_num_samples: Optional[int] = None, 217 | s_test_iterations: int = 1, 218 | precomputed_s_test: Optional[List[torch.FloatTensor]] = None, 219 | precomputed_grad_zjs: Optional[List[torch.FloatTensor]] = None, 220 | train_indices_to_include: Optional[Union[np.ndarray, List[int]]] = None, 221 | ) -> Tuple[Dict[int, float], Dict[int, Dict], List[torch.FloatTensor]]: 222 | 223 | if s_test_iterations < 1: 224 | raise ValueError("`s_test_iterations` must >= 1") 225 | 226 | if weight_decay_ignores is None: 227 | # https://github.com/huggingface/transformers/blob/v3.0.2/src/transformers/trainer.py#L325 228 | weight_decay_ignores = [ 229 | "bias", 230 | "LayerNorm.weight"] 231 | 232 | if precomputed_s_test is not None: 233 | s_test = precomputed_s_test 234 | else: 235 | s_test = None 236 | for _ in range(s_test_iterations): 237 | _s_test = compute_s_test( 238 | n_gpu=n_gpu, 239 | device=device, 240 | model=model, 241 | test_inputs=test_inputs, 242 | train_data_loaders=[batch_train_data_loader], 243 | params_filter=params_filter, 244 | weight_decay=weight_decay, 245 | weight_decay_ignores=weight_decay_ignores, 246 | damp=s_test_damp, 247 | scale=s_test_scale, 248 | num_samples=s_test_num_samples) 249 | 250 | # Sum the values across runs 251 | if s_test is None: 252 | s_test = _s_test 253 | else: 254 | s_test = [ 255 | a + b for a, b in zip(s_test, _s_test) 256 | ] 257 | # Do the averaging 258 | s_test = [a / s_test_iterations for a in s_test] 259 | torch.cuda.empty_cache() 260 | 261 | influences = {} 262 | for index, train_inputs in enumerate(tqdm(instance_train_data_loader)): 263 | if precomputed_grad_zjs is not None: 264 | if type(precomputed_grad_zjs) is str: 265 | precomputed_grad_zj_path = precomputed_grad_zjs 266 | if os.path.exists(precomputed_grad_zj_path): 267 | grad_zj = torch.load(precomputed_grad_zj_path)[train_inputs['idx'].item()] 268 | else: 269 | grad_zj = precomputed_grad_zjs[train_inputs['idx'].item()] 270 | else: 271 | # Skip indices when a subset is specified to be included 272 | if (train_indices_to_include is not None) and ( 273 | index not in train_indices_to_include): 274 | continue 275 | 276 | grad_zj = compute_gradients( 277 | n_gpu=n_gpu, 278 | device=device, 279 | model=model, 280 | inputs=train_inputs, 281 | params_filter=params_filter, 282 | weight_decay=weight_decay, 283 | weight_decay_ignores=weight_decay_ignores) 284 | 285 | with torch.no_grad(): 286 | influence = [ 287 | - torch.sum(x * y) 288 | for x, y in zip(grad_zj, s_test)] 289 | 290 | influences[train_inputs['idx'].item()] = sum(influence).item() 291 | torch.cuda.empty_cache() 292 | 293 | return influences -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | import numpy as np 5 | import pandas as pd 6 | from tqdm import tqdm 7 | from torchvision.datasets import MNIST, CIFAR10 8 | from transformers import AutoModelForImageClassification, AutoImageProcessor 9 | from peft import LoraConfig, get_peft_model 10 | from datasets import load_dataset 11 | from torchvision.transforms import ( 12 | Compose, 13 | Normalize, 14 | RandomResizedCrop, 15 | ToTensor, 16 | ) 17 | from types import List 18 | 19 | 20 | def load_seeds(): 21 | return torch.load(f'{os.getcwd()}/random_seeds.pt') 22 | 23 | 24 | class NetBW(nn.Module): 25 | def __init__(self): 26 | super(NetBW, self).__init__() 27 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3) 28 | self.pool = nn.MaxPool2d(kernel_size=2) 29 | self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3) 30 | self.fc1 = nn.Linear(in_features=64*5*5, out_features=128) 31 | self.fc2 = nn.Linear(in_features=128, out_features=10) 32 | 33 | def forward(self, x, output_hidden_states=False): 34 | x = self.pool(nn.functional.gelu(self.conv1(x))) 35 | x = self.pool(nn.functional.gelu(self.conv2(x))) 36 | x = x.view(-1, 64*5*5) 37 | x_latent = nn.functional.gelu(self.fc1(x)) 38 | x = self.fc2(x_latent) 39 | if output_hidden_states: 40 | return (x, x_latent) 41 | return x 42 | 43 | 44 | class NetRGB(nn.Module): 45 | def __init__(self): 46 | super(NetRGB, self).__init__() 47 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5) 48 | self.pool = nn.MaxPool2d(kernel_size=2) 49 | self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5) 50 | self.fc1 = nn.Linear(in_features=64*5*5, out_features=128) 51 | self.fc2 = nn.Linear(in_features=128, out_features=10) 52 | 53 | def forward(self, x, output_hidden_states=False): 54 | x = self.pool(nn.functional.gelu(self.conv1(x))) 55 | x = self.pool(nn.functional.gelu(self.conv2(x))) 56 | x = x.view(-1, 64*5*5) 57 | x_latent = nn.functional.gelu(self.fc1(x)) 58 | x = self.fc2(x_latent) 59 | if output_hidden_states: 60 | return (x, x_latent) 61 | return x 62 | 63 | 64 | class NetBWThree(nn.Module): 65 | def __init__(self): 66 | super(NetBWThree, self).__init__() 67 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3) 68 | self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=2) 69 | self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=2) 70 | self.pool = nn.MaxPool2d(kernel_size=2) 71 | self.fc1 = nn.Linear(in_features=128*4*4, out_features=256) 72 | self.fc2 = nn.Linear(in_features=256, out_features=10) 73 | 74 | def forward(self, x, output_hidden_states=False): 75 | x = self.pool(nn.functional.gelu(self.conv1(x))) 76 | x = self.pool(nn.functional.gelu(self.conv2(x))) 77 | x = self.pool(nn.functional.gelu(self.conv3(x))) 78 | x = x.view(-1, 128*4*4) 79 | x_latent = nn.functional.gelu(self.fc1(x)) 80 | x = self.fc2(x_latent) 81 | if output_hidden_states: 82 | return (x, x_latent) 83 | return x 84 | 85 | 86 | class NetRGBThree(nn.Module): 87 | def __init__(self): 88 | super(NetRGBThree, self).__init__() 89 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3) 90 | self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3) 91 | self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=2) 92 | self.pool = nn.MaxPool2d(kernel_size=2) 93 | self.fc1 = nn.Linear(in_features=128*4*4, out_features=256) 94 | self.fc2 = nn.Linear(in_features=256, out_features=10) 95 | 96 | def forward(self, x, output_hidden_states=False): 97 | x = self.pool(nn.functional.gelu(self.conv1(x))) 98 | x = self.pool(nn.functional.gelu(self.conv2(x))) 99 | x = self.pool(nn.functional.gelu(self.conv3(x))) 100 | x = x.view(-1, 128*4*4) 101 | x_latent = nn.functional.gelu(self.fc1(x)) 102 | x = self.fc2(x_latent) 103 | if output_hidden_states: 104 | return (x, x_latent) 105 | return x 106 | 107 | 108 | class MNISTWithIdx(MNIST): 109 | def __getitem__(self, index): 110 | img, target = super(MNISTWithIdx, self).__getitem__(index) 111 | return img, target, index 112 | 113 | 114 | class CIFAR10WithIdx(CIFAR10): 115 | def __getitem__(self, index): 116 | img, target = super(CIFAR10WithIdx, self).__getitem__(index) 117 | return img, target, index 118 | 119 | 120 | def train_model(model, train_loader, optimizer, criterion, num_epochs, save_path=None, loo_idx=None): 121 | """Model training with option to leave one out by zero-ing out the loss. Saves last 5 checkpoints.""" 122 | for epoch in tqdm(range(num_epochs)): 123 | running_loss = 0.0 124 | for batch in train_loader: 125 | inputs, labels, indices = batch 126 | optimizer.zero_grad() 127 | outputs = model(inputs) 128 | loss = criterion(outputs, labels) 129 | if loo_idx is not None: 130 | if loo_idx in indices: 131 | loss[torch.where(torch.isin(indices, loo_idx))] = 0 #Remove from loss contribution 132 | loss = loss.mean() 133 | loss.backward() 134 | optimizer.step() 135 | running_loss += loss.item() 136 | print('Epoch %d loss: %.3f' % (epoch + 1, running_loss / len(train_loader))) 137 | if save_path is not None: 138 | if (epoch+1) > (num_epochs - 5): 139 | if not os.path.exists(save_path): 140 | os.makedirs(save_path) 141 | torch.save({ 142 | 'epoch': epoch, 143 | 'model_state_dict': model.state_dict(), 144 | 'optimizer_state_dict': optimizer.state_dict(), 145 | 'loss': running_loss / len(train_loader), 146 | }, os.path.join(save_path, f'ckpt_epoch_{epoch}.pth')) 147 | return model 148 | 149 | 150 | def test_model(model, test_loader, criterion): 151 | """Runs the CNN model and returns the test loss, accuracy and predictions.""" 152 | model.eval() 153 | correct = 0 154 | total = 0 155 | test_loss = [] 156 | predictions = [] 157 | 158 | with torch.no_grad(): 159 | for batch in test_loader: 160 | inputs, labels, _ = batch 161 | outputs = model(inputs) 162 | _, predicted = torch.max(torch.nn.functional.softmax(outputs), axis=1) 163 | total += labels.size(0) 164 | predictions.append(predicted.cpu().numpy()) 165 | correct += (predicted == labels).sum().item() 166 | test_loss.append(criterion(outputs, labels).cpu().numpy()) 167 | 168 | accuracy = 100 * correct / total 169 | 170 | # Concatenate the predicted values into a single numpy array 171 | predictions = np.concatenate(predictions) 172 | test_loss = np.concatenate(test_loss) 173 | 174 | return test_loss, accuracy, predictions 175 | 176 | 177 | def compute_gradient(model, criterion, instance): 178 | """Computes parameter gradient of the model for a given input.""" 179 | input, label = instance[0], instance[1] 180 | 181 | # Forward pass to compute the loss 182 | outputs = model(input) 183 | loss = criterion(outputs, label) 184 | 185 | model.zero_grad() 186 | 187 | # Extract the gradients of the inputs tensor 188 | gradient_tuple = torch.autograd.grad(outputs=loss, 189 | inputs=[param for _, param 190 | in model.named_parameters() 191 | if param.requires_grad]) 192 | 193 | return gradient_tuple 194 | 195 | 196 | def ViTLoRA(device): 197 | """Loads the ViT model as a peft model with LoRA.""" 198 | peft_config = LoraConfig(r=16, 199 | lora_alpha=16, 200 | target_modules=["query", "value"], 201 | lora_dropout=0.1, 202 | bias="none", 203 | modules_to_save=["classifier"], 204 | ) 205 | 206 | model = AutoModelForImageClassification.from_pretrained( 207 | 'google/vit-base-patch16-224-in21k', 208 | num_labels=10, 209 | ignore_mismatched_sizes=True, # provide this in case you're planning to fine-tune an already fine-tuned checkpoint 210 | ) 211 | 212 | model = get_peft_model(model, peft_config) 213 | model = model.to(device) 214 | return model 215 | 216 | 217 | def test_vit(data_loader, device, model): 218 | """Run the ViT model on the data in dataloader and return accuracy, loss and predictions.""" 219 | loss = [] 220 | preds = [] 221 | num_correct = 0 222 | total =0 223 | for batch in tqdm(data_loader): 224 | inputs = batch['pixel_values'] 225 | labels = batch['labels'] 226 | inputs = inputs.to(device) 227 | labels = labels.to(device) 228 | with torch.no_grad(): 229 | outputs = model(inputs, labels=labels) 230 | logits = outputs.logits 231 | pred = logits.argmax(dim=-1) 232 | tmp_correct = (pred == labels).sum().item() 233 | num_correct += tmp_correct 234 | total += len(labels) 235 | loss.extend(list(outputs.loss.cpu().numpy())) 236 | preds.extend(list(pred.cpu().numpy())) 237 | accuracy = num_correct * 1. / total 238 | return accuracy, loss, preds 239 | 240 | 241 | def load_vit_data(task, num_per_class): 242 | """Load the image data preprocessed by AutoImageProcessor for the ViT.""" 243 | dataset_name = 'mnist' if 'mnist' in task else 'cifar10' # Define the dataset name as in HuggingFace Datasets 244 | 245 | # Load the image processor and data transforms 246 | image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") 247 | normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std) 248 | data_transforms = Compose( 249 | [ 250 | RandomResizedCrop(image_processor.size["height"]), 251 | ToTensor(), 252 | normalize, 253 | ] 254 | ) 255 | 256 | def preprocess_data(example_batch): 257 | """Apply data_transforms across a batch.""" 258 | example_batch["pixel_values"] = [data_transforms(image.convert("RGB")) for image in example_batch["image"]] 259 | return example_batch 260 | 261 | # Load the training set and subselect it 262 | trainset = load_dataset(dataset_name, split='train') 263 | if task == 'cifar10': 264 | trainset = trainset.rename_column('img', 'image') 265 | trainset = trainset.add_column('idx', range(len(trainset))) 266 | trainset.set_transform(preprocess_data) 267 | train_idx = load_subset_indices(f'{os.getcwd()}/../data/{task}/train_subset_{num_per_class}pc.txt') 268 | trainset = trainset.select((idx for idx in range(len(trainset)) 269 | if idx in train_idx)) 270 | 271 | # Load the test set and subselect it 272 | testset = load_dataset(dataset_name, split='test') 273 | if task == 'cifar10': 274 | testset = testset.rename_column('img', 'image') 275 | testset = testset.add_column('idx', range(len(testset))) 276 | testset.set_transform(preprocess_data) 277 | test_idx = load_subset_indices(f'{os.getcwd()}/../data/{task}/test_subset.txt') 278 | testset = testset.select((idx for idx in range(len(testset)) 279 | if idx in test_idx)) 280 | 281 | return trainset, testset 282 | 283 | 284 | def load_subset_indices(idx_filepath): 285 | """Reads indices defined in a text file at idx_filepath.""" 286 | with open(idx_filepath, 'r') as f: 287 | indices = f.readlines() 288 | indices = [int(idx.strip()) for idx in indices] 289 | return indices 290 | 291 | 292 | def load_attribution_types(): 293 | return ['loo', 'ats', 'if', 'gd', 'gc'] 294 | 295 | 296 | def get_expected_tau_per_z(expected_attributions, test_idx): 297 | """Organise the expected attributions.""" 298 | seeds = load_seeds() 299 | # Reordering the attributions into respective dataframes 300 | expected_tau_per_z = {} 301 | for z_test_idx in test_idx: 302 | df = pd.DataFrame() 303 | for seed in seeds: 304 | df[seed] = expected_attributions[seed][f'z_test_{z_test_idx}'] 305 | expected_tau_per_z[f'z_test_{z_test_idx}'] = df 306 | return expected_tau_per_z 307 | 308 | 309 | def load_expected_tda_swa(num_ckpts: int, 310 | experiment: str, 311 | tau: List[str] = ['loo', 'ats', 'if', 'gd', 'gc']): 312 | """Loading the expected attribution of type tau across the last num_ckpts checkpoints.""" 313 | seeds = load_seeds() 314 | expected_tda = {} 315 | model_name, task, num_per_class = experiment.split('_') 316 | max_ckpt = 15 if 'mnist3'==task else 30 317 | ckpts = range(max_ckpt-num_ckpts, max_ckpt) 318 | 319 | for seed in seeds: 320 | cumulative_attribution=None 321 | for num_ckpt in ckpts: 322 | attribution = pd.read_csv(f'{os.getcwd()}/tda_scores/{model_name}/{tau}/{task}_{num_per_class}pc/{seed}/attribution_ckpt_{num_ckpt}.csv', index_col=False) 323 | cumulative_attribution = attribution if cumulative_attribution is None else cumulative_attribution + attribution 324 | expected_tda[seed] = cumulative_attribution/len(ckpts) 325 | return expected_tda 326 | 327 | 328 | def get_mu_and_sigma(tau: str, 329 | num_ckpts: int, 330 | experiment: str, 331 | test_idx: List[int]): 332 | """Compute mean and standard deviation across random seeds and checkpoints for all train-test pairs.""" 333 | expected_attributions = load_expected_tda_swa(num_ckpts=num_ckpts, 334 | experiment=experiment, 335 | tau=tau) 336 | expected_tau_per_z = get_expected_tau_per_z(expected_attributions=expected_attributions, 337 | test_idx=test_idx) 338 | all_means = pd.DataFrame() 339 | all_stds = pd.DataFrame() 340 | for z_test_idx in test_idx: 341 | means = expected_tau_per_z[f'z_test_{z_test_idx}'].mean(axis=1) 342 | all_means[z_test_idx] = means 343 | stds = expected_tau_per_z[f'z_test_{z_test_idx}'].std(axis=1) 344 | all_stds[z_test_idx] = stds 345 | return all_means, all_stds 346 | 347 | -------------------------------------------------------------------------------- /nn_influence_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | from tqdm import tqdm 10 | from utils import expected_gradient, expected_grad_grad 11 | from typing import Dict, List, Union, Optional, Tuple, Iterator, Any 12 | 13 | 14 | def count_parameters(model: torch.nn.Module) -> int: 15 | return sum(p.numel() for p in model.parameters()) 16 | 17 | 18 | def get_loss_with_weight_decay( 19 | device: torch.device, 20 | n_gpu: int, 21 | model: torch.nn.Module, 22 | inputs: Dict[str, torch.Tensor], 23 | weight_decay: Optional[float], 24 | weight_decay_ignores: Optional[List[str]]) -> float: 25 | criterion = nn.CrossEntropyLoss() 26 | input, label, _ = inputs 27 | label = torch.tensor([label]) if type(label) is not torch.Tensor else label 28 | outputs = model(input) 29 | loss = criterion(outputs, label) 30 | 31 | if n_gpu > 1: 32 | # mean() to average on multi-gpu parallel training 33 | loss = loss.mean() 34 | 35 | # In PyTorch, weight-decay loss and gradients are calculated in 36 | # optimizers rather in nn.Module, so we have to manually specify 37 | # this for the loss here. 38 | if weight_decay is not None: 39 | no_decay = ( 40 | weight_decay_ignores 41 | if weight_decay_ignores 42 | is not None else []) 43 | 44 | weight_decay_loss = torch.cat([ 45 | p.square().view(-1) 46 | for n, p in model.named_parameters() 47 | if not any(nd in n for nd in no_decay) 48 | ]).sum() * weight_decay 49 | loss = loss + weight_decay_loss 50 | 51 | return loss 52 | 53 | 54 | def compute_gradients( 55 | device: torch.device, 56 | n_gpu: int, 57 | model: torch.nn.Module, 58 | inputs: Dict[str, torch.Tensor], 59 | params_filter: Optional[List[str]], 60 | weight_decay: Optional[float], 61 | weight_decay_ignores: Optional[List[str]] 62 | ) -> List[torch.FloatTensor]: 63 | 64 | if params_filter is None: 65 | params_filter = [] 66 | 67 | model.zero_grad() 68 | loss = get_loss_with_weight_decay( 69 | device=device, n_gpu=n_gpu, 70 | model=model, inputs=inputs, 71 | weight_decay=weight_decay, 72 | weight_decay_ignores=weight_decay_ignores) 73 | 74 | return torch.autograd.grad( 75 | outputs=loss, 76 | inputs=[ 77 | param for name, param 78 | in model.named_parameters() 79 | if name not in params_filter], 80 | create_graph=True) 81 | 82 | 83 | def compute_hessian_vector_products( 84 | device: torch.device, 85 | n_gpu: int, 86 | model: torch.nn.Module, 87 | inputs: Dict[str, torch.Tensor], 88 | vectors: torch.FloatTensor, 89 | params_filter: Optional[List[str]], 90 | weight_decay: Optional[float], 91 | weight_decay_ignores: Optional[List[str]] 92 | ) -> List[torch.FloatTensor]: 93 | 94 | if params_filter is None: 95 | params_filter = [] 96 | 97 | model.zero_grad() 98 | loss = get_loss_with_weight_decay( 99 | model=model, n_gpu=n_gpu, 100 | device=device, inputs=inputs, 101 | weight_decay=weight_decay, 102 | weight_decay_ignores=weight_decay_ignores) 103 | 104 | grad_tuple = torch.autograd.grad( 105 | outputs=loss, 106 | inputs=[ 107 | param for name, param 108 | in model.named_parameters() 109 | if name not in params_filter], 110 | create_graph=True) 111 | 112 | model.zero_grad() 113 | grad_grad_tuple = torch.autograd.grad( 114 | outputs=grad_tuple, 115 | inputs=[ 116 | param for name, param 117 | in model.named_parameters() 118 | if name not in params_filter], 119 | grad_outputs=vectors, 120 | only_inputs=True 121 | ) 122 | 123 | return grad_grad_tuple 124 | 125 | 126 | def compute_s_test( 127 | n_gpu: int, 128 | device: torch.device, 129 | model: torch.nn.Module, 130 | test_inputs: Dict[str, torch.Tensor], 131 | train_data_loaders: List[torch.utils.data.DataLoader], 132 | params_filter: Optional[List[str]], 133 | weight_decay: Optional[float], 134 | weight_decay_ignores: Optional[List[str]], 135 | damp: float, 136 | scale: float, 137 | num_samples: Optional[int] = None, 138 | verbose: bool = True, 139 | ) -> List[torch.FloatTensor]: 140 | 141 | v = compute_gradients( 142 | model=model, 143 | n_gpu=n_gpu, 144 | device=device, 145 | inputs=test_inputs, 146 | params_filter=params_filter, 147 | weight_decay=weight_decay, 148 | weight_decay_ignores=weight_decay_ignores) 149 | 150 | # Technically, it's hv^-1 151 | last_estimate = list(v).copy() 152 | cumulative_num_samples = 0 153 | with tqdm(total=num_samples) as pbar: 154 | for data_loader in train_data_loaders: 155 | for i, inputs in enumerate(data_loader): 156 | this_estimate = compute_hessian_vector_products( 157 | model=model, 158 | n_gpu=n_gpu, 159 | device=device, 160 | vectors=last_estimate, 161 | inputs=inputs, 162 | params_filter=params_filter, 163 | weight_decay=weight_decay, 164 | weight_decay_ignores=weight_decay_ignores) 165 | # Recursively caclulate h_estimate 166 | # https://github.com/dedeswim/pytorch_influence_functions/blob/master/pytorch_influence_functions/influence_functions/hvp_grad.py#L118 167 | with torch.no_grad(): 168 | new_estimate = [ 169 | a + (1 - damp) * b - c / scale 170 | for a, b, c in zip(v, last_estimate, this_estimate) 171 | ] 172 | 173 | pbar.update(1) 174 | if verbose is True: 175 | new_estimate_norm = new_estimate[0].norm().item() 176 | last_estimate_norm = last_estimate[0].norm().item() 177 | estimate_norm_diff = new_estimate_norm - last_estimate_norm 178 | pbar.set_description(f"{new_estimate_norm:.2f} | {estimate_norm_diff:.2f}") 179 | 180 | cumulative_num_samples += 1 181 | last_estimate = new_estimate 182 | if num_samples is not None and i > num_samples: 183 | break 184 | 185 | # References: 186 | # https://github.com/kohpangwei/influence-release/blob/master/influence/genericNeuralNet.py#L475 187 | # Do this for each iteration of estimation 188 | # Since we use one estimation, we put this at the end 189 | inverse_hvp = [X / scale for X in last_estimate] 190 | 191 | # Sanity check 192 | # Note that in parallel settings, we should have `num_samples` 193 | # whereas in sequential settings we would have `num_samples + 2`. 194 | # This is caused by some loose stop condition. In parallel settings, 195 | # We only allocate `num_samples` data to reduce communication overhead. 196 | # Should probably make this more consistent sometime. 197 | if cumulative_num_samples not in [num_samples, num_samples + 2]: 198 | raise ValueError(f"cumulative_num_samples={cumulative_num_samples} f" 199 | f"but num_samples={num_samples}: Untested Territory") 200 | 201 | return inverse_hvp 202 | 203 | 204 | def compute_grad_zs( 205 | n_gpu: int, 206 | device: torch.device, 207 | model: torch.nn.Module, 208 | data_loader: torch.utils.data.DataLoader, 209 | params_filter: Optional[List[str]] = None, 210 | weight_decay: Optional[float] = None, 211 | weight_decay_ignores: Optional[List[str]] = None, 212 | ) -> List[List[torch.FloatTensor]]: 213 | 214 | if weight_decay_ignores is None: 215 | weight_decay_ignores = [ 216 | "bias", 217 | "LayerNorm.weight"] 218 | 219 | grad_zs = [] 220 | for inputs in data_loader: 221 | grad_z = compute_gradients( 222 | n_gpu=n_gpu, device=device, 223 | model=model, inputs=inputs, 224 | params_filter=params_filter, 225 | weight_decay=weight_decay, 226 | weight_decay_ignores=weight_decay_ignores) 227 | with torch.no_grad(): 228 | grad_zs.append([X.cpu() for X in grad_z]) 229 | 230 | return grad_zs 231 | 232 | 233 | def compute_influences( 234 | n_gpu: int, 235 | device: torch.device, 236 | model: torch.nn.Module, 237 | test_inputs: Dict[str, torch.Tensor], 238 | batch_train_data_loader: torch.utils.data.DataLoader, 239 | instance_train_data_loader: torch.utils.data.DataLoader, 240 | params_filter: Optional[List[str]] = None, 241 | weight_decay: Optional[float] = None, 242 | weight_decay_ignores: Optional[List[str]] = None, 243 | s_test_damp: float = 3e-5, 244 | s_test_scale: float = 1e4, 245 | s_test_num_samples: Optional[int] = None, 246 | s_test_iterations: int = 1, 247 | precomputed_s_test: Optional[List[torch.FloatTensor]] = None, 248 | train_indices_to_include: Optional[Union[np.ndarray, List[int]]] = None, 249 | ) -> Tuple[Dict[int, float], Dict[int, Dict], List[torch.FloatTensor]]: 250 | 251 | if s_test_iterations < 1: 252 | raise ValueError("`s_test_iterations` must >= 1") 253 | 254 | if weight_decay_ignores is None: 255 | # https://github.com/huggingface/transformers/blob/v3.0.2/src/transformers/trainer.py#L325 256 | weight_decay_ignores = [ 257 | "bias", 258 | "LayerNorm.weight"] 259 | 260 | if precomputed_s_test is not None: 261 | s_test = precomputed_s_test 262 | else: 263 | s_test = None 264 | for _ in range(s_test_iterations): 265 | _s_test = compute_s_test( 266 | n_gpu=n_gpu, 267 | device=device, 268 | model=model, 269 | test_inputs=test_inputs, 270 | train_data_loaders=[batch_train_data_loader], 271 | params_filter=params_filter, 272 | weight_decay=weight_decay, 273 | weight_decay_ignores=weight_decay_ignores, 274 | damp=s_test_damp, 275 | scale=s_test_scale, 276 | num_samples=s_test_num_samples) 277 | 278 | # Sum the values across runs 279 | if s_test is None: 280 | s_test = _s_test 281 | else: 282 | s_test = [ 283 | a + b for a, b in zip(s_test, _s_test) 284 | ] 285 | # Do the averaging 286 | s_test = [a / s_test_iterations for a in s_test] 287 | 288 | influences = {} 289 | # train_inputs_collections = {} 290 | for train_inputs in tqdm(instance_train_data_loader): 291 | index = train_inputs[2] 292 | # Skip indices when a subset is specified to be included 293 | if (train_indices_to_include is not None) and ( 294 | index not in train_indices_to_include): 295 | continue 296 | 297 | grad_z = compute_gradients( 298 | n_gpu=n_gpu, 299 | device=device, 300 | model=model, 301 | inputs=train_inputs, 302 | params_filter=params_filter, 303 | weight_decay=weight_decay, 304 | weight_decay_ignores=weight_decay_ignores) 305 | 306 | with torch.no_grad(): 307 | influence = [ 308 | - torch.sum(x * y) 309 | for x, y in zip(grad_z, s_test)] 310 | 311 | influences[index] = sum(influence).item() 312 | # train_inputs_collections[index] = train_inputs 313 | 314 | return influences#, train_inputs_collections, s_test 315 | 316 | 317 | def compute_s_test_from_expected_grads( 318 | n_gpu: int, 319 | device: torch.device, 320 | num_ckpts: int, 321 | test_inputs: Dict[str, torch.Tensor], 322 | train_data_loaders: List[torch.utils.data.DataLoader], 323 | params_filter: Optional[List[str]], 324 | weight_decay: Optional[float], 325 | weight_decay_ignores: Optional[List[str]], 326 | damp: float, 327 | scale: float, 328 | num_samples: Optional[int] = None, 329 | verbose: bool = True, 330 | ) -> List[torch.FloatTensor]: 331 | 332 | v = expected_gradient(num_ckpts=num_ckpts, 333 | input=test_inputs) 334 | 335 | # for saving the history of HVP estimation to see if it converges 336 | estimate_hist = [] 337 | # Technically, it's hv^-1 338 | last_estimate = list(v).copy() 339 | cumulative_num_samples = 0 340 | with tqdm(total=num_samples) as pbar: 341 | for data_loader in train_data_loaders: 342 | for i, inputs in enumerate(data_loader): 343 | this_estimate = expected_grad_grad( 344 | num_ckpts=num_ckpts, 345 | input=inputs) 346 | # Recursively caclulate h_estimate 347 | # https://github.com/dedeswim/pytorch_influence_functions/blob/master/pytorch_influence_functions/influence_functions/hvp_grad.py#L118 348 | with torch.no_grad(): 349 | batch_size = this_estimate.shape[0] 350 | for idx in range(batch_size): 351 | new_estimate = [ 352 | a + (1 - damp) * b - c / scale 353 | for a, b, c in zip(v, last_estimate, this_estimate[idx]) 354 | ] 355 | estimate_hist.append(new_estimate[0].norm().item()) 356 | 357 | pbar.update(batch_size) 358 | if verbose is True: 359 | new_estimate_norm = new_estimate[0].norm().item() 360 | last_estimate_norm = last_estimate[0].norm().item() 361 | estimate_norm_diff = new_estimate_norm - last_estimate_norm 362 | pbar.set_description(f"{new_estimate_norm:.2f} | {estimate_norm_diff:.2f}") 363 | 364 | cumulative_num_samples += batch_size 365 | last_estimate = new_estimate 366 | if num_samples is not None and i > num_samples: 367 | break 368 | 369 | # References: 370 | # https://github.com/kohpangwei/influence-release/blob/master/influence/genericNeuralNet.py#L475 371 | # Do this for each iteration of estimation 372 | # Since we use one estimation, we put this at the end 373 | inverse_hvp = [X / scale for X in last_estimate] 374 | 375 | # Sanity check 376 | # Note that in parallel settings, we should have `num_samples` 377 | # whereas in sequential settings we would have `num_samples + 2`. 378 | # This is caused by some loose stop condition. In parallel settings, 379 | # We only allocate `num_samples` data to reduce communication overhead. 380 | # Should probably make this more consistent sometime. 381 | if cumulative_num_samples not in [num_samples, num_samples + 2]: 382 | raise ValueError(f"cumulative_num_samples={cumulative_num_samples} f" 383 | f"but num_samples={num_samples}: Untested Territory") 384 | 385 | return inverse_hvp, estimate_hist 386 | 387 | 388 | def compute_influences_from_expected_grads( 389 | n_gpu: int, 390 | device: torch.device, 391 | num_ckpts: int, 392 | test_inputs: Dict[str, torch.Tensor], 393 | batch_train_data_loader: torch.utils.data.DataLoader, 394 | instance_train_data_loader: torch.utils.data.DataLoader, 395 | params_filter: Optional[List[str]] = None, 396 | weight_decay: Optional[float] = None, 397 | weight_decay_ignores: Optional[List[str]] = None, 398 | s_test_damp: float = 3e-5, 399 | s_test_scale: float = 1e4, 400 | s_test_num_samples: Optional[int] = None, 401 | s_test_iterations: int = 1, 402 | precomputed_s_test: Optional[List[torch.FloatTensor]] = None, 403 | train_indices_to_include: Optional[Union[np.ndarray, List[int]]] = None, 404 | ) -> Tuple[Dict[int, float], Dict[int, Dict], List[torch.FloatTensor]]: 405 | 406 | if s_test_iterations < 1: 407 | raise ValueError("`s_test_iterations` must >= 1") 408 | 409 | if weight_decay_ignores is None: 410 | # https://github.com/huggingface/transformers/blob/v3.0.2/src/transformers/trainer.py#L325 411 | weight_decay_ignores = [ 412 | "bias", 413 | "LayerNorm.weight"] 414 | 415 | if precomputed_s_test is not None: 416 | s_test = precomputed_s_test 417 | else: 418 | s_test = None 419 | for _ in range(s_test_iterations): 420 | _s_test = compute_s_test_from_expected_grads( 421 | n_gpu=n_gpu, 422 | device=device, 423 | num_ckpts=num_ckpts, 424 | test_inputs=test_inputs, 425 | train_data_loaders=[batch_train_data_loader], 426 | params_filter=params_filter, 427 | weight_decay=weight_decay, 428 | weight_decay_ignores=weight_decay_ignores, 429 | damp=s_test_damp, 430 | scale=s_test_scale, 431 | num_samples=s_test_num_samples) 432 | 433 | # Sum the values across runs 434 | if s_test is None: 435 | s_test = _s_test 436 | else: 437 | s_test = [ 438 | a + b for a, b in zip(s_test, _s_test) 439 | ] 440 | # Do the averaging 441 | s_test = [a / s_test_iterations for a in s_test] 442 | 443 | influences = {} 444 | # train_inputs_collections = {} 445 | for train_inputs in tqdm(instance_train_data_loader): 446 | index = train_inputs[2].item() 447 | # Skip indices when a subset is specified to be included 448 | if (train_indices_to_include is not None) and ( 449 | index not in train_indices_to_include): 450 | continue 451 | 452 | grad_z = expected_gradient(num_ckpts=num_ckpts, 453 | input=train_inputs) 454 | 455 | with torch.no_grad(): 456 | influence = [ 457 | - torch.sum(x * y) 458 | for x, y in zip(grad_z, s_test)] 459 | 460 | influences[index] = sum(influence).item() 461 | # train_inputs_collections[index] = train_inputs 462 | 463 | return influences#, train_inputs_collections, s_test 464 | --------------------------------------------------------------------------------