├── README.md ├── feature_viz ├── direct_feature_visualization.ipynb └── images │ ├── boat1.jpg │ ├── dog1.jpg │ └── f35.jpg ├── influence_functions ├── hessian.py ├── invert_hessian.py ├── make_grads.py ├── make_influences.py └── visualizing influence functions.ipynb ├── results ├── full_data.csv ├── mass_extractor.py ├── plots │ ├── constraint.pdf │ ├── ft_blocks_8.pdf │ ├── less_epochs.pdf │ └── less_images.pdf ├── tables │ └── summary_table.csv └── visualize_results.ipynb ├── test.py ├── tools ├── batch.py ├── constants.py ├── custom_datasets.py ├── delete_big_files.py ├── helpers.py └── transforms.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # Adversarially-Trained Deep Nets Transfer Better 2 | 3 | This github repo contains the code used for the "Adversarially-Trained Deep Nets Transfer Better" paper 4 | 5 | ## How to use 6 | 7 | 1. Download all source models into models directory and save with the appropriate name: 8 | - https://www.dropbox.com/s/knf4uimlqsi1yz8/imagenet_l2_3_0.pt?dl=0; save as 'imagenet_l2_3_0.pt' 9 | - https://www.dropbox.com/s/axfuary2w1cnyrg/imagenet_linf_4.pt?dl=0; save as 'imagenet_linf_4.pt' 10 | - https://www.dropbox.com/s/yxn15a9zklz3s8q/imagenet_linf_8.pt?dl=0; save as 'imagenet_linf_8.pt' 11 | 12 | 2. Install all dependencies: 13 | 14 | - pip install robustness 15 | - conda install pytorch==1.2.0 torchvision==0.4.0 cudatoolkit=10.0 -c pytorch 16 | 17 | 3. If you want to train a single model, run the train.py file. For more info on the run-time inputs, run "python train.py --help". 18 | 19 | 4. If you want to replicate the entire experiment run the tools/batch.py file. Keep in mind that this might take a considerable ammount of time since we fine-tune over 14 thousand models. 20 | 21 | 5. Find the logs including the validation accuracy in the results/logs folder. Use the log_extractor.py file to extract all your logs into a nice csv format. 22 | -------------------------------------------------------------------------------- /feature_viz/images/boat1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/utrerf/robust_transfer_learning/ff18fcaf6463c388d3a8cb972806eb61c842ea65/feature_viz/images/boat1.jpg -------------------------------------------------------------------------------- /feature_viz/images/dog1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/utrerf/robust_transfer_learning/ff18fcaf6463c388d3a8cb972806eb61c842ea65/feature_viz/images/dog1.jpg -------------------------------------------------------------------------------- /feature_viz/images/f35.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/utrerf/robust_transfer_learning/ff18fcaf6463c388d3a8cb972806eb61c842ea65/feature_viz/images/f35.jpg -------------------------------------------------------------------------------- /influence_functions/hessian.py: -------------------------------------------------------------------------------- 1 | from robustness import train 2 | from robustness.model_utils import make_and_restore_model 3 | from robustness.datasets import ImageNet, CIFAR, CINIC 4 | from torch import autograd 5 | from torchvision import datasets 6 | import torch as ch 7 | import numpy as np 8 | import argparse 9 | import random 10 | from scipy import linalg 11 | import os 12 | # note: currently only supported for CIFAR10 and epsilon=[0,3] 13 | import sys 14 | sys.path.insert(1, '..') 15 | from tools.helpers import flatten_grad, make_mask, set_seeds, eval_hessian 16 | from tools.helpers import get_runtime_inputs_for_influence_functions 17 | import tools.custom_datasets as custom_datasets 18 | 19 | args = get_runtime_inputs_for_influence_functions() 20 | eps = args.e 21 | num_images = args.n 22 | batch_size = args.b 23 | seed = args.s 24 | ub = args.ub 25 | ds = args.ds 26 | 27 | # LOAD MODEL 28 | eps_to_model = {3: f'l2_{eps}_imagenet_to_{ds}_{ub}_ub_{num_images}_images.pt', 29 | 0: f'nat_imagenet_to_{ds}_{ub}_ub_{num_images}_images.pt'} 30 | 31 | source_model_path = '../models/' + eps_to_model[eps] 32 | 33 | model, _ = make_and_restore_model(arch='resnet50_m', dataset=ImageNet('/tmp'), 34 | resume_path=source_model_path, parallel=False) 35 | model.eval() 36 | criterion = ch.nn.CrossEntropyLoss() 37 | 38 | # MAKE DATASET 39 | size = (224, 224) 40 | if ds == 'cifar10': 41 | data_set = datasets.CIFAR10(root='/tmp', train=True, download=True, 42 | transform=custom_datasets.TEST_TRANSFORMS_DEFAULT(size)) 43 | elif ds == 'svhn': 44 | data_set = datasets.SVHN(root='/tmp', split='train', download=True, 45 | transform=custom_datasets.TEST_TRANSFORMS_DEFAULT(size)) 46 | # SET THE SEEDS 47 | set_seeds({'seed':1000000}) 48 | 49 | # MAKE THE MASK 50 | dataset_size = len(data_set) 51 | mask_sampler = make_mask(num_images, data_set) 52 | 53 | # MAKE THE LOADER 54 | num_workers = 10 55 | loader = ch.utils.data.DataLoader(data_set, sampler=mask_sampler, batch_size=batch_size, 56 | shuffle=False, num_workers=num_workers, pin_memory=True) 57 | 58 | # MAKE THE HESSIAN 59 | for i, data in enumerate(loader): 60 | print((i+1)*100) 61 | image, label = data 62 | output, final_inp = model(image.cuda()) 63 | output = output.cpu() 64 | loss = criterion(output, label).cpu().double() 65 | loss_grad = autograd.grad(loss, model.model.fc.parameters(), create_graph=True) 66 | loss_grad = (loss_grad[0].cpu().double(), loss_grad[1].cpu().double()) 67 | H_i = eval_hessian(loss_grad, model.model.fc) 68 | if i > 0: 69 | H += H_i 70 | else: 71 | H = H_i 72 | ch.cuda.empty_cache() 73 | 74 | if 'hessians' not in os.listdir(): os.mkdir('hessians') 75 | np.save(f'hessians/H_{ds}_target_{ub}_ub_{eps}_eps_{num_images}_images_{batch_size}_batch_size', H) 76 | 77 | # INVERT THE HESSIAN 78 | H_inv = linalg.pinv2(H, rcond=5e-2) 79 | if 'h_inverses' not in os.listdir(): os.mkdir('h_inverses') 80 | np.save(f'h_inverses/H_inv_{ds}_target_{ub}_ub_{eps}_eps_{num_images}_images_{batch_size}_batch_size', H_inv) 81 | 82 | -------------------------------------------------------------------------------- /influence_functions/invert_hessian.py: -------------------------------------------------------------------------------- 1 | from robustness import train 2 | from robustness.model_utils import make_and_restore_model 3 | from robustness.datasets import ImageNet, CIFAR, CINIC 4 | from torch import autograd 5 | from torchvision import datasets 6 | import torch as ch 7 | import numpy as np 8 | import argparse 9 | import random 10 | from scipy import linalg 11 | import os 12 | # note: currently only supported for CIFAR10 and epsilon=[0,3] 13 | import sys 14 | sys.path.insert(1, '..') 15 | from tools.helpers import flatten_grad, make_mask, set_seeds, eval_hessian 16 | import tools.custom_datasets as custom_datasets 17 | 18 | # PARSE INPUTS 19 | parser = argparse.ArgumentParser(add_help=True) 20 | 21 | eps_list = [0, 3] 22 | parser.add_argument('-e', required=False, default=0, 23 | help='epsilon used to train the source dataset', type=int, choices=eps_list) 24 | parser.add_argument('-n', required=False, default=100, 25 | help='number of images used to make hessian', type=int) 26 | parser.add_argument('-b', required=False, default=50, 27 | help='batch_size', type=int) 28 | parser.add_argument('-s', required=False, default=1000000, 29 | help='seed', type=int) 30 | parser.add_argument('-ub', required=False, default=1000000, 31 | help='number of unfrozen blocks', type=int) 32 | 33 | args = parser.parse_args() 34 | 35 | eps = args.e 36 | num_images = args.n 37 | batch_size = args.b 38 | seed = args.s 39 | ub = args.ub 40 | 41 | 42 | H = np.load(f'hessians/H_{ub}_ub_{eps}_eps_{num_images}_images_{batch_size}_batch_size.npy') 43 | 44 | # INVERT THE HESSIAN 45 | H_inv = linalg.pinv2(H, rcond=1e-20) 46 | if 'h_inverses' not in os.listdir(): os.mkdir('h_inverses') 47 | np.save(f'h_inverses/H_inv_{ub}_ub_{eps}_eps_{num_images}_images_{batch_size}_batch_size', H_inv) 48 | 49 | -------------------------------------------------------------------------------- /influence_functions/make_grads.py: -------------------------------------------------------------------------------- 1 | import os 2 | from robustness.model_utils import make_and_restore_model 3 | from robustness.datasets import ImageNet, CIFAR, CINIC 4 | from torch import autograd 5 | from torchvision import datasets 6 | import torch as ch 7 | import numpy as np 8 | import argparse 9 | # note: currently only supported for CIFAR10 and epsilon=[0,3] 10 | import sys 11 | sys.path.insert(1, '..') 12 | from tools.helpers import flatten_grad, make_mask, set_seeds 13 | from tools.helpers import get_runtime_inputs_for_influence_functions 14 | import tools.custom_datasets as custom_datasets 15 | 16 | args = get_runtime_inputs_for_influence_functions() 17 | 18 | eps = args.e 19 | num_images = args.n 20 | ub = args.ub 21 | data_type = args.t 22 | seed = args.s 23 | ds = args.ds 24 | isTrain = (data_type == 'train') 25 | 26 | # LOAD MODEL 27 | eps_to_model = {3: f'l2_{eps}_imagenet_to_cifar10_{ub}_ub_{num_images}_images.pt', 28 | 0: f'nat_imagenet_to_cifar10_{ub}_ub_{num_images}_images.pt'} 29 | 30 | source_model_path = '../models/' + eps_to_model[eps] 31 | 32 | model, _ = make_and_restore_model(arch='resnet50_m', dataset=ImageNet('/tmp'), 33 | resume_path=source_model_path, parallel=False) 34 | model.eval() 35 | criterion = ch.nn.CrossEntropyLoss() 36 | 37 | # MAKE DATASET 38 | size = (224, 224) 39 | if ds == 'cifar10': 40 | data_set = datasets.CIFAR10(root='/tmp', train=isTrain, download=True, 41 | transform=custom_datasets.TEST_TRANSFORMS_DEFAULT(size)) 42 | elif ds == 'svhn': 43 | split = 'test' 44 | if isTrain: split = 'train' 45 | data_set = datasets.SVHN(root='/tmp', split=split, download=True, 46 | transform=custom_datasets.TEST_TRANSFORMS_DEFAULT(size)) 47 | 48 | set_seeds({'seed':seed}) 49 | subset = data_set 50 | if isTrain: 51 | dataset_size = len(data_set) 52 | mask_sampler = make_mask(num_images, data_set) 53 | subset = ch.utils.data.Subset(data_set, mask_sampler.indices) 54 | # SAVE MASK 55 | mask = mask_sampler.indices 56 | np.save(f'mask/{ds}_{num_images}_num_images_{seed}_seed', mask) 57 | 58 | # MAKE THE GRADIENTS 59 | batch_size = 1 60 | num_workers = 1 61 | loader = ch.utils.data.DataLoader(subset, batch_size=batch_size, 62 | shuffle=False, num_workers=num_workers, pin_memory=True) 63 | 64 | import os 65 | base1 = f'{data_type}_grad' 66 | base2 = f'{ds}_{eps}_eps_{ub}_ub_{num_images}_images' 67 | if base1 not in os.listdir(): 68 | os.mkdir(base1) 69 | if base2 not in os.listdir(base1): 70 | os.mkdir(base1 + '/' + base2) 71 | 72 | # get influence 73 | for i, data in enumerate(loader): 74 | image, label = data 75 | output, final_inp = model(image.cuda()) 76 | loss = criterion(output.cpu(), label) 77 | loss_grad = autograd.grad(loss.double(), model.model.fc.parameters(), create_graph=False) 78 | loss_grad = flatten_grad(loss_grad) 79 | ch.cuda.empty_cache() 80 | if i%1000 == 0: 81 | grad = loss_grad 82 | else: 83 | grad = np.vstack((grad, loss_grad)) 84 | if ((i+1)%1000) == 0: 85 | np.save(base1+'/'+base2+'/' + f'{i}_end_idx', np.array(grad)) 86 | if data_type == 'test': np.save(base1+'/'+base2+'/' + f'{num_images-1}_end_idx', np.array(grad)) 87 | -------------------------------------------------------------------------------- /influence_functions/make_influences.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import sys 4 | sys.path.insert(1, '..') 5 | from tools.helpers import get_runtime_inputs_for_influence_functions, load_gradients 6 | 7 | 8 | args = get_runtime_inputs_for_influence_functions() 9 | 10 | eps = args.e 11 | num_images = args.n 12 | ub = args.ub 13 | seed = args.s 14 | ds = args.ds 15 | b = args.b 16 | data_type = args.t 17 | isTrain = (data_type == 'train') 18 | 19 | train_gradients = load_gradients(ds, eps, ub, num_images, train_or_test='train') 20 | test_gradients = train_gradients 21 | if isTrain == False: 22 | test_gradients = load_gradients(ds, eps, ub, num_images, train_or_test='test') 23 | 24 | H_inv = np.load(f'h_inverses/H_inv_{ds}_target_{ub}_ub_{eps}_eps_{num_images}_images_{b}_batch_size.npy') 25 | 26 | influences = test_gradients @ H_inv @ train_gradients.T 27 | 28 | np.save(f'influences/{ds}_target_{ub}_ub_{eps}_eps_{num_images}_images_{b}_batch_size_{data_type}', influences) 29 | -------------------------------------------------------------------------------- /results/mass_extractor.py: -------------------------------------------------------------------------------- 1 | import os 2 | from cox.store import Store 3 | from cox.readers import CollectionReader 4 | import pandas as pd 5 | import re 6 | import argparse 7 | 8 | parser = argparse.ArgumentParser(add_help=True) 9 | parser.add_argument('-o', required=False, default='log_summary.csv', help='output file', type=str) 10 | 11 | args = parser.parse_args() 12 | 13 | os.chdir('logs') 14 | 15 | df = pd.DataFrame() 16 | folder_list = list(os.listdir()) 17 | 18 | for folder in folder_list: 19 | os.chdir(folder) 20 | try: 21 | reader = CollectionReader(os.getcwd()) 22 | new_df = reader.df('logs') 23 | 24 | except: 25 | pass 26 | # reader.close() 27 | 28 | else: 29 | 30 | new_df['source_eps'] = re.findall(r'source_eps_(\d*\.*\d*)_', folder)[0] 31 | new_df['target_ds'] = re.findall(r'target_dataset_([a-z,0-9]+)', folder)[0] 32 | new_df['num_training_images'] = re.findall(r'num_training_images_([a-z,0-9,-]+)', folder)[0] 33 | new_df['unfrozen_blocks'] = re.findall(r'unfrozen_blocks_([a-z,0-9]+)', folder)[0] 34 | new_df['seed'] = re.findall(r'seed_([a-z,0-9]+)', folder)[0] 35 | 36 | df = df.append(new_df, sort=False) 37 | 38 | reader.close() 39 | os.chdir("..") 40 | 41 | os.chdir("..") 42 | df.to_csv(args.o) 43 | -------------------------------------------------------------------------------- /results/plots/constraint.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/utrerf/robust_transfer_learning/ff18fcaf6463c388d3a8cb972806eb61c842ea65/results/plots/constraint.pdf -------------------------------------------------------------------------------- /results/plots/ft_blocks_8.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/utrerf/robust_transfer_learning/ff18fcaf6463c388d3a8cb972806eb61c842ea65/results/plots/ft_blocks_8.pdf -------------------------------------------------------------------------------- /results/plots/less_epochs.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/utrerf/robust_transfer_learning/ff18fcaf6463c388d3a8cb972806eb61c842ea65/results/plots/less_epochs.pdf -------------------------------------------------------------------------------- /results/plots/less_images.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/utrerf/robust_transfer_learning/ff18fcaf6463c388d3a8cb972806eb61c842ea65/results/plots/less_images.pdf -------------------------------------------------------------------------------- /results/tables/summary_table.csv: -------------------------------------------------------------------------------- 1 | num_training_images,target_ds,0,3,4,8 2 | 100,cifar10,48.49299850463867,51.890998458862306,49.86199836730957,49.028998947143556 3 | 100,cifar100,7.970999765396118,8.537999773025513,8.236999607086181,8.842999744415284 4 | 100,fmnist,66.11299858093261,71.7989990234375,71.23899841308594,72.69199676513672 5 | 100,kmnist,34.29499893188476,51.76299858093262,50.119998931884766,51.05299873352051 6 | 100,mnist,68.69499816894532,83.83199768066406,83.03399810791015,82.8069969177246 7 | 100,svhn,23.502611541748045,37.89989166259765,34.751458930969235,35.754455184936525 8 | 200,cifar10,57.44099884033203,64.0539981842041,62.82999877929687,60.109998321533205 9 | 200,cifar100,13.053999710083009,13.258999633789063,13.713999557495118,12.793999767303466 10 | 200,fmnist,73.52899856567383,78.50699920654297,77.6709976196289,78.3329978942871 11 | 200,kmnist,44.77199859619141,62.973998641967775,60.35799903869629,62.835998916625975 12 | 200,mnist,81.58299713134765,90.7209976196289,90.44099807739258,90.49099731445312 13 | 200,svhn,29.36578025817871,50.8854476928711,46.82237167358399,47.77005119323731 14 | 400,cifar10,66.83999710083008,71.64899826049805,71.92899780273437,68.81199722290039 15 | 400,cifar100,18.57099962234497,21.735999488830565,21.31599998474121,19.789999580383302 16 | 400,fmnist,78.79799880981446,82.09099731445312,81.40399780273438,81.57999725341797 17 | 400,kmnist,52.45499877929687,71.65099716186523,70.17599792480469,71.88699722290039 18 | 400,mnist,88.09499740600586,94.27599792480468,93.90499877929688,93.97899856567383 19 | 400,svhn,37.29717178344727,61.96680870056152,58.989319229125975,61.022970581054686 20 | 800,cifar10,74.19899826049804,78.67899780273437,78.99199829101562,75.947998046875 21 | 800,cifar100,27.57299938201904,33.68299903869629,32.551998710632326,30.172999000549318 22 | 800,fmnist,82.73299636840821,85.2089973449707,84.5989974975586,84.58799743652344 23 | 800,kmnist,63.00499801635742,79.6709976196289,77.72999801635743,80.67999725341797 24 | 800,mnist,91.81399765014649,96.0479965209961,95.57899780273438,95.9229965209961 25 | 800,svhn,49.167562103271486,71.1919937133789,68.83566360473633,70.19283905029297 26 | 1600,cifar10,79.73999710083008,83.75799789428712,83.46999740600586,81.10299682617188 27 | 1600,cifar100,37.87799911499023,45.642998123168944,44.48899917602539,40.818999481201175 28 | 1600,fmnist,85.39099884033203,87.58199768066406,87.08699798583984,86.90599822998047 29 | 1600,kmnist,74.06099853515624,84.94599761962891,83.66499710083008,85.51699829101562 30 | 1600,mnist,94.46799774169922,97.1349967956543,96.99999771118163,97.1609977722168 31 | 1600,svhn,62.139288330078124,77.83766098022461,75.2377830505371,76.76513366699218 32 | 3200,cifar10,83.76699905395508,87.56299743652343,87.34399871826172,85.07099609375 33 | 3200,cifar100,51.459999084472656,56.90899887084961,55.789998626708986,51.34299850463867 34 | 3200,fmnist,87.97199859619141,89.4349983215332,89.11099853515626,88.88899765014648 35 | 3200,kmnist,81.41999816894531,89.67899703979492,88.26199722290039,89.1389991760254 36 | 3200,mnist,96.39899826049805,98.12599868774414,97.93299865722656,98.10299682617188 37 | 3200,svhn,74.31123046875,85.54663314819337,81.57075729370118,84.72802658081055 38 | 6400,cifar10,87.27599792480468,90.91499710083008,90.2349983215332,88.5969970703125 39 | 6400,cifar100,62.255998611450195,66.09799728393554,65.30899810791016,60.74999885559082 40 | 6400,fmnist,89.62999877929687,90.72599792480469,90.43999786376953,90.10799789428711 41 | 6400,kmnist,86.70499649047852,93.61199645996093,91.6989974975586,92.8589988708496 42 | 6400,mnist,97.3369972229004,98.73499755859375,98.49699783325195,98.57499771118164 43 | 6400,svhn,82.63214416503907,90.45943298339844,88.17147827148438,89.53941116333007 44 | 12800,cifar10,91.94599761962891,93.19999694824219,93.10999755859375,91.597998046875 45 | 12800,cifar100,69.91599731445312,71.61399688720704,71.4219985961914,67.15599822998047 46 | 12800,fmnist,91.55799865722656,92.08799743652344,91.96599578857422,91.68999786376953 47 | 12800,kmnist,94.31999816894532,96.20599822998047,95.72799835205078,95.8959976196289 48 | 12800,mnist,98.5239974975586,99.10599670410156,98.97599639892579,99.03199768066406 49 | 12800,svhn,88.69544982910156,92.86186065673829,91.68023834228515,91.81468811035157 50 | 25600,cifar10,93.93199768066407,94.92399749755859,94.62799835205078,93.28399810791015 51 | 25600,cifar100,75.07599792480468,75.67799835205078,75.39999847412109,71.5219970703125 52 | 25600,fmnist,93.41999816894531,93.23999786376953,93.14399871826171,92.78600006103515 53 | 25600,kmnist,97.00599822998046,97.5239974975586,97.33799743652344,97.44399871826172 54 | 25600,mnist,99.06199798583984,99.2739974975586,99.20799865722657,99.23799743652344 55 | 25600,svhn,92.24262084960938,94.35694274902343,93.75691375732421,94.08343200683593 56 | 50000,cifar10,95.36000061035156,95.86000061035156,95.5999984741211,94.57999420166016 57 | 50000,cifar100,78.48999786376953,79.23999786376953,79.50999450683594,76.05999755859375 58 | 60000,fmnist,94.50999450683594,94.72000122070312,94.40999603271484,93.97999572753906 59 | 60000,kmnist,98.06999969482422,98.3499984741211,98.18999481201172,98.37999725341795 60 | 60000,mnist,99.18999481201172,99.41999816894531,99.33999633789062,99.38999938964844 61 | 73257,svhn,95.32882690429688,96.01644134521484,95.41717529296876,95.63229370117188 62 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import tools.custom_datasets as custom_datasets 3 | from tools.custom_datasets import ImageNetTransfer, name_to_dataset 4 | import tools.constants as constants 5 | from cox.utils import Parameters 6 | import cox.store 7 | import os 8 | from robustness.model_utils import make_and_restore_model 9 | import robustness.train as train 10 | 11 | def get_input_args(): 12 | eps_list = list(constants.eps_to_filename.keys()) 13 | target_dataset_list = list(custom_datasets.name_to_dataset.keys()) 14 | norm_list = ['inf', '2'] 15 | 16 | parser = argparse.ArgumentParser(add_help=True) 17 | parser.add_argument('-p', required=False, default=0, help='resume path for the model', type=str) 18 | parser.add_argument('-out', required=False, default=f'{os.getcwd()}/test_results', help='resume path for the model', type=str) 19 | parser.add_argument('-ds', required=False, default='cifar10', help='name of the target dataset', type=str, choices=target_dataset_list) 20 | parser.add_argument('-e', required=False, default=0, help='test epsilon', type=float) 21 | parser.add_argument('-pgd', required=False, default=128, help='number of pgd steps for testing', type=int) 22 | parser.add_argument('-lp', required=False, default='2', help='norm of the lp constraint', type=str, choices=norm_list) 23 | 24 | input_args = parser.parse_args() 25 | 26 | return input_args 27 | 28 | input_args = get_input_args() 29 | 30 | dataset = ImageNetTransfer(name=input_args.ds, 31 | data_path=f'{constants.base_data_path}{input_args.ds}', 32 | num_transf_classes=name_to_dataset[input_args.ds]['num_classes'], 33 | downscale=False, **name_to_dataset[input_args.ds]) 34 | 35 | _, test_loader = dataset.make_loaders(8, 128) 36 | 37 | if not os.path.exists(input_args.out): os.mkdir(input_args.out) 38 | store = cox.store.Store(f'{input_args.out}/dataset_{input_args.ds}_eps_{input_args.e}_PGD_{input_args.pgd}_norm_{input_args.lp}') 39 | 40 | model, _ = make_and_restore_model(arch='resnet50', dataset=dataset, parallel=False, resume_path=input_args.p, pytorch_pretrained=False) 41 | 42 | test_args = { 43 | 'adv_eval': input_args.e > 0, 44 | 'use_best': True, 45 | 'random_restarts': False, 46 | 'out_dir': "train_out", 47 | 'constraint': input_args.lp, # L-inf PGD 48 | 'eps': input_args.e, # Epsilon constraint (L-inf norm) 49 | 'attack_lr': 2.5*(input_args.e/input_args.pgd), # Learning rate for PGD 50 | 'attack_steps': input_args.pgd, # Number of PGD steps 51 | } 52 | 53 | test_args = Parameters(test_args) 54 | 55 | train.eval_model(test_args, model, test_loader, store) 56 | -------------------------------------------------------------------------------- /tools/batch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import subprocess 4 | import shlex 5 | import os 6 | from delete_big_files import deleteBigFilesFor1000experiment 7 | 8 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 9 | 10 | train_dir = '/home/ubuntu/robust_transfer_learning/' 11 | 12 | sample_size_to_number_of_seeds_epochs_and_log_freq = { 13 | 400 : (3, 150, 5), 14 | 1600 : (2, 150, 5), 15 | 6400 : (1, 150, 5), 16 | 25600: (1, 150, 5), 17 | -1 : (1, 150, 5), 18 | } 19 | 20 | target_ds_list = ['food101'] 21 | eps_levels = [0, 0.05, 0.25, 1] 22 | num_unfrozen_blocks_list = [3] 23 | 24 | 25 | polling_delay_seconds = 1 26 | concurrent_commands = 4 27 | commands_to_run = [] 28 | 29 | def poll_process(process): 30 | time.sleep(polling_delay_seconds) 31 | return process.poll() 32 | 33 | for t in target_ds_list: 34 | for ub in num_unfrozen_blocks_list: 35 | for n, tup in sample_size_to_number_of_seeds_epochs_and_log_freq.items(): 36 | num_seeds, ne, li = tup 37 | seed_list = [20000000 + 100000*(i) for i in range(num_seeds)] 38 | for s in seed_list: 39 | for e in eps_levels: 40 | command = f'python train.py -e {e} -t {t} -ub {ub} -n {n} -s {s} -ne {ne} -li {li} -d True' 41 | commands_to_run.append(command) 42 | 43 | for start_idx in range(0, len(commands_to_run), concurrent_commands): 44 | os.chdir(train_dir) 45 | processes = [] 46 | rng = range(start_idx, min(len(commands_to_run), start_idx + concurrent_commands)) 47 | print(rng) 48 | for i in rng: 49 | os.environ["CUDA_VISIBLE_DEVICES"] = str(i%4) 50 | processes.append(subprocess.Popen(shlex.split(commands_to_run[i]))) 51 | print(f'Starting command: {commands_to_run[i]}') 52 | 53 | for process in processes: 54 | while poll_process(process) is None: 55 | pass 56 | 57 | deleteBigFilesFor1000experiment() 58 | 59 | 60 | -------------------------------------------------------------------------------- /tools/constants.py: -------------------------------------------------------------------------------- 1 | 2 | base_model_path = 'models/' 3 | 4 | eps_to_filename = { 5 | -1: None, # Do not load a model that has been trained from scratch 6 | 0: 'nat', 7 | 0.05: 'imagenet_l2_0_05.pt', 8 | 0.25: 'imagenet_l2_0_25.pt', 9 | 1: 'imagenet_l2_1_0.pt', 10 | 3: 'imagenet_l2_3_0.pt', 11 | 4: 'imagenet_linf_4.pt', 12 | 8: 'imagenet_linf_8.pt' 13 | } 14 | 15 | # TODO: FILL OUT YOUR DATA PATH BELOW 16 | base_data_path = '/scratch/data/' 17 | 18 | -------------------------------------------------------------------------------- /tools/custom_datasets.py: -------------------------------------------------------------------------------- 1 | import tools.transforms as transforms 2 | import tools.constants as constants 3 | import os 4 | from robustness import imagenet_models, cifar_models 5 | from robustness.datasets import DataSet, CIFAR 6 | import torch as ch 7 | from torchvision import datasets 8 | 9 | 10 | IMAGENET_MEAN = [0.485, 0.456, 0.406] 11 | IMAGENET_STD = [0.229, 0.224, 0.225] 12 | 13 | class ImageNetTransfer(DataSet): 14 | def __init__(self, data_path, num_transf_classes=1000, **kwargs): 15 | self.num_classes = num_transf_classes 16 | imagenet_size = 224 17 | transform_type_to_transform = { 18 | 'default' : (transforms.TRAIN_TRANSFORMS_DEFAULT(imagenet_size), 19 | transforms.TEST_TRANSFORMS_DEFAULT(imagenet_size)), 20 | 'black_n_white' : (transforms.BLACK_N_WHITE(imagenet_size), 21 | transforms.BLACK_N_WHITE(imagenet_size)) 22 | } 23 | if kwargs['downscale']: 24 | transform_type_to_transform = { 25 | 'default' : (transforms.TRAIN_TRANSFORMS_DOWNSCALE(kwargs['downscale_size'], imagenet_size), 26 | transforms.TEST_TRANSFORMS_DOWNSCALE(kwargs['downscale_size'], imagenet_size)), 27 | 'black_n_white' : (transforms.BLACK_N_WHITE_DOWNSCALE(kwargs['downscale_size'], imagenet_size), 28 | transforms.BLACK_N_WHITE_DOWNSCALE(kwargs['downscale_size'], imagenet_size)) 29 | } 30 | ds_kwargs = { 31 | 'num_classes': kwargs['num_classes'], 32 | 'mean': ch.tensor(kwargs['mean']), 33 | 'std': ch.tensor(kwargs['std']), 34 | 'custom_class': kwargs['custom_class'], 35 | 'label_mapping': None, 36 | 'transform_train': transform_type_to_transform[kwargs['transform_type']][0], 37 | 'transform_test': transform_type_to_transform[kwargs['transform_type']][1] 38 | } 39 | # ds_kwargs = self.override_args(ds_kwargs, kwargs) 40 | self.name = kwargs['name'] 41 | super(ImageNetTransfer, self).__init__(kwargs['name'], data_path, **ds_kwargs) 42 | 43 | def get_model(self, arch, pretrained=False): 44 | return imagenet_models.__dict__[arch](num_classes=1000, pretrained=pretrained) 45 | 46 | CIFAR_MEAN = [0.4914, 0.4822, 0.4465] 47 | CIFAR_STD = [0.2023, 0.1994, 0.2010] 48 | 49 | # this class is used when we're training from scratch instead of from a pre-trained imagenet model 50 | class CIFAR(DataSet): 51 | def __init__(self, num_classes, data_path=None, **kwargs): 52 | self.name = f'cifar{num_classes}' 53 | 54 | num_classes_to_custom_class = { 55 | 10: datasets.CIFAR10, 56 | 100: datasets.CIFAR100 57 | } 58 | 59 | ds_kwargs = { 60 | 'num_classes': num_classes, 61 | 'mean': ch.tensor(CIFAR_MEAN), 62 | 'std': ch.tensor(CIFAR_STD), 63 | 'custom_class': num_classes_to_custom_class[num_classes], 64 | 'label_mapping': None, 65 | 'transform_train': transforms.TRAIN_TRANSFORMS_DEFAULT(32), 66 | 'transform_test': transforms.TEST_TRANSFORMS_DEFAULT(32) 67 | } 68 | super(CIFAR, self).__init__(f'cifar{num_classes}', data_path, **ds_kwargs) 69 | 70 | def get_model(self, arch, pretrained=False): 71 | if pretrained: 72 | raise ValueError('CIFAR100 does not support pytorch_pretrained=True') 73 | return cifar_models.__dict__[arch](num_classes=num_classes) 74 | 75 | name_to_dataset = { 76 | 'caltech101_stylized': {'num_classes':101, 'custom_class':None, 'transform_type':'default', 77 | 'mean':IMAGENET_MEAN, 'std':IMAGENET_STD}, 78 | 79 | 'food_stylized': {'num_classes':101, 'custom_class':None, 'transform_type':'default', 80 | 'mean':[0.5493, 0.4450, 0.3435], 'std':[0.2730, 0.2759, 0.2800]}, 81 | 82 | 'caltech101': {'num_classes':101, 'custom_class':None, 'transform_type':'default', 83 | 'mean':IMAGENET_MEAN, 'std':IMAGENET_STD}, 84 | 85 | 'food': {'num_classes':101, 'custom_class':None, 'transform_type':'default', 86 | 'mean':[0.5493, 0.4450, 0.3435], 'std':[0.2730, 0.2759, 0.2800]}, 87 | 88 | 'cifar10': {'num_classes':10, 'custom_class':datasets.CIFAR10, 'transform_type':'default', 89 | 'mean':CIFAR_MEAN, 'std':CIFAR_STD}, 90 | 91 | 'cifar100': {'num_classes':100, 'custom_class':datasets.CIFAR100, 'transform_type':'default', 92 | 'mean':CIFAR_MEAN, 'std':CIFAR_STD}, 93 | 94 | 'svhn': {'num_classes':10, 'custom_class':datasets.SVHN, 'transform_type':'default', 95 | 'mean':[0.4377, 0.4438,0.4728], 'std':[0.1980,0.2010,0.1970]}, 96 | 97 | # TODO: Get mean and std for fmnist 98 | 'fmnist': {'num_classes':10, 'custom_class':datasets.FashionMNIST, 'transform_type':'black_n_white', 99 | 'mean':[0.1801,0.1801,0.1801], 'std':[0.3421,0.3421,0.3421]}, 100 | 101 | 'kmnist': {'num_classes':10, 'custom_class':datasets.KMNIST, 'transform_type':'black_n_white', 102 | 'mean':[0.1801,0.1801,0.1801], 'std':[0.3421,0.3421,0.3421]}, 103 | 104 | 'mnist': {'num_classes':10, 'custom_class':datasets.MNIST, 'transform_type':'black_n_white', 105 | 'mean':[0.1307,0.1307,0.1307], 'std':[0.3081,0.3081,0.3081]} 106 | } 107 | 108 | def make_dataset(args): 109 | return ImageNetTransfer(name=args['target_dataset_name'], 110 | data_path=f'{constants.base_data_path}{args["target_dataset_name"]}', 111 | downscale=args['downscale'], downscale_size=args['downscale_size'], 112 | **name_to_dataset[args['target_dataset_name']]) 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /tools/delete_big_files.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def deleteBigFiles(max_size): 4 | file_list = [f for f in os.listdir() if os.path.isfile(f)] 5 | for f in file_list: 6 | filesize = os.stat(f).st_size 7 | if filesize > max_size: 8 | os.remove(f) 9 | # print(f'filename: {f} \n file size: {filesize} \n current working directory: {os.getcwd()} \n\n') 10 | directory_list = [d for d in os.listdir() if os.path.isdir(d)] 11 | for d in directory_list: 12 | os.chdir(d) 13 | deleteBigFiles(max_size) 14 | os.chdir("..") 15 | 16 | def deleteBigFilesFor1000experiment(): 17 | max_size = 10000000 18 | os.chdir(os.getcwd()) 19 | os.chdir("results/logs") 20 | for folder in [folder for folder in os.listdir()]: 21 | os.chdir(folder) 22 | deleteBigFiles(max_size) 23 | os.chdir("..") 24 | -------------------------------------------------------------------------------- /tools/helpers.py: -------------------------------------------------------------------------------- 1 | import torch as ch 2 | from torch import autograd 3 | import torch.nn as nn 4 | from torchvision import datasets 5 | 6 | from robustness import defaults 7 | from robustness.datasets import CIFAR 8 | from robustness.model_utils import make_and_restore_model 9 | from robustness.attacker import AttackerModel 10 | 11 | import cox 12 | from cox.utils import Parameters 13 | 14 | import tools.custom_datasets as custom_datasets 15 | import tools.constants as constants 16 | 17 | import argparse 18 | import os 19 | import random 20 | import numpy as np 21 | import pprint 22 | import sys 23 | 24 | 25 | def get_runtime_inputs(): 26 | 27 | eps_list = list(constants.eps_to_filename.keys()) 28 | unfrozen_blocks_list = [-1, 0, 1, 3, 9] 29 | # 'food101', 'aircraft', 'caltech101', 'pets', 'cars', 'dtd'] 30 | target_dataset_list = list(custom_datasets.name_to_dataset.keys()) 31 | 32 | parser = argparse.ArgumentParser(add_help=True) 33 | parser.add_argument('-e', required=False, default=0, help='epsilon used to train the source dataset. -1 means train from scratch', type=float, choices=eps_list) 34 | parser.add_argument('-t', required=False, default='cifar10', help='name of the target dataset', type=str, choices=target_dataset_list) 35 | parser.add_argument('-ub', required=False, default=0, help='number of unfrozen blocks. -1 means all blocks unfrozen', type=int, choices=unfrozen_blocks_list) 36 | parser.add_argument('-b', required=False, default=128, help='batch size', type=int) 37 | parser.add_argument('-n', required=False, default=-1, help='number of images used in the target dataset', type=int) 38 | parser.add_argument('-w', required=False, default=10, help='number of workers used for parallel computations', type=int) 39 | parser.add_argument('-ne', required=False, default=30, help='number of epochs', type=int) 40 | parser.add_argument('-li', required=False, default=5, help='how often to log iterations', type=int) 41 | parser.add_argument('-s', required=False, default=1000000, help='random seed', type=int) 42 | parser.add_argument('-d', required=False, default=False, help='downscale to lower res?', type=bool) 43 | parser.add_argument('-ds', required=False, default=32, help='downscaled resolution', type=int) 44 | parser.add_argument('-lr', required=False, default=0.1, help='learning rate', type=float) 45 | # parser.add_argument('-lp', required=False, default=False, help='use the 32x32 low pass?', type=bool) 46 | 47 | args = parser.parse_args() 48 | 49 | var_dict = { 50 | 'source_eps' : args.e, 51 | 'target_dataset_name' : args.t, 52 | 'unfrozen_blocks' : args.ub, 53 | 'batch_size' : args.b, 54 | 'seed' : args.s, 55 | 'num_workers' : args.w, 56 | 'num_training_images' : args.n, 57 | 'num_epochs' : args.ne, 58 | 'log_iters' : args.li, 59 | 'downscale' : args.d, 60 | 'downscale_size' : args.ds, 61 | 'learning_rate' : args.lr, 62 | # 'low_pass' : args.lp 63 | } 64 | 65 | return var_dict 66 | 67 | 68 | def set_seeds(var_dict): 69 | seed = var_dict['seed'] 70 | random.seed(seed) 71 | np.random.seed(seed) 72 | ch.manual_seed(seed) 73 | 74 | 75 | def change_linear_layer_out_features(model, var_dict, dataset, is_Transfer, num_in_features=2048): 76 | num_out_features = dataset.num_classes 77 | model.model.fc = nn.Linear(in_features=num_in_features, out_features=num_out_features) 78 | #if is_Transfer: 79 | # model.model.fc = nn.Linear(in_features=num_in_features, out_features=num_out_features) 80 | #else: 81 | # model.fc = nn.Linear(in_features=num_in_features, out_features=num_out_features) 82 | return model 83 | 84 | 85 | def load_model(var_dict, is_Transfer, pretrained, dataset): 86 | if is_Transfer: 87 | if var_dict['source_eps'] > 0: 88 | resume_path = os.path.abspath('models/'+constants.eps_to_filename[var_dict['source_eps']]) 89 | model, _ = make_and_restore_model(arch='resnet50', dataset=dataset, parallel=False, resume_path=resume_path, pytorch_pretrained=False) 90 | else: 91 | model, _ = make_and_restore_model(arch='resnet50', dataset=dataset, parallel=False, resume_path=None, pytorch_pretrained=True) 92 | else: 93 | model = dataset.get_model(arch='resnet50', pretrained=False) 94 | model = AttackerModel(model, dataset) 95 | return model 96 | 97 | 98 | def re_init_and_freeze_blocks(model, var_dict): 99 | 100 | unfrozen_blocks_to_layer_name_list = { 101 | 0: ['fc'], 102 | 1: ['fc', '4.2'], 103 | 3: ['fc', '4.2', '4.1', '4.0'], 104 | 9: ['fc', '4.2', '4.1', '4.0', '3.5', '3.4', '3.3', '3.2', '3.1', '3.0'] 105 | } 106 | 107 | layer_name_list = unfrozen_blocks_to_layer_name_list[var_dict['unfrozen_blocks']] 108 | 109 | for name, param in model.named_parameters(): 110 | # if the name of the parameter is not one of the unfrozen blocks then freeze it 111 | if not any([layer_name in name for layer_name in layer_name_list]): 112 | param.requires_grad = False 113 | else: 114 | if 'fc.weight' in name: nn.init.normal_(param) 115 | elif 'fc.bias' in name: nn.init.constant_(param, 0.0) 116 | return model 117 | 118 | def make_out_store(var_dict): 119 | 120 | out_dir = (os.getcwd()+ '/results/logs/' 121 | + f'source_eps_{var_dict["source_eps"]}_' 122 | + f'target_dataset_{var_dict["target_dataset_name"]}_' 123 | + f'num_training_images_{var_dict["num_training_images"]}_' 124 | + f'unfrozen_blocks_{var_dict["unfrozen_blocks"]}_' 125 | + f'seed_{var_dict["seed"]}_' 126 | + f'downscaled_{var_dict["downscale"]}_') 127 | out_store = cox.store.Store(out_dir) 128 | return out_store 129 | 130 | 131 | def make_train_args(var_dict): 132 | 133 | train_kwargs = { 134 | 'out_dir' : "train_out", 135 | 'adv_train' : 0, 136 | 'epochs' : var_dict['num_epochs'], 137 | 'step_lr' : var_dict['num_epochs']//3, 138 | 'log_iters' : var_dict['log_iters'], 139 | 'learning_rate': var_dict['learning_rate'] 140 | } 141 | 142 | train_args = Parameters(train_kwargs) 143 | train_args = defaults.check_and_fill_args(train_args, defaults.TRAINING_ARGS, CIFAR) 144 | 145 | return train_args 146 | 147 | 148 | def print_details(model, var_dict, train_args): 149 | 150 | for name, param in model.named_parameters(): 151 | print("{}: {}".format(name, param.requires_grad)) 152 | 153 | print('Input parameters: ') 154 | pprint.pprint(var_dict) 155 | 156 | print('Transfer learning training parameters: ') 157 | pprint.pprint(train_args) 158 | 159 | 160 | def eval_hessian(loss_grad, model): 161 | cnt = 0 162 | for g in loss_grad: 163 | g_vector = g.contiguous().view(-1) if cnt == 0 else ch.cat([g_vector, g.contiguous().view(-1)]) 164 | cnt = 1 165 | l = g_vector.size(0) 166 | hessian = ch.zeros(l, l).cpu().double() 167 | for idx in range(l): 168 | grad2rd = autograd.grad(g_vector[idx], model.parameters(), retain_graph=True) 169 | grad2rd = (grad2rd[0].cpu().double(), grad2rd[1].cpu().double()) 170 | cnt = 0 171 | for g in grad2rd: 172 | g2 = g.contiguous().view(-1) if cnt == 0 else ch.cat([g2, g.contiguous().view(-1)]) 173 | cnt = 1 174 | hessian[idx] = g2 175 | return hessian.cpu().double().data.numpy() 176 | 177 | 178 | def flatten_grad(grad): 179 | cnt = 0 180 | for g in grad: 181 | g_vector = g.contiguous().view(-1) if cnt == 0 else ch.cat([g_vector, g.contiguous().view(-1)]) 182 | cnt = 1 183 | return g_vector.cpu().double().data.numpy() 184 | 185 | def get_runtime_inputs_for_influence_functions(): 186 | # PARSE INPUTS 187 | parser = argparse.ArgumentParser(add_help=True) 188 | 189 | parser.add_argument('-e', required=False, default=0, 190 | help='epsilon used to train the source dataset', type=int, choices=[0, 3]) 191 | parser.add_argument('-n', required=False, default=3200, 192 | help='number of images used to make hessian', type=int) 193 | parser.add_argument('-ub', required=False, default=3, 194 | help='number of unfrozen blocks', type=int) 195 | parser.add_argument('-s', required=False, default=1000000, 196 | help='seed', type=int) 197 | ds_list = ['cifar10', 'svhn'] 198 | parser.add_argument('-ds', required=True, choices=ds_list, 199 | help='target dataset', type=str) 200 | parser.add_argument('-b', required=False, default=1, 201 | help='batch_size used to generate hessian (not needed for make_gradient.py)', type=int) 202 | parser.add_argument('-t', required=False, default='train', 203 | help='train or test? (not needed for hessian.py)', type=str, choices=['train', 'test']) 204 | 205 | args = parser.parse_args() 206 | 207 | return args 208 | 209 | def load_gradients(ds, eps, ub=3, num_images=3200, train_or_test='train'): 210 | os.chdir(f'{train_or_test}_grad/{ds}_{eps}_eps_{ub}_ub_{num_images}_images') 211 | files = os.listdir() 212 | for i, f in enumerate(files): 213 | if i == 0: gradients = np.load(f) 214 | else: gradients = np.vstack((gradients, np.load(f))) 215 | 216 | os.chdir('../..') 217 | return gradients 218 | 219 | -------------------------------------------------------------------------------- /tools/transforms.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | 3 | # DOWNSCALE RESOLUTION TRANSFORMS 4 | 5 | TRAIN_TRANSFORMS_DOWNSCALE = lambda downscale, upscale: transforms.Compose([ 6 | transforms.Resize(downscale), 7 | transforms.Resize(upscale), 8 | transforms.RandomCrop(upscale, padding=4), 9 | transforms.RandomHorizontalFlip(), 10 | transforms.ColorJitter(.25,.25,.25), 11 | transforms.RandomRotation(2), 12 | transforms.ToTensor(), 13 | ]) 14 | 15 | TEST_TRANSFORMS_DOWNSCALE = lambda downscale, upscale:transforms.Compose([ 16 | transforms.Resize(downscale), 17 | transforms.Resize(upscale), 18 | transforms.CenterCrop(upscale), 19 | transforms.ToTensor() 20 | ]) 21 | 22 | 23 | # DEFAULT TRANSFORMS 24 | 25 | TRAIN_TRANSFORMS_DEFAULT = lambda size: transforms.Compose([ 26 | transforms.Resize(size), 27 | transforms.RandomCrop(size, padding=4), 28 | transforms.RandomHorizontalFlip(), 29 | transforms.ColorJitter(.25,.25,.25), 30 | transforms.RandomRotation(2), 31 | transforms.ToTensor(), 32 | ]) 33 | 34 | TEST_TRANSFORMS_DEFAULT = lambda size:transforms.Compose([ 35 | transforms.Resize(size), 36 | transforms.CenterCrop(size), 37 | transforms.ToTensor() 38 | ]) 39 | 40 | # BLACK AND WHITE TRANSFORMS 41 | 42 | BLACK_N_WHITE_DOWNSCALE = lambda downscale, size: transforms.Compose([ 43 | transforms.Resize(downscale), 44 | transforms.Resize(size), 45 | transforms.Grayscale(num_output_channels=3), 46 | transforms.ToTensor(), 47 | ]) 48 | 49 | BLACK_N_WHITE = lambda size: transforms.Compose([ 50 | transforms.Resize(size), 51 | transforms.Grayscale(num_output_channels=3), 52 | transforms.ToTensor(), 53 | ]) 54 | 55 | 56 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import tools.helpers as helpers 2 | from tools.custom_datasets import name_to_dataset, make_dataset 3 | from robustness import train 4 | import cox.store 5 | 6 | 7 | var_dict = helpers.get_runtime_inputs() 8 | 9 | helpers.set_seeds(var_dict) 10 | 11 | # do we transfer? 12 | is_Transfer = False 13 | if var_dict['source_eps'] > -1: is_Transfer = True 14 | 15 | # do we grab an imagenet pretrained model? 16 | pretrained = False 17 | if var_dict['source_eps'] == 0: pretrained = True 18 | 19 | # get dataset class 20 | dataset = make_dataset(var_dict) 21 | 22 | model = helpers.load_model(var_dict, is_Transfer, pretrained, dataset) 23 | 24 | model = helpers.change_linear_layer_out_features(model, var_dict, dataset, is_Transfer) 25 | 26 | if is_Transfer: model = helpers.re_init_and_freeze_blocks(model, var_dict) 27 | 28 | subset = var_dict['num_training_images'] 29 | if var_dict['num_training_images'] == -1: subset = None 30 | train_loader, test_loader = dataset.make_loaders(workers=var_dict['num_workers'], batch_size=var_dict['batch_size'], 31 | subset=subset, subset_seed=var_dict['seed']) 32 | 33 | out_store = helpers.make_out_store(var_dict) 34 | train_args = helpers.make_train_args(var_dict) 35 | 36 | helpers.print_details(model, var_dict, train_args) 37 | 38 | train.train_model(train_args, model, (train_loader, test_loader), store=out_store) 39 | pass 40 | --------------------------------------------------------------------------------