├── LICENSE ├── PACS ├── __init__.py ├── evaulate_domain_experiments.py ├── fine_tune_alexnet.py ├── fine_tune_alexnet_domain_classifier.py ├── fine_tune_alexnet_find_da.py ├── fine_tune_alexnet_with_all_da.py ├── fine_tune_alexnet_with_chosen_da.py ├── jobs │ ├── __init__.py │ ├── find_da.sh │ ├── find_da1.sh │ ├── find_da234.sh │ ├── find_da_one_batch.sh │ ├── fine_tune_alexnet.sh │ ├── fine_tune_alexnet_domain_classifier.sh │ ├── fine_tune_alexnet_jitter_norm.sh │ ├── with_all_da.sh │ └── with_chosen_da.sh ├── model_alexnet_caffe.py ├── model_alexnet_caffe_domain.py ├── pacs_data_loader.py ├── pacs_data_loader_data_augmentation.py ├── pacs_data_loader_grey.py ├── pacs_data_loader_norm.py ├── pacs_example.png ├── pacs_examples.png ├── pretrained │ └── __init__.py └── samples_for_paper.py ├── README.md ├── colored_mnist ├── __init__.py ├── choose_da_with_domain_classifier.py ├── domain_classifier.py ├── evaulate_domain_experiments.py ├── jobs_da.sh ├── main_random_baseline.py ├── main_with_all_da.py ├── main_with_chosen_da.py ├── with_all_da.sh └── with_chosen_da.sh ├── rotated_MNIST ├── __init__.py ├── augmentations │ ├── __init__.py │ ├── experiment_augmentations_test_0.py │ ├── experiment_augmentations_test_0_all_da.py │ ├── experiment_augmentations_test_0_individual_classes.py │ ├── experiment_augmentations_test_30.py │ ├── experiment_augmentations_test_30_all_da.py │ ├── experiment_augmentations_test_60.py │ ├── experiment_augmentations_test_60_all_da.py │ ├── experiment_augmentations_test_90.py │ ├── experiment_augmentations_test_90_all_da.py │ ├── jobs │ │ ├── __init__.py │ │ ├── test_0_all_da.sh │ │ ├── test_0_flip.sh │ │ ├── test_0_rotate.sh │ │ ├── test_30_all_da.sh │ │ ├── test_30_flip.sh │ │ ├── test_30_rotate.sh │ │ ├── test_60_all_da.sh │ │ ├── test_60_flip.sh │ │ ├── test_60_rotate.sh │ │ ├── test_90_all_da.sh │ │ ├── test_90_flip.sh │ │ └── test_90_rotate.sh │ └── model_baseline.py ├── choose_da_with_domain_classifier │ ├── __init__.py │ ├── evaulate_domain_experiments.py │ ├── experiment_test_0.py │ ├── experiment_test_0_only_rotation.py │ └── jobs │ │ ├── __init__.py │ │ ├── test_0.sh │ │ └── test_0_rot_only.sh ├── domain_classifier │ ├── __init__.py │ ├── domain_classifier.py │ ├── domain_classifier_none.sh │ └── domain_classifier_rotate.sh ├── mnist_loader.py ├── mnist_loader_da.py ├── mnist_loader_shifted_label_distribution_all_da.py ├── mnist_loader_shifted_label_distribution_rotate.py └── mnist_loader_shifted_label_distribution_rotate_da.py └── synthetic_data ├── __init__.py ├── domain_gen_sem.py ├── jobs ├── ICP.sh ├── IRM.sh └── __init__.py ├── main_domain_gen.py ├── plot.py ├── run_erm.py ├── run_erm_domain.py ├── run_erm_with_data_augmentation_random_noise.py └── toy_data_comparison.png /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 AMLAB 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /PACS/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AMLab-Amsterdam/DataAugmentationInterventions/78ce67174db487e9697b9a842e69818305bb41ef/PACS/__init__.py -------------------------------------------------------------------------------- /PACS/evaulate_domain_experiments.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import re 4 | import numpy as np 5 | 6 | os.chdir('./') 7 | 8 | brightness_train = [] 9 | brightness_val = [] 10 | contrast_train = [] 11 | contrast_val = [] 12 | saturation_train = [] 13 | saturation_val = [] 14 | hue_train = [] 15 | hue_val = [] 16 | rotation_train = [] 17 | rotation_val = [] 18 | translate_train = [] 19 | translate_val = [] 20 | scale_train = [] 21 | scale_val = [] 22 | shear_train = [] 23 | shear_val = [] 24 | vflip_train = [] 25 | vflip_val = [] 26 | hflip_train = [] 27 | hflip_val = [] 28 | 29 | for file in glob.glob("*.txt"): 30 | if 'photo' in file: 31 | if 'brightness' in file: 32 | with open(file) as f: 33 | content = f.readlines() 34 | train, val = re.findall(r"[-+]?\d*\.\d+|\d+", content[0]) 35 | train, val = float(train), float(val) 36 | brightness_train.append(train) 37 | brightness_val.append(val) 38 | 39 | if 'contrast' in file: 40 | with open(file) as f: 41 | content = f.readlines() 42 | train, val = re.findall(r"[-+]?\d*\.\d+|\d+", content[0]) 43 | train, val = float(train), float(val) 44 | contrast_train.append(train) 45 | contrast_val.append(val) 46 | 47 | if 'saturation' in file: 48 | with open(file) as f: 49 | content = f.readlines() 50 | train, val = re.findall(r"[-+]?\d*\.\d+|\d+", content[0]) 51 | train, val = float(train), float(val) 52 | saturation_train.append(train) 53 | saturation_val.append(val) 54 | 55 | if 'hue' in file: 56 | with open(file) as f: 57 | content = f.readlines() 58 | train, val = re.findall(r"[-+]?\d*\.\d+|\d+", content[0]) 59 | train, val = float(train), float(val) 60 | hue_train.append(train) 61 | hue_val.append(val) 62 | 63 | if 'rotation' in file: 64 | with open(file) as f: 65 | content = f.readlines() 66 | train, val = re.findall(r"[-+]?\d*\.\d+|\d+", content[0]) 67 | train, val = float(train), float(val) 68 | rotation_train.append(train) 69 | rotation_val.append(val) 70 | 71 | if 'translate' in file: 72 | with open(file) as f: 73 | content = f.readlines() 74 | train, val = re.findall(r"[-+]?\d*\.\d+|\d+", content[0]) 75 | train, val = float(train), float(val) 76 | translate_train.append(train) 77 | translate_val.append(val) 78 | 79 | if 'scale' in file: 80 | with open(file) as f: 81 | content = f.readlines() 82 | train, val = re.findall(r"[-+]?\d*\.\d+|\d+", content[0]) 83 | train, val = float(train), float(val) 84 | scale_train.append(train) 85 | scale_val.append(val) 86 | 87 | if 'shear' in file: 88 | with open(file) as f: 89 | content = f.readlines() 90 | train, val = re.findall(r"[-+]?\d*\.\d+|\d+", content[0]) 91 | train, val = float(train), float(val) 92 | shear_train.append(train) 93 | shear_val.append(val) 94 | 95 | if 'vflip' in file: 96 | with open(file) as f: 97 | content = f.readlines() 98 | train, val = re.findall(r"[-+]?\d*\.\d+|\d+", content[0]) 99 | train, val = float(train), float(val) 100 | vflip_train.append(train) 101 | vflip_val.append(val) 102 | 103 | if 'hflip' in file: 104 | with open(file) as f: 105 | content = f.readlines() 106 | train, val = re.findall(r"[-+]?\d*\.\d+|\d+", content[0]) 107 | train, val = float(train), float(val) 108 | hflip_train.append(train) 109 | hflip_val.append(val) 110 | 111 | brightness_train_std = np.std(np.array(brightness_train)) 112 | brightness_val_std = np.std(np.array(brightness_val)) 113 | contrast_train_std = np.std(np.array(contrast_train)) 114 | contrast_val_std = np.std(np.array(contrast_val)) 115 | saturation_train_std = np.std(np.array(saturation_train)) 116 | saturation_val_std = np.std(np.array(saturation_val)) 117 | hue_train_std = np.std(np.array(hue_train)) 118 | hue_val_std = np.std(np.array(hue_val)) 119 | rotation_train_std = np.std(np.array(rotation_train)) 120 | rotation_val_std = np.std(np.array(rotation_val)) 121 | translate_train_std = np.std(np.array(translate_train)) 122 | translate_val_std = np.std(np.array(translate_val)) 123 | scale_train_std = np.std(np.array(scale_train)) 124 | scale_val_std = np.std(np.array(scale_val)) 125 | shear_train_std = np.std(np.array(shear_train)) 126 | shear_val_std = np.std(np.array(shear_val)) 127 | vflip_train_std = np.std(np.array(vflip_train)) 128 | vflip_val_std = np.std(np.array(vflip_val)) 129 | hflip_train_std = np.std(np.array(hflip_train)) 130 | hflip_val_std = np.std(np.array(hflip_val)) 131 | 132 | 133 | brightness_train = np.mean(np.array(brightness_train)) 134 | brightness_val = np.mean(np.array(brightness_val)) 135 | contrast_train = np.mean(np.array(contrast_train)) 136 | contrast_val = np.mean(np.array(contrast_val)) 137 | saturation_train = np.mean(np.array(saturation_train)) 138 | saturation_val = np.mean(np.array(saturation_val)) 139 | hue_train = np.mean(np.array(hue_train)) 140 | hue_val = np.mean(np.array(hue_val)) 141 | rotation_train = np.mean(np.array(rotation_train)) 142 | rotation_val = np.mean(np.array(rotation_val)) 143 | translate_train = np.mean(np.array(translate_train)) 144 | translate_val = np.mean(np.array(translate_val)) 145 | scale_train = np.mean(np.array(scale_train)) 146 | scale_val = np.mean(np.array(scale_val)) 147 | shear_train = np.mean(np.array(shear_train)) 148 | shear_val = np.mean(np.array(shear_val)) 149 | vflip_train = np.mean(np.array(vflip_train)) 150 | vflip_val = np.mean(np.array(vflip_val)) 151 | hflip_train = np.mean(np.array(hflip_train)) 152 | hflip_val = np.mean(np.array(hflip_val)) 153 | 154 | print('mean') 155 | print(brightness_train) 156 | print(brightness_val) 157 | print(contrast_train) 158 | print(contrast_val) 159 | print(saturation_train) 160 | print(saturation_val) 161 | print(hue_train) 162 | print(hue_val) 163 | print(rotation_train) 164 | print(rotation_val) 165 | print(translate_train) 166 | print(translate_val) 167 | print(scale_train) 168 | print(scale_val) 169 | print(shear_train) 170 | print(shear_val) 171 | print(vflip_train) 172 | print(vflip_val) 173 | print(hflip_train) 174 | print(hflip_val) 175 | print('std') 176 | print(brightness_train_std) 177 | print(brightness_val_std) 178 | print(contrast_train_std) 179 | print(contrast_val_std) 180 | print(saturation_train_std) 181 | print(saturation_val_std) 182 | print(hue_train_std) 183 | print(hue_val_std) 184 | print(rotation_train_std) 185 | print(rotation_val_std) 186 | print(translate_train_std) 187 | print(translate_val_std) 188 | print(scale_train_std) 189 | print(scale_val_std) 190 | print(shear_train_std) 191 | print(shear_val_std) 192 | print(vflip_train_std) 193 | print(vflip_val_std) 194 | print(hflip_train_std) 195 | print(hflip_val_std) -------------------------------------------------------------------------------- /PACS/fine_tune_alexnet.py: -------------------------------------------------------------------------------- 1 | # Scores we want to reproduce 2 | # Art painting: 63.3, Cartoon: 63.13, Photo 87.7, Sketch 54.07, mean: 67.05 3 | 4 | import sys 5 | import wandb 6 | 7 | sys.path.insert(0, "../../") 8 | 9 | import argparse 10 | import torch 11 | import torch.optim as optim 12 | import torch.utils.data as data_utils 13 | 14 | import numpy as np 15 | 16 | from paper_experiments.PACS.model_alexnet_caffe import caffenet 17 | from paper_experiments.PACS.pacs_data_loader import PacsData 18 | 19 | 20 | def train(args, model, device, train_loader, optimizer, epoch): 21 | model.train() 22 | loss_batch = 0 23 | 24 | for batch_idx, (data, target, _) in enumerate(train_loader): 25 | data, target = data.to(device), target.to(device) 26 | _, target = target.max(dim=1) 27 | 28 | optimizer.zero_grad() 29 | output = model(data) 30 | loss = torch.nn.CrossEntropyLoss(reduction='mean')(output, target) 31 | loss.backward() 32 | optimizer.step() 33 | 34 | loss_batch += loss 35 | 36 | return loss_batch 37 | 38 | 39 | def test(args, model, device, test_loader, set_name): 40 | model.eval() 41 | test_loss = 0 42 | correct = 0 43 | with torch.no_grad(): 44 | for data, target, _ in test_loader: 45 | data, target = data.to(device), target.to(device) 46 | _, target = target.max(dim=1) 47 | 48 | output = model(data) 49 | test_loss += torch.nn.CrossEntropyLoss(reduction='mean')(output, target) # sum up batch loss 50 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 51 | correct += pred.eq(target.view_as(pred)).sum().item() 52 | 53 | # test_loss /= len(test_loader.dataset) 54 | 55 | return test_loss, 100. * correct / len(test_loader.dataset) 56 | 57 | 58 | def main(): 59 | # Training settings 60 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 61 | parser.add_argument('--no-cuda', action='store_true', default=False, 62 | help='disables CUDA training') 63 | parser.add_argument('--seed', type=int, default=0, 64 | help='random seed (default: 1)') 65 | parser.add_argument('--batch-size', type=int, default=128, 66 | help='input batch size for training (default: 64)') 67 | parser.add_argument('--epochs', type=int, default=200, 68 | help='number of epochs to train (default: 10)') 69 | parser.add_argument('--lr', type=float, default=0.001, 70 | help='learning rate (default: 0.01)') 71 | parser.add_argument('--test_domain', type=list, default=['sketch'], 72 | help='domain used during test') 73 | parser.add_argument('--all_domains', type=list, default=['art_painting', 'cartoon', 'photo', 'sketch'], 74 | help='domain used during train') 75 | 76 | parser.add_argument('-dd', '--data_dir', type=str, default='./data', help='Directory to download data to and load data from') 77 | parser.add_argument('-wd', '--wandb_dir', type=str, default='./', help='(OVERRIDDEN BY ENV_VAR for sweep) Directory to download data to and load data from') 78 | 79 | args = parser.parse_args() 80 | args.test_domain = [''.join(args.test_domain)] 81 | 82 | # Default config is above, Overridden by ENV_VARIABLES!!! or command line 83 | # Sweep interacts weirdly with some things... 84 | wandb.init(project="CaffeNetValLoss_" + args.test_domain[0], config=args) 85 | 86 | config = wandb.config 87 | 88 | # wandb.config.seed = args.seed 89 | # wandb.config.lr = args.lr 90 | # wandb.config.test_domain = args.test_domain 91 | 92 | print(config) 93 | print("Data from:", config.data_dir) 94 | print("Logging to:", config.wandb_dir) 95 | 96 | use_cuda = not config.no_cuda and torch.cuda.is_available() 97 | 98 | # Set seed 99 | torch.manual_seed(config.seed) 100 | torch.backends.cudnn.benchmark = False 101 | np.random.seed(config.seed) 102 | 103 | device = torch.device("cuda") 104 | 105 | model_name = 'caffenet_val_loss_seed_' + str(config.seed) + '_test_domain_' + config.test_domain[0] 106 | 107 | kwargs = {'num_workers': 8, 'pin_memory': True} if use_cuda else {} 108 | 109 | train_domain = [n for n in config.all_domains if n != config.test_domain[0]] 110 | print(train_domain, config.test_domain) 111 | 112 | # Load supervised training 113 | train_loader = data_utils.DataLoader( 114 | PacsData('./kfold/', domain_list=train_domain, mode='train'), 115 | batch_size=config.batch_size, 116 | shuffle=True, **kwargs) 117 | val_loader = data_utils.DataLoader( 118 | PacsData('./kfold/', domain_list=train_domain, mode='val'), 119 | batch_size=config.batch_size, 120 | shuffle=False, **kwargs) 121 | 122 | model = caffenet(7).to(device) 123 | 124 | optimizer = optim.SGD(model.parameters(), weight_decay=.0005, momentum=.9, nesterov=True, lr=config.lr) 125 | step_size = int(config.epochs * .8) 126 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size) 127 | 128 | best_val_loss = 1000000000 129 | 130 | for epoch in range(1, config.epochs + 1): 131 | train_loss = train(config, model, device, train_loader, optimizer, epoch) 132 | _, train_acc = test(config, model, device, train_loader, set_name='Train') 133 | val_loss, val_acc = test(config, model, device, val_loader, set_name='Val') 134 | 135 | wandb.log({'train_loss': train_loss, 'val_loss': val_loss, 'train_acc': train_acc, 'val_acc': val_acc}) 136 | 137 | scheduler.step() 138 | 139 | # Save best 140 | if val_loss <= best_val_loss: 141 | best_val_loss = val_loss 142 | 143 | torch.save(model, model_name + '.model') 144 | torch.save(args, model_name + '.config') 145 | 146 | # Test 147 | test_loader = data_utils.DataLoader( 148 | PacsData('./kfold/', domain_list=config.test_domain, mode='test'), 149 | batch_size=config.batch_size, 150 | shuffle=False, **kwargs) 151 | model = torch.load(model_name + '.model').to(device) 152 | _, test_acc = test(config, model, device, test_loader, set_name='Test') 153 | 154 | with open(model_name + '.txt', "w") as text_file: 155 | text_file.write("Test Acc: " + str(test_acc)) 156 | 157 | 158 | if __name__ == '__main__': 159 | main() -------------------------------------------------------------------------------- /PACS/fine_tune_alexnet_domain_classifier.py: -------------------------------------------------------------------------------- 1 | # Experiments in the paper have 200 epochs and no normalization 2 | 3 | import sys 4 | import wandb 5 | 6 | sys.path.insert(0, "../../") 7 | 8 | import argparse 9 | import torch 10 | import torch.optim as optim 11 | import torch.utils.data as data_utils 12 | import torchvision.transforms as transforms 13 | 14 | import numpy as np 15 | 16 | from paper_experiments.PACS.model_alexnet_caffe import caffenet 17 | from paper_experiments.PACS.pacs_data_loader_data_augmentation import PacsDataDataAug 18 | from paper_experiments.PACS.pacs_data_loader_norm import PacsData 19 | 20 | 21 | def train(args, model, device, train_loader, optimizer, epoch): 22 | model.train() 23 | loss_batch = 0 24 | 25 | for batch_idx, (data, _, domain) in enumerate(train_loader): 26 | data, domain = data.to(device), domain.to(device) 27 | _, domain = domain.max(dim=1) 28 | 29 | optimizer.zero_grad() 30 | output = model(data) 31 | loss = torch.nn.CrossEntropyLoss(reduction='mean')(output, domain) 32 | loss.backward() 33 | optimizer.step() 34 | 35 | loss_batch += loss 36 | 37 | return loss_batch 38 | 39 | 40 | def test(args, model, device, test_loader, set_name): 41 | model.eval() 42 | test_loss = 0 43 | correct = 0 44 | with torch.no_grad(): 45 | for data, _, domain in test_loader: 46 | data, domain = data.to(device), domain.to(device) 47 | _, domain = domain.max(dim=1) 48 | 49 | output = model(data) 50 | test_loss += torch.nn.CrossEntropyLoss(reduction='mean')(output, domain) # sum up batch loss 51 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 52 | correct += pred.eq(domain.view_as(pred)).sum().item() 53 | 54 | # test_loss /= len(test_loader.dataset) 55 | 56 | return test_loss, 100. * correct / len(test_loader.dataset) 57 | 58 | 59 | def main(): 60 | # Training settings 61 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 62 | parser.add_argument('--no-cuda', action='store_true', default=False, 63 | help='disables CUDA training') 64 | parser.add_argument('--seed', type=int, default=0, 65 | help='random seed (default: 1)') 66 | parser.add_argument('--batch-size', type=int, default=128, 67 | help='input batch size for training (default: 64)') 68 | parser.add_argument('--epochs', type=int, default=200, 69 | help='number of epochs to train (default: 10)') 70 | parser.add_argument('--lr', type=float, default=0.001, 71 | help='learning rate (default: 0.01)') 72 | parser.add_argument('--test_domain', type=list, default=['photo'], 73 | help='domain used during test') 74 | parser.add_argument('--all_domains', type=list, default=['art_painting', 'cartoon', 'photo', 'sketch'], 75 | help='domain used during train') 76 | parser.add_argument('--da', type=str, default='jitter', choices=['none', 'jitter'], 77 | help='domain used during train') 78 | 79 | parser.add_argument('-dd', '--data_dir', type=str, default='./data', help='Directory to download data to and load data from') 80 | parser.add_argument('-wd', '--wandb_dir', type=str, default='./', help='(OVERRIDDEN BY ENV_VAR for sweep) Directory to download data to and load data from') 81 | 82 | args = parser.parse_args() 83 | args.test_domain = [''.join(args.test_domain)] 84 | 85 | # Default config is above, Overridden by ENV_VARIABLES!!! or command line 86 | # Sweep interacts weirdly with some things... 87 | wandb.init(project="CaffeDomain", config=args) 88 | 89 | # wandb.config.seed = args.seed 90 | # wandb.config.lr = args.lr 91 | # wandb.config.test_domain = args.test_domain 92 | 93 | config = wandb.config 94 | 95 | print(config) 96 | print("Data from:", config.data_dir) 97 | print("Logging to:", config.wandb_dir) 98 | 99 | use_cuda = not config.no_cuda and torch.cuda.is_available() 100 | 101 | # Set seed 102 | torch.manual_seed(config.seed) 103 | torch.backends.cudnn.benchmark = False 104 | np.random.seed(config.seed) 105 | 106 | device = torch.device("cuda") 107 | 108 | model_name = 'caffenet_val_acc_domain_classifier_da_' + config.da + '_seed_' + str(config.seed) 109 | 110 | kwargs = {'num_workers': 8, 'pin_memory': True} if use_cuda else {} 111 | 112 | # train_domain = [n for n in config.all_domains if n != config.test_domain[0]] 113 | # print(train_domain, config.test_domain) 114 | 115 | if config.da == 'none': 116 | transforms_pacs_train = transforms.Compose([ 117 | transforms.ToTensor(), 118 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 119 | ]) 120 | 121 | else: 122 | transforms_pacs_train = transforms.Compose([ 123 | transforms.RandomHorizontalFlip(p=0.5), 124 | transforms.ColorJitter(brightness=0.1, contrast=10.0, saturation=10.0, hue=0.5), 125 | transforms.ToTensor(), 126 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 127 | ]) 128 | 129 | # transforms_pacs_test = transforms.Compose([ 130 | # transforms.ToTensor(), 131 | # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 132 | # Hardcoded in the data loader 133 | 134 | # Load supervised training 135 | train_loader = data_utils.DataLoader( 136 | PacsDataDataAug('./kfold/', domain_list=config.all_domains, mode='train', transform=transforms_pacs_train), 137 | batch_size=config.batch_size, 138 | shuffle=True, **kwargs) 139 | val_loader = data_utils.DataLoader( 140 | PacsData('./kfold/', domain_list=config.all_domains, mode='val'), 141 | batch_size=config.batch_size, 142 | shuffle=False, **kwargs) 143 | 144 | model = caffenet(4).to(device) 145 | 146 | optimizer = optim.SGD(model.parameters(), weight_decay=.0005, momentum=.9, nesterov=True, lr=config.lr) 147 | step_size = int(config.epochs * .8) 148 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size) 149 | 150 | best_val_acc = 0.0 151 | 152 | for epoch in range(1, config.epochs + 1): 153 | # print('\n Epoch: ' + str(epoch)) 154 | train_loss = train(config, model, device, train_loader, optimizer, epoch) 155 | _, train_acc = test(config, model, device, train_loader, set_name='Train') 156 | val_loss, val_acc = test(config, model, device, val_loader, set_name='Val') 157 | 158 | wandb.log({'train_loss': train_loss, 'val_loss': val_loss, 'train_acc': train_acc, 'val_acc': val_acc}) 159 | 160 | scheduler.step() 161 | 162 | # Save best 163 | if val_acc >= best_val_acc: 164 | best_val_acc = val_acc 165 | 166 | torch.save(model, model_name + '.model') 167 | torch.save(args, model_name + '.config') 168 | 169 | # Test 170 | test_loader = data_utils.DataLoader( 171 | PacsData('./kfold/', domain_list=config.all_domains, mode='test'), 172 | batch_size=config.batch_size, 173 | shuffle=False, **kwargs) 174 | model = torch.load(model_name + '.model').to(device) 175 | _, test_acc = test(config, model, device, test_loader, set_name='Test') 176 | 177 | with open(model_name + '.txt', "w") as text_file: 178 | text_file.write("Test Acc: " + str(test_acc)) 179 | 180 | wandb.run.summary["test_accuracy"] = test_acc 181 | 182 | with open(model_name + '.txt', "w") as text_file: 183 | text_file.write("Test Acc: " + str(test_acc)) 184 | 185 | 186 | if __name__ == '__main__': 187 | main() -------------------------------------------------------------------------------- /PACS/jobs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AMLab-Amsterdam/DataAugmentationInterventions/78ce67174db487e9697b9a842e69818305bb41ef/PACS/jobs/__init__.py -------------------------------------------------------------------------------- /PACS/jobs/find_da.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd ~/AdditionalDriveA/DeployedProjects/TwoLatentSpacesVAE/paper_experiments/PACS 3 | echo "Starting" 4 | 5 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain art_painting --da brightness 6 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain art_painting --da contrast 7 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain art_painting --da saturation 8 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain art_painting --da hue 9 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain art_painting --da rotation 10 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain art_painting --da translate 11 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain art_painting --da scale 12 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain art_painting --da shear 13 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain art_painting --da vflip 14 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain art_painting --da hflip 15 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain art_painting --da none 16 | 17 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain cartoon --da brightness 18 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain cartoon --da contrast 19 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain cartoon --da saturation 20 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain cartoon --da hue 21 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain cartoon --da rotation 22 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain cartoon --da translate 23 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain cartoon --da scale 24 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain cartoon --da shear 25 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain cartoon --da vflip 26 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain cartoon --da hflip 27 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain cartoon --da none 28 | 29 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain photo --da brightness 30 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain photo --da contrast 31 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain photo --da saturation 32 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain photo --da hue 33 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain photo --da rotation 34 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain photo --da translate 35 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain photo --da scale 36 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain photo --da shear 37 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain photo --da vflip 38 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain photo --da hflip 39 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain photo --da none 40 | 41 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain sketch --da brightness 42 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain sketch --da contrast 43 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain sketch --da saturation 44 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain sketch --da hue 45 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain sketch --da rotation 46 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain sketch --da translate 47 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain sketch --da scale 48 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain sketch --da shear 49 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain sketch --da vflip 50 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain sketch --da hflip 51 | python fine_tune_alexnet_find_da.py --seed 0 --test_domain sketch --da none 52 | 53 | 54 | echo "Done" 55 | -------------------------------------------------------------------------------- /PACS/jobs/find_da1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd ~/AdditionalDriveA/DeployedProjects/TwoLatentSpacesVAE/paper_experiments/PACS 3 | echo "Starting" 4 | 5 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain art_painting --da brightness 6 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain art_painting --da contrast 7 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain art_painting --da saturation 8 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain art_painting --da hue 9 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain art_painting --da rotation 10 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain art_painting --da translate 11 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain art_painting --da scale 12 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain art_painting --da shear 13 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain art_painting --da vflip 14 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain art_painting --da hflip 15 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain art_painting --da none 16 | 17 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain cartoon --da brightness 18 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain cartoon --da contrast 19 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain cartoon --da saturation 20 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain cartoon --da hue 21 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain cartoon --da rotation 22 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain cartoon --da translate 23 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain cartoon --da scale 24 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain cartoon --da shear 25 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain cartoon --da vflip 26 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain cartoon --da hflip 27 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain cartoon --da none 28 | 29 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain photo --da brightness 30 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain photo --da contrast 31 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain photo --da saturation 32 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain photo --da hue 33 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain photo --da rotation 34 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain photo --da translate 35 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain photo --da scale 36 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain photo --da shear 37 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain photo --da vflip 38 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain photo --da hflip 39 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain photo --da none 40 | 41 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain sketch --da brightness 42 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain sketch --da contrast 43 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain sketch --da saturation 44 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain sketch --da hue 45 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain sketch --da rotation 46 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain sketch --da translate 47 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain sketch --da scale 48 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain sketch --da shear 49 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain sketch --da vflip 50 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain sketch --da hflip 51 | python fine_tune_alexnet_find_da.py --seed 1 --test_domain sketch --da none 52 | 53 | 54 | echo "Done" 55 | -------------------------------------------------------------------------------- /PACS/jobs/fine_tune_alexnet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd ~/AdditionalDriveA/DeployedProjects/TwoLatentSpacesVAE/paper_experiments/PACS 3 | echo "Starting" 4 | python fine_tune_alexnet.py --test_domain art_painting --seed 0 5 | python fine_tune_alexnet.py --test_domain cartoon --seed 0 6 | python fine_tune_alexnet.py --test_domain photo --seed 0 7 | python fine_tune_alexnet.py --test_domain sketch --seed 0 8 | 9 | python fine_tune_alexnet.py --test_domain art_painting --seed 1 10 | python fine_tune_alexnet.py --test_domain cartoon --seed 1 11 | python fine_tune_alexnet.py --test_domain photo --seed 1 12 | python fine_tune_alexnet.py --test_domain sketch --seed 1 13 | 14 | python fine_tune_alexnet.py --test_domain art_painting --seed 2 15 | python fine_tune_alexnet.py --test_domain cartoon --seed 2 16 | python fine_tune_alexnet.py --test_domain photo --seed 2 17 | python fine_tune_alexnet.py --test_domain sketch --seed 2 18 | 19 | python fine_tune_alexnet.py --test_domain art_painting --seed 3 20 | python fine_tune_alexnet.py --test_domain cartoon --seed 3 21 | python fine_tune_alexnet.py --test_domain photo --seed 3 22 | python fine_tune_alexnet.py --test_domain sketch --seed 3 23 | 24 | python fine_tune_alexnet.py --test_domain art_painting --seed 4 25 | python fine_tune_alexnet.py --test_domain cartoon --seed 4 26 | python fine_tune_alexnet.py --test_domain photo --seed 4 27 | python fine_tune_alexnet.py --test_domain sketch --seed 4 28 | 29 | 30 | echo "Done" 31 | -------------------------------------------------------------------------------- /PACS/jobs/fine_tune_alexnet_domain_classifier.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd ~/AdditionalDriveA/DeployedProjects/TwoLatentSpacesVAE/paper_experiments/PACS 3 | echo "Starting" 4 | python fine_tune_alexnet_domain_classifier.py --da none --seed 0 5 | python fine_tune_alexnet_domain_classifier.py --da jitter --seed 0 6 | 7 | python fine_tune_alexnet_domain_classifier.py --da none --seed 1 8 | python fine_tune_alexnet_domain_classifier.py --da jitter --seed 1 9 | 10 | python fine_tune_alexnet_domain_classifier.py --da none --seed 2 11 | python fine_tune_alexnet_domain_classifier.py --da jitter --seed 2 12 | 13 | python fine_tune_alexnet_domain_classifier.py --da none --seed 3 14 | python fine_tune_alexnet_domain_classifier.py --da jitter --seed 3 15 | 16 | python fine_tune_alexnet_domain_classifier.py --da none --seed 4 17 | python fine_tune_alexnet_domain_classifier.py --da jitter --seed 4 18 | 19 | echo "Done" 20 | -------------------------------------------------------------------------------- /PACS/jobs/fine_tune_alexnet_jitter_norm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd ~/AdditionalDriveA/DeployedProjects/TwoLatentSpacesVAE/paper_experiments/PACS 3 | echo "Starting" 4 | python fine_tune_alexnet_only_jitter_norm.py --test_domain art_painting --seed 0 5 | python fine_tune_alexnet_only_jitter_norm.py --test_domain cartoon --seed 0 6 | python fine_tune_alexnet_only_jitter_norm.py --test_domain photo --seed 0 7 | python fine_tune_alexnet_only_jitter_norm.py --test_domain sketch --seed 0 8 | 9 | python fine_tune_alexnet_only_jitter_norm.py --test_domain art_painting --seed 1 10 | python fine_tune_alexnet_only_jitter_norm.py --test_domain cartoon --seed 1 11 | python fine_tune_alexnet_only_jitter_norm.py --test_domain photo --seed 1 12 | python fine_tune_alexnet_only_jitter_norm.py --test_domain sketch --seed 1 13 | 14 | python fine_tune_alexnet_only_jitter_norm.py --test_domain art_painting --seed 2 15 | python fine_tune_alexnet_only_jitter_norm.py --test_domain cartoon --seed 2 16 | python fine_tune_alexnet_only_jitter_norm.py --test_domain photo --seed 2 17 | python fine_tune_alexnet_only_jitter_norm.py --test_domain sketch --seed 2 18 | 19 | python fine_tune_alexnet_only_jitter_norm.py --test_domain art_painting --seed 3 20 | python fine_tune_alexnet_only_jitter_norm.py --test_domain cartoon --seed 3 21 | python fine_tune_alexnet_only_jitter_norm.py --test_domain photo --seed 3 22 | python fine_tune_alexnet_only_jitter_norm.py --test_domain sketch --seed 3 23 | 24 | python fine_tune_alexnet_only_jitter_norm.py --test_domain art_painting --seed 4 25 | python fine_tune_alexnet_only_jitter_norm.py --test_domain cartoon --seed 4 26 | python fine_tune_alexnet_only_jitter_norm.py --test_domain photo --seed 4 27 | python fine_tune_alexnet_only_jitter_norm.py --test_domain sketch --seed 4 28 | 29 | 30 | echo "Done" 31 | -------------------------------------------------------------------------------- /PACS/jobs/with_all_da.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd ~/AdditionalDriveA/DeployedProjects/TwoLatentSpacesVAE/paper_experiments/PACS 3 | echo "Starting" 4 | 5 | python fine_tune_alexnet_with_all_da.py --seed 0 --test_domain art_painting 6 | python fine_tune_alexnet_with_all_da.py --seed 1 --test_domain art_painting 7 | python fine_tune_alexnet_with_all_da.py --seed 2 --test_domain art_painting 8 | python fine_tune_alexnet_with_all_da.py --seed 3 --test_domain art_painting 9 | python fine_tune_alexnet_with_all_da.py --seed 4 --test_domain art_painting 10 | 11 | python fine_tune_alexnet_with_all_da.py --seed 0 --test_domain cartoon 12 | python fine_tune_alexnet_with_all_da.py --seed 1 --test_domain cartoon 13 | python fine_tune_alexnet_with_all_da.py --seed 2 --test_domain cartoon 14 | python fine_tune_alexnet_with_all_da.py --seed 3 --test_domain cartoon 15 | python fine_tune_alexnet_with_all_da.py --seed 4 --test_domain cartoon 16 | 17 | python fine_tune_alexnet_with_all_da.py --seed 0 --test_domain photo 18 | python fine_tune_alexnet_with_all_da.py --seed 1 --test_domain photo 19 | python fine_tune_alexnet_with_all_da.py --seed 2 --test_domain photo 20 | python fine_tune_alexnet_with_all_da.py --seed 3 --test_domain photo 21 | python fine_tune_alexnet_with_all_da.py --seed 4 --test_domain photo 22 | 23 | python fine_tune_alexnet_with_all_da.py --seed 0 --test_domain sketch 24 | python fine_tune_alexnet_with_all_da.py --seed 1 --test_domain sketch 25 | python fine_tune_alexnet_with_all_da.py --seed 2 --test_domain sketch 26 | python fine_tune_alexnet_with_all_da.py --seed 3 --test_domain sketch 27 | python fine_tune_alexnet_with_all_da.py --seed 4 --test_domain sketch 28 | 29 | echo "Done" 30 | -------------------------------------------------------------------------------- /PACS/jobs/with_chosen_da.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd ~/AdditionalDriveA/DeployedProjects/TwoLatentSpacesVAE/paper_experiments/PACS 3 | echo "Starting" 4 | 5 | python fine_tune_alexnet_with_chosen_da.py --seed 0 --test_domain art_painting 6 | python fine_tune_alexnet_with_chosen_da.py --seed 1 --test_domain art_painting 7 | python fine_tune_alexnet_with_chosen_da.py --seed 2 --test_domain art_painting 8 | python fine_tune_alexnet_with_chosen_da.py --seed 3 --test_domain art_painting 9 | python fine_tune_alexnet_with_chosen_da.py --seed 4 --test_domain art_painting 10 | 11 | python fine_tune_alexnet_with_chosen_da.py --seed 0 --test_domain cartoon 12 | python fine_tune_alexnet_with_chosen_da.py --seed 1 --test_domain cartoon 13 | python fine_tune_alexnet_with_chosen_da.py --seed 2 --test_domain cartoon 14 | python fine_tune_alexnet_with_chosen_da.py --seed 3 --test_domain cartoon 15 | python fine_tune_alexnet_with_chosen_da.py --seed 4 --test_domain cartoon 16 | 17 | python fine_tune_alexnet_with_chosen_da.py --seed 0 --test_domain photo 18 | python fine_tune_alexnet_with_chosen_da.py --seed 1 --test_domain photo 19 | python fine_tune_alexnet_with_chosen_da.py --seed 2 --test_domain photo 20 | python fine_tune_alexnet_with_chosen_da.py --seed 3 --test_domain photo 21 | python fine_tune_alexnet_with_chosen_da.py --seed 4 --test_domain photo 22 | 23 | python fine_tune_alexnet_with_chosen_da.py --seed 0 --test_domain sketch 24 | python fine_tune_alexnet_with_chosen_da.py --seed 1 --test_domain sketch 25 | python fine_tune_alexnet_with_chosen_da.py --seed 2 --test_domain sketch 26 | python fine_tune_alexnet_with_chosen_da.py --seed 3 --test_domain sketch 27 | python fine_tune_alexnet_with_chosen_da.py --seed 4 --test_domain sketch 28 | 29 | echo "Done" 30 | -------------------------------------------------------------------------------- /PACS/pacs_data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from PIL import Image 4 | 5 | import torch 6 | import torch.utils.data as data_utils 7 | import torchvision.transforms as transforms 8 | 9 | 10 | class PacsData(data_utils.Dataset): 11 | def __init__(self, path, domain_list=None, mode='train'): 12 | self.path = path 13 | self.domain_list = domain_list 14 | self.mode = mode 15 | 16 | self.to_tensor = transforms.ToTensor() 17 | self.normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 18 | 19 | self.train_data, self.train_labels, self.train_domain = self.get_data() 20 | 21 | def get_imgs_and_labels(self, file_name, domain_path): 22 | 23 | with open(file_name, 'r') as data: 24 | img_paths = [] 25 | labels = [] 26 | for line in data: 27 | p = line.split() 28 | img_paths.append(domain_path + p[0]) 29 | labels.append(p[1]) 30 | 31 | img_list = [] 32 | label_list = [] 33 | 34 | for i, img_path in enumerate(img_paths): 35 | with open(img_path, 'rb') as f: 36 | with Image.open(f) as img: 37 | img = img.convert('RGB') 38 | 39 | img_list.append(self.normalize(self.to_tensor(img))) 40 | label_list.append((np.float(labels[i]) - 1.0)) 41 | 42 | return torch.stack(img_list), torch.Tensor(np.array(label_list)) 43 | 44 | def get_data(self): 45 | imgs_per_domain_list = [] 46 | labels_per_domain_list = [] 47 | domain_per_domain_list = [] 48 | 49 | for i, domain in enumerate(self.domain_list): 50 | 51 | domain_path = self.path + domain 52 | 53 | if self.mode == 'train': 54 | file_name = domain_path + '_train_kfold.txt' 55 | elif self.mode == 'val': 56 | file_name = domain_path + '_crossval_kfold.txt' 57 | elif self.mode == 'test': 58 | file_name = domain_path + '_test_kfold.txt' 59 | else: 60 | print('unkown mode found') 61 | 62 | imgs, labels = self.get_imgs_and_labels(file_name, self.path) 63 | domain_labels = torch.zeros(labels.size()) + i 64 | 65 | # append to final list 66 | imgs_per_domain_list.append(imgs) 67 | labels_per_domain_list.append(labels) 68 | domain_per_domain_list.append(domain_labels) 69 | 70 | # One last cat 71 | train_imgs = torch.cat(imgs_per_domain_list).squeeze() 72 | train_labels = torch.cat(labels_per_domain_list).long() 73 | train_domains = torch.cat(domain_per_domain_list).long() 74 | 75 | # Convert to onehot 76 | y = torch.eye(7) 77 | train_labels = y[train_labels] 78 | 79 | d = torch.eye(4) 80 | train_domains = d[train_domains] 81 | 82 | return train_imgs, train_labels, train_domains 83 | 84 | def __len__(self): 85 | return len(self.train_labels) 86 | 87 | def __getitem__(self, index): 88 | x = self.train_data[index] 89 | y = self.train_labels[index] 90 | d = self.train_domain[index] 91 | 92 | return x, y, d 93 | 94 | 95 | if __name__ == "__main__": 96 | from torchvision.utils import save_image 97 | 98 | seed = 0 99 | torch.manual_seed(seed) 100 | torch.backends.cudnn.benchmark = False 101 | np.random.seed(seed) 102 | 103 | # Train 104 | domain_list_train = ['art_painting', 'cartoon', 'photo', 'sketch'] 105 | 106 | train_loader = data_utils.DataLoader( 107 | PacsData('./kfold/', domain_list=domain_list_train, mode='train'), 108 | batch_size=128, 109 | shuffle=True) 110 | 111 | y_array = np.zeros(7) 112 | d_array = np.zeros(4) 113 | 114 | for i, (x, y, d) in enumerate(train_loader): 115 | 116 | y_array += y.sum(dim=0).cpu().numpy() 117 | d_array += d.sum(dim=0).cpu().numpy() 118 | 119 | if i == 0: 120 | n = min(x.size(0), 36) 121 | save_image(x[:n].cpu(), 122 | 'pacs_train.png', nrow=6) 123 | 124 | print(y_array, d_array) 125 | print('\n') 126 | 127 | train_loader = data_utils.DataLoader( 128 | PacsData('./kfold/', domain_list=domain_list_train, mode='val'), 129 | batch_size=128, 130 | shuffle=True) 131 | 132 | y_array = np.zeros(7) 133 | d_array = np.zeros(4) 134 | 135 | for i, (x, y, d) in enumerate(train_loader): 136 | 137 | y_array += y.sum(dim=0).cpu().numpy() 138 | d_array += d.sum(dim=0).cpu().numpy() 139 | 140 | if i == 0: 141 | n = min(x.size(0), 36) 142 | save_image(x[:n].cpu(), 143 | 'pacs_val.png', nrow=6) 144 | 145 | print(y_array, d_array) 146 | print('\n') 147 | 148 | train_loader = data_utils.DataLoader( 149 | PacsData('./kfold/', domain_list=domain_list_train, mode='test'), 150 | batch_size=128, 151 | shuffle=True) 152 | 153 | y_array = np.zeros(7) 154 | d_array = np.zeros(4) 155 | 156 | for i, (x, y, d) in enumerate(train_loader): 157 | 158 | y_array += y.sum(dim=0).cpu().numpy() 159 | d_array += d.sum(dim=0).cpu().numpy() 160 | 161 | if i == 0: 162 | n = min(x.size(0), 36) 163 | save_image(x[:n].cpu(), 164 | 'pacs_test.png', nrow=6) 165 | 166 | print(y_array, d_array) 167 | print('\n') 168 | -------------------------------------------------------------------------------- /PACS/pacs_data_loader_data_augmentation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from PIL import Image, ImageFilter 4 | 5 | import torch 6 | import torch.utils.data as data_utils 7 | import torchvision.transforms as transforms 8 | 9 | 10 | class PacsDataDataAug(data_utils.Dataset): 11 | def __init__(self, path, domain_list=None, mode='train', transform=None): 12 | self.path = path 13 | self.domain_list = domain_list 14 | self.mode = mode 15 | self.transform = transform 16 | 17 | self.train_data, self.train_labels, self.train_domain = self.get_data() 18 | 19 | def get_imgs_and_labels(self, file_name, domain_path): 20 | 21 | with open(file_name, 'r') as data: 22 | img_paths = [] 23 | labels = [] 24 | for line in data: 25 | p = line.split() 26 | img_paths.append(domain_path + p[0]) 27 | labels.append(p[1]) 28 | 29 | img_list = [] 30 | label_list = [] 31 | 32 | for i, img_path in enumerate(img_paths): 33 | with open(img_path, 'rb') as f: 34 | with Image.open(f) as img: 35 | img = img.convert('RGB') 36 | 37 | img_list.append(img) 38 | label_list.append((np.float(labels[i]) - 1.0)) 39 | 40 | return img_list, torch.Tensor(np.array(label_list)) 41 | 42 | def get_data(self): 43 | imgs_per_domain_list = [] 44 | labels_per_domain_list = [] 45 | domain_per_domain_list = [] 46 | 47 | for i, domain in enumerate(self.domain_list): 48 | 49 | domain_path = self.path + domain 50 | 51 | if self.mode == 'train': 52 | file_name = domain_path + '_train_kfold.txt' 53 | elif self.mode == 'val': 54 | file_name = domain_path + '_crossval_kfold.txt' 55 | elif self.mode == 'test': 56 | file_name = domain_path + '_test_kfold.txt' 57 | else: 58 | print('unkown mode found') 59 | 60 | imgs, labels = self.get_imgs_and_labels(file_name, self.path) 61 | domain_labels = torch.zeros(labels.size()) + i 62 | 63 | # append to final list 64 | imgs_per_domain_list.append(imgs) 65 | labels_per_domain_list.append(labels) 66 | domain_per_domain_list.append(domain_labels) 67 | 68 | # One last cat 69 | # train_imgs = torch.cat(imgs_per_domain_list).squeeze() 70 | train_imgs = [item for sublist in imgs_per_domain_list for item in sublist] 71 | 72 | train_labels = torch.cat(labels_per_domain_list).long() 73 | train_domains = torch.cat(domain_per_domain_list).long() 74 | 75 | # Convert to onehot 76 | y = torch.eye(7) 77 | train_labels = y[train_labels] 78 | 79 | d = torch.eye(4) 80 | train_domains = d[train_domains] 81 | 82 | return train_imgs, train_labels, train_domains 83 | 84 | def __len__(self): 85 | return len(self.train_labels) 86 | 87 | def __getitem__(self, index): 88 | x = self.train_data[index] 89 | y = self.train_labels[index] 90 | d = self.train_domain[index] 91 | 92 | return self.transform(x), y, d 93 | 94 | 95 | if __name__ == "__main__": 96 | from torchvision.utils import save_image 97 | import torchvision 98 | import PIL 99 | 100 | seed = 0 101 | da = 'saturation' 102 | torch.manual_seed(seed) 103 | torch.backends.cudnn.benchmark = False 104 | np.random.seed(seed) 105 | 106 | kwargs = {'num_workers': 8, 'pin_memory': True} 107 | 108 | # Train 109 | domain_list_train = ['art_painting', 'cartoon', 'photo', 'sketch'] 110 | 111 | transform_dict = {'brightness': torchvision.transforms.ColorJitter(brightness=1.0, contrast=0, saturation=0, hue=0), 112 | 'contrast': torchvision.transforms.ColorJitter(brightness=0, contrast=10.0, saturation=0, hue=0), 113 | 'saturation': torchvision.transforms.ColorJitter(brightness=0, contrast=0, saturation=10.0, hue=0), 114 | 'hue': torchvision.transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0.5), 115 | 'rotation': torchvision.transforms.RandomAffine([0, 359], translate=None, scale=None, shear=None, 116 | resample=PIL.Image.BILINEAR, fillcolor=0), 117 | 'translate': torchvision.transforms.RandomAffine(0, translate=[0.2, 0.2], scale=None, shear=None, 118 | resample=PIL.Image.BILINEAR, fillcolor=0), 119 | 'scale': torchvision.transforms.RandomAffine(0, translate=None, scale=[0.8, 1.2], shear=None, 120 | resample=PIL.Image.BILINEAR, fillcolor=0), 121 | 'shear': torchvision.transforms.RandomAffine(0, translate=None, scale=None, 122 | shear=[-10., 10., -10., 10.], 123 | resample=PIL.Image.BILINEAR, fillcolor=0), 124 | 'vflip': torchvision.transforms.RandomVerticalFlip(p=0.5), 125 | 'hflip': torchvision.transforms.RandomHorizontalFlip(p=0.5), 126 | 'none': None, 127 | } 128 | 129 | transforms_pacs = transforms.Compose([ 130 | transform_dict[da], 131 | transforms.ToTensor(), 132 | ]) 133 | 134 | train_loader = data_utils.DataLoader( 135 | PacsDataDataAug('./kfold/', domain_list=domain_list_train, mode='train', transform=transforms_pacs), 136 | batch_size=128, 137 | shuffle=True, 138 | **kwargs) 139 | 140 | y_array = np.zeros(7) 141 | d_array = np.zeros(4) 142 | 143 | for i, (x, y, d) in enumerate(train_loader): 144 | 145 | # y_array += y.sum(dim=0).cpu().numpy() 146 | # d_array += d.sum(dim=0).cpu().numpy() 147 | 148 | if i == 0: 149 | n = min(x.size(0), 48) 150 | save_image(x[:n].cpu(), 151 | '__pacs_train_' + da + '.png', nrow=8) 152 | 153 | break 154 | 155 | print(y_array, d_array) 156 | print('\n') 157 | -------------------------------------------------------------------------------- /PACS/pacs_data_loader_grey.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from PIL import Image 4 | 5 | import torch 6 | import torch.utils.data as data_utils 7 | import torchvision.transforms as transforms 8 | 9 | 10 | class PacsDataGrey(data_utils.Dataset): 11 | def __init__(self, path, domain_list=None, mode='train'): 12 | self.path = path 13 | self.domain_list = domain_list 14 | self.mode = mode 15 | 16 | self.to_tensor = transforms.ToTensor() 17 | self.to_grey = transforms.Grayscale(num_output_channels=3) 18 | 19 | self.train_data, self.train_labels, self.train_domain = self.get_data() 20 | 21 | def get_imgs_and_labels(self, file_name, domain_path): 22 | 23 | with open(file_name, 'r') as data: 24 | img_paths = [] 25 | labels = [] 26 | for line in data: 27 | p = line.split() 28 | img_paths.append(domain_path + p[0]) 29 | labels.append(p[1]) 30 | 31 | img_list = [] 32 | label_list = [] 33 | 34 | for i, img_path in enumerate(img_paths): 35 | with open(img_path, 'rb') as f: 36 | with Image.open(f) as img: 37 | img = img.convert('RGB') 38 | 39 | img_list.append(self.to_tensor(self.to_grey(img))) 40 | label_list.append((np.float(labels[i]) - 1.0)) 41 | 42 | return torch.stack(img_list), torch.Tensor(np.array(label_list)) 43 | 44 | def get_data(self): 45 | imgs_per_domain_list = [] 46 | labels_per_domain_list = [] 47 | domain_per_domain_list = [] 48 | 49 | for i, domain in enumerate(self.domain_list): 50 | 51 | domain_path = self.path + domain 52 | 53 | if self.mode == 'train': 54 | file_name = domain_path + '_train_kfold.txt' 55 | elif self.mode == 'val': 56 | file_name = domain_path + '_crossval_kfold.txt' 57 | elif self.mode == 'test': 58 | file_name = domain_path + '_test_kfold.txt' 59 | else: 60 | print('unkown mode found') 61 | 62 | imgs, labels = self.get_imgs_and_labels(file_name, self.path) 63 | domain_labels = torch.zeros(labels.size()) + i 64 | 65 | # append to final list 66 | imgs_per_domain_list.append(imgs) 67 | labels_per_domain_list.append(labels) 68 | domain_per_domain_list.append(domain_labels) 69 | 70 | # One last cat 71 | train_imgs = torch.cat(imgs_per_domain_list).squeeze() 72 | train_labels = torch.cat(labels_per_domain_list).long() 73 | train_domains = torch.cat(domain_per_domain_list).long() 74 | 75 | # Convert to onehot 76 | y = torch.eye(7) 77 | train_labels = y[train_labels] 78 | 79 | d = torch.eye(4) 80 | train_domains = d[train_domains] 81 | 82 | return train_imgs, train_labels, train_domains 83 | 84 | def __len__(self): 85 | return len(self.train_labels) 86 | 87 | def __getitem__(self, index): 88 | x = self.train_data[index] 89 | y = self.train_labels[index] 90 | d = self.train_domain[index] 91 | 92 | return x, y, d 93 | 94 | 95 | if __name__ == "__main__": 96 | from torchvision.utils import save_image 97 | 98 | seed = 0 99 | torch.manual_seed(seed) 100 | torch.backends.cudnn.benchmark = False 101 | np.random.seed(seed) 102 | 103 | # Train 104 | domain_list_train = ['art_painting', 'cartoon', 'photo', 'sketch'] 105 | 106 | train_loader = data_utils.DataLoader( 107 | PacsDataGrey('./kfold/', domain_list=domain_list_train, mode='train'), 108 | batch_size=128, 109 | shuffle=True) 110 | 111 | y_array = np.zeros(7) 112 | d_array = np.zeros(4) 113 | 114 | for i, (x, y, d) in enumerate(train_loader): 115 | 116 | y_array += y.sum(dim=0).cpu().numpy() 117 | d_array += d.sum(dim=0).cpu().numpy() 118 | 119 | if i == 0: 120 | n = min(x.size(0), 36) 121 | save_image(x[:n].cpu(), 122 | 'pacs_train_grey.png', nrow=6) 123 | 124 | print(y_array, d_array) 125 | print('\n') 126 | 127 | train_loader = data_utils.DataLoader( 128 | PacsDataGrey('./kfold/', domain_list=domain_list_train, mode='val'), 129 | batch_size=128, 130 | shuffle=True) 131 | 132 | y_array = np.zeros(7) 133 | d_array = np.zeros(4) 134 | 135 | for i, (x, y, d) in enumerate(train_loader): 136 | 137 | y_array += y.sum(dim=0).cpu().numpy() 138 | d_array += d.sum(dim=0).cpu().numpy() 139 | 140 | if i == 0: 141 | n = min(x.size(0), 36) 142 | save_image(x[:n].cpu(), 143 | 'pacs_val_grey.png', nrow=6) 144 | 145 | print(y_array, d_array) 146 | print('\n') 147 | 148 | train_loader = data_utils.DataLoader( 149 | PacsDataGrey('./kfold/', domain_list=domain_list_train, mode='test'), 150 | batch_size=128, 151 | shuffle=True) 152 | 153 | y_array = np.zeros(7) 154 | d_array = np.zeros(4) 155 | 156 | for i, (x, y, d) in enumerate(train_loader): 157 | 158 | y_array += y.sum(dim=0).cpu().numpy() 159 | d_array += d.sum(dim=0).cpu().numpy() 160 | 161 | if i == 0: 162 | n = min(x.size(0), 36) 163 | save_image(x[:n].cpu(), 164 | 'pacs_test_grey.png', nrow=6) 165 | 166 | print(y_array, d_array) 167 | print('\n') 168 | -------------------------------------------------------------------------------- /PACS/pacs_data_loader_norm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from PIL import Image 4 | 5 | import torch 6 | import torch.utils.data as data_utils 7 | import torchvision.transforms as transforms 8 | 9 | 10 | class PacsData(data_utils.Dataset): 11 | def __init__(self, path, domain_list=None, mode='train'): 12 | self.path = path 13 | self.domain_list = domain_list 14 | self.mode = mode 15 | 16 | self.to_tensor = transforms.ToTensor() 17 | self.normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 18 | 19 | self.train_data, self.train_labels, self.train_domain = self.get_data() 20 | 21 | def get_imgs_and_labels(self, file_name, domain_path): 22 | 23 | with open(file_name, 'r') as data: 24 | img_paths = [] 25 | labels = [] 26 | for line in data: 27 | p = line.split() 28 | img_paths.append(domain_path + p[0]) 29 | labels.append(p[1]) 30 | 31 | img_list = [] 32 | label_list = [] 33 | 34 | for i, img_path in enumerate(img_paths): 35 | with open(img_path, 'rb') as f: 36 | with Image.open(f) as img: 37 | img = img.convert('RGB') 38 | 39 | img_list.append(self.normalize(self.to_tensor(img))) 40 | # img_list.append((self.to_tensor(img))) 41 | label_list.append((np.float(labels[i]) - 1.0)) 42 | 43 | return torch.stack(img_list), torch.Tensor(np.array(label_list)) 44 | 45 | def get_data(self): 46 | imgs_per_domain_list = [] 47 | labels_per_domain_list = [] 48 | domain_per_domain_list = [] 49 | 50 | for i, domain in enumerate(self.domain_list): 51 | 52 | domain_path = self.path + domain 53 | 54 | if self.mode == 'train': 55 | file_name = domain_path + '_train_kfold.txt' 56 | elif self.mode == 'val': 57 | file_name = domain_path + '_crossval_kfold.txt' 58 | elif self.mode == 'test': 59 | file_name = domain_path + '_test_kfold.txt' 60 | else: 61 | print('unkown mode found') 62 | 63 | imgs, labels = self.get_imgs_and_labels(file_name, self.path) 64 | domain_labels = torch.zeros(labels.size()) + i 65 | 66 | # append to final list 67 | imgs_per_domain_list.append(imgs) 68 | labels_per_domain_list.append(labels) 69 | domain_per_domain_list.append(domain_labels) 70 | 71 | # One last cat 72 | train_imgs = torch.cat(imgs_per_domain_list).squeeze() 73 | train_labels = torch.cat(labels_per_domain_list).long() 74 | train_domains = torch.cat(domain_per_domain_list).long() 75 | 76 | # Convert to onehot 77 | y = torch.eye(7) 78 | train_labels = y[train_labels] 79 | 80 | d = torch.eye(4) 81 | train_domains = d[train_domains] 82 | 83 | return train_imgs, train_labels, train_domains 84 | 85 | def __len__(self): 86 | return len(self.train_labels) 87 | 88 | def __getitem__(self, index): 89 | x = self.train_data[index] 90 | y = self.train_labels[index] 91 | d = self.train_domain[index] 92 | 93 | return x, y, d 94 | 95 | 96 | if __name__ == "__main__": 97 | from torchvision.utils import save_image 98 | 99 | seed = 0 100 | torch.manual_seed(seed) 101 | torch.backends.cudnn.benchmark = False 102 | np.random.seed(seed) 103 | 104 | # Train 105 | domain_list_train = ['art_painting', 'cartoon', 'photo', 'sketch'] 106 | 107 | train_loader = data_utils.DataLoader( 108 | PacsData('./kfold/', domain_list=domain_list_train, mode='train'), 109 | batch_size=128, 110 | shuffle=True) 111 | 112 | y_array = np.zeros(7) 113 | d_array = np.zeros(4) 114 | 115 | for i, (x, y, d) in enumerate(train_loader): 116 | 117 | y_array += y.sum(dim=0).cpu().numpy() 118 | d_array += d.sum(dim=0).cpu().numpy() 119 | 120 | if i == 0: 121 | n = min(x.size(0), 36) 122 | save_image(x[:n].cpu(), 123 | 'pacs_train.png', nrow=6) 124 | 125 | print(y_array, d_array) 126 | print('\n') 127 | 128 | train_loader = data_utils.DataLoader( 129 | PacsData('./kfold/', domain_list=domain_list_train, mode='val'), 130 | batch_size=128, 131 | shuffle=True) 132 | 133 | y_array = np.zeros(7) 134 | d_array = np.zeros(4) 135 | 136 | for i, (x, y, d) in enumerate(train_loader): 137 | 138 | y_array += y.sum(dim=0).cpu().numpy() 139 | d_array += d.sum(dim=0).cpu().numpy() 140 | 141 | if i == 0: 142 | n = min(x.size(0), 36) 143 | save_image(x[:n].cpu(), 144 | 'pacs_val.png', nrow=6) 145 | 146 | print(y_array, d_array) 147 | print('\n') 148 | 149 | train_loader = data_utils.DataLoader( 150 | PacsData('./kfold/', domain_list=domain_list_train, mode='test'), 151 | batch_size=128, 152 | shuffle=True) 153 | 154 | y_array = np.zeros(7) 155 | d_array = np.zeros(4) 156 | 157 | for i, (x, y, d) in enumerate(train_loader): 158 | 159 | y_array += y.sum(dim=0).cpu().numpy() 160 | d_array += d.sum(dim=0).cpu().numpy() 161 | 162 | if i == 0: 163 | n = min(x.size(0), 36) 164 | save_image(x[:n].cpu(), 165 | 'pacs_test.png', nrow=6) 166 | 167 | print(y_array, d_array) 168 | print('\n') 169 | -------------------------------------------------------------------------------- /PACS/pacs_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AMLab-Amsterdam/DataAugmentationInterventions/78ce67174db487e9697b9a842e69818305bb41ef/PACS/pacs_example.png -------------------------------------------------------------------------------- /PACS/pacs_examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AMLab-Amsterdam/DataAugmentationInterventions/78ce67174db487e9697b9a842e69818305bb41ef/PACS/pacs_examples.png -------------------------------------------------------------------------------- /PACS/pretrained/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AMLab-Amsterdam/DataAugmentationInterventions/78ce67174db487e9697b9a842e69818305bb41ef/PACS/pretrained/__init__.py -------------------------------------------------------------------------------- /PACS/samples_for_paper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from PIL import Image 4 | 5 | import torch 6 | import torch.utils.data as data_utils 7 | import torchvision.transforms as transforms 8 | from torchvision.utils import save_image 9 | 10 | to_tensor = transforms.ToTensor() 11 | 12 | img_paths = ['./kfold/art_painting/dog/pic_001.jpg', 13 | './kfold/cartoon/dog/pic_001.jpg', 14 | './kfold/photo/dog/056_0002.jpg', 15 | './kfold/sketch/dog/5281.png', 16 | './kfold/art_painting/elephant/pic_001.jpg', 17 | './kfold/cartoon/elephant/pic_001.jpg', 18 | './kfold/photo/elephant/064_0001.jpg', 19 | './kfold/sketch/elephant/5921.png', 20 | './kfold/art_painting/giraffe/pic_001.jpg', 21 | './kfold/cartoon/giraffe/pic_001.jpg', 22 | './kfold/photo/giraffe/084_0001.jpg', 23 | './kfold/sketch/giraffe/7361.png', 24 | './kfold/art_painting/guitar/pic_001.jpg', 25 | './kfold/cartoon/guitar/pic_001.jpg', 26 | './kfold/photo/guitar/063_0001.jpg', 27 | './kfold/sketch/guitar/7601.png', 28 | ] 29 | 30 | img_list = [] 31 | for i, img_path in enumerate(img_paths): 32 | with open(img_path, 'rb') as f: 33 | with Image.open(f) as img: 34 | img = img.convert('RGB') 35 | 36 | img_list.append(to_tensor(img)) 37 | 38 | 39 | 40 | save_image(img_list, 41 | 'pacs_example.png', nrow=4) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Selecting Data Augmentation for Simulating Interventions 2 | ================================================ 3 | by Maximilian Ilse (), Jakub M. Tomczak and Patrick Forré 4 | 5 | Overview 6 | -------- 7 | PyTorch implementation of our paper "Selecting Data Augmentation for Simulating Interventions": 8 | * Ilse, M., Tomczak, J. M., & Forré, P. (2020). Selecting Data Augmentation for Simulating Interventions. https://arxiv.org/abs/2005.01856 9 | 10 | Used modules 11 | ------------ 12 | - Python 3.6 13 | - PyTorch 1.0.1 14 | 15 | Datasets 16 | -------- 17 | - MNIST: http://yann.lecun.com/exdb/mnist/ 18 | - PACS: https://domaingeneralization.github.io/ 19 | 20 | Pre-trained AlexNet 21 | ------------------- 22 | To reproduce our results on the PACS dataset, please use: https://drive.google.com/file/d/1wUJTH1Joq2KAgrUDeKJghP1Wf7Q9w4z-/view?usp=sharing 23 | 24 | Story behind the paper 25 | ---------------------- 26 | Everybody that works with medical imaging data eventually comes across the following problem: appearance variability. This variability is usually caused by the equipment used to generate medical imaging data, e.g., CT scanners from different vendors will generate images with different intensity patterns. If we train a CNN on data from a single scanner we are likely to overfit on the specific intensity pattern of the scanner. As a result, we are likely to fail to generalize to data from a different scanner. 27 | 28 | In late 2018, we started to work on the problem of domain generalization/learning invariant representations motivated by the appearance variability in medical imaging data described above. In domain generalization, one tries to find a representation that generalizes across different environments, called domains, each with a different shift of the input. 29 | 30 | This eventually led to a model that we called the Domain Invariant Variational Autoencoder (DIVA, https://arxiv.org/abs/1905.10427, thanks to my co-authors!). While the results of DIVA are promising, there were a couple of experiments that didn’t make it into the paper since the performance of DIVA didn’t match a simple baseline CNN. For a while, we thought it is probably due to optimization issues, etc. During 2019, we realized that we had a very poor understanding of the problem itself. 31 | 32 | Questions and Issues 33 | -------------------- 34 | 35 | If you find any bugs or have any questions about this code please contact Maximilian. We cannot guarantee any support for this software. 36 | 37 | 38 | Citation 39 | -------------------- 40 | 41 | Please cite our paper if you use this code in your research: 42 | ``` 43 | @article{ilse_selecting_2020, 44 | title = {Selecting {Data} {Augmentation} for {Simulating} {Interventions}}, 45 | url = {http://arxiv.org/abs/2005.01856}, 46 | urldate = {2020-05-06}, 47 | journal = {arXiv:2005.01856 [cs, stat]}, 48 | author = {Ilse, Maximilian and Tomczak, Jakub M. and Forré, Patrick}, 49 | month = may, 50 | year = {2020}, 51 | note = {arXiv: 2005.01856} 52 | ``` 53 | 54 | Acknowledgments 55 | -------------------- 56 | 57 | The work conducted by Maximilian Ilse was funded by the Nederlandse Organisatie voor Wetenschappelijk Onderzoek (Grant DLMedIa: Deep Learning for Medical Image Analysis). 58 | -------------------------------------------------------------------------------- /colored_mnist/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AMLab-Amsterdam/DataAugmentationInterventions/78ce67174db487e9697b9a842e69818305bb41ef/colored_mnist/__init__.py -------------------------------------------------------------------------------- /colored_mnist/evaulate_domain_experiments.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import re 4 | import numpy as np 5 | 6 | os.chdir('./') 7 | 8 | brightness_train = [] 9 | brightness_val = [] 10 | contrast_train = [] 11 | contrast_val = [] 12 | saturation_train = [] 13 | saturation_val = [] 14 | hue_train = [] 15 | hue_val = [] 16 | rotation_train = [] 17 | rotation_val = [] 18 | translate_train = [] 19 | translate_val = [] 20 | scale_train = [] 21 | scale_val = [] 22 | shear_train = [] 23 | shear_val = [] 24 | vflip_train = [] 25 | vflip_val = [] 26 | hflip_train = [] 27 | hflip_val = [] 28 | 29 | for file in glob.glob("*.txt"): 30 | if 'brightness' in file: 31 | with open(file) as f: 32 | content = f.readlines() 33 | train, val = re.findall(r"[-+]?\d*\.\d+|\d+", content[0]) 34 | train, val = float(train), float(val) 35 | brightness_train.append(train) 36 | brightness_val.append(val) 37 | 38 | if 'contrast' in file: 39 | with open(file) as f: 40 | content = f.readlines() 41 | train, val = re.findall(r"[-+]?\d*\.\d+|\d+", content[0]) 42 | train, val = float(train), float(val) 43 | contrast_train.append(train) 44 | contrast_val.append(val) 45 | 46 | if 'saturation' in file: 47 | with open(file) as f: 48 | content = f.readlines() 49 | train, val = re.findall(r"[-+]?\d*\.\d+|\d+", content[0]) 50 | train, val = float(train), float(val) 51 | saturation_train.append(train) 52 | saturation_val.append(val) 53 | 54 | if 'hue' in file: 55 | with open(file) as f: 56 | content = f.readlines() 57 | train, val = re.findall(r"[-+]?\d*\.\d+|\d+", content[0]) 58 | train, val = float(train), float(val) 59 | hue_train.append(train) 60 | hue_val.append(val) 61 | 62 | if 'rotation' in file: 63 | with open(file) as f: 64 | content = f.readlines() 65 | train, val = re.findall(r"[-+]?\d*\.\d+|\d+", content[0]) 66 | train, val = float(train), float(val) 67 | rotation_train.append(train) 68 | rotation_val.append(val) 69 | 70 | if 'translate' in file: 71 | with open(file) as f: 72 | content = f.readlines() 73 | train, val = re.findall(r"[-+]?\d*\.\d+|\d+", content[0]) 74 | train, val = float(train), float(val) 75 | translate_train.append(train) 76 | translate_val.append(val) 77 | 78 | if 'scale' in file: 79 | with open(file) as f: 80 | content = f.readlines() 81 | train, val = re.findall(r"[-+]?\d*\.\d+|\d+", content[0]) 82 | train, val = float(train), float(val) 83 | scale_train.append(train) 84 | scale_val.append(val) 85 | 86 | if 'shear' in file: 87 | with open(file) as f: 88 | content = f.readlines() 89 | train, val = re.findall(r"[-+]?\d*\.\d+|\d+", content[0]) 90 | train, val = float(train), float(val) 91 | shear_train.append(train) 92 | shear_val.append(val) 93 | 94 | if 'vflip' in file: 95 | with open(file) as f: 96 | content = f.readlines() 97 | train, val = re.findall(r"[-+]?\d*\.\d+|\d+", content[0]) 98 | train, val = float(train), float(val) 99 | vflip_train.append(train) 100 | vflip_val.append(val) 101 | 102 | if 'hflip' in file: 103 | with open(file) as f: 104 | content = f.readlines() 105 | train, val = re.findall(r"[-+]?\d*\.\d+|\d+", content[0]) 106 | train, val = float(train), float(val) 107 | hflip_train.append(train) 108 | hflip_val.append(val) 109 | 110 | brightness_train_std = np.std(np.array(brightness_train)) 111 | brightness_val_std = np.std(np.array(brightness_val)) 112 | contrast_train_std = np.std(np.array(contrast_train)) 113 | contrast_val_std = np.std(np.array(contrast_val)) 114 | saturation_train_std = np.std(np.array(saturation_train)) 115 | saturation_val_std = np.std(np.array(saturation_val)) 116 | hue_train_std = np.std(np.array(hue_train)) 117 | hue_val_std = np.std(np.array(hue_val)) 118 | rotation_train_std = np.std(np.array(rotation_train)) 119 | rotation_val_std = np.std(np.array(rotation_val)) 120 | translate_train_std = np.std(np.array(translate_train)) 121 | translate_val_std = np.std(np.array(translate_val)) 122 | scale_train_std = np.std(np.array(scale_train)) 123 | scale_val_std = np.std(np.array(scale_val)) 124 | shear_train_std = np.std(np.array(shear_train)) 125 | shear_val_std = np.std(np.array(shear_val)) 126 | vflip_train_std = np.std(np.array(vflip_train)) 127 | vflip_val_std = np.std(np.array(vflip_val)) 128 | hflip_train_std = np.std(np.array(hflip_train)) 129 | hflip_val_std = np.std(np.array(hflip_val)) 130 | 131 | 132 | brightness_train = np.mean(np.array(brightness_train)) 133 | brightness_val = np.mean(np.array(brightness_val)) 134 | contrast_train = np.mean(np.array(contrast_train)) 135 | contrast_val = np.mean(np.array(contrast_val)) 136 | saturation_train = np.mean(np.array(saturation_train)) 137 | saturation_val = np.mean(np.array(saturation_val)) 138 | hue_train = np.mean(np.array(hue_train)) 139 | hue_val = np.mean(np.array(hue_val)) 140 | rotation_train = np.mean(np.array(rotation_train)) 141 | rotation_val = np.mean(np.array(rotation_val)) 142 | translate_train = np.mean(np.array(translate_train)) 143 | translate_val = np.mean(np.array(translate_val)) 144 | scale_train = np.mean(np.array(scale_train)) 145 | scale_val = np.mean(np.array(scale_val)) 146 | shear_train = np.mean(np.array(shear_train)) 147 | shear_val = np.mean(np.array(shear_val)) 148 | vflip_train = np.mean(np.array(vflip_train)) 149 | vflip_val = np.mean(np.array(vflip_val)) 150 | hflip_train = np.mean(np.array(hflip_train)) 151 | hflip_val = np.mean(np.array(hflip_val)) 152 | 153 | print('mean') 154 | print(brightness_train) 155 | print(brightness_val) 156 | print(contrast_train) 157 | print(contrast_val) 158 | print(saturation_train) 159 | print(saturation_val) 160 | print(hue_train) 161 | print(hue_val) 162 | print(rotation_train) 163 | print(rotation_val) 164 | print(translate_train) 165 | print(translate_val) 166 | print(scale_train) 167 | print(scale_val) 168 | print(shear_train) 169 | print(shear_val) 170 | print(vflip_train) 171 | print(vflip_val) 172 | print(hflip_train) 173 | print(hflip_val) 174 | print('std') 175 | print(brightness_train_std) 176 | print(brightness_val_std) 177 | print(contrast_train_std) 178 | print(contrast_val_std) 179 | print(saturation_train_std) 180 | print(saturation_val_std) 181 | print(hue_train_std) 182 | print(hue_val_std) 183 | print(rotation_train_std) 184 | print(rotation_val_std) 185 | print(translate_train_std) 186 | print(translate_val_std) 187 | print(scale_train_std) 188 | print(scale_val_std) 189 | print(shear_train_std) 190 | print(shear_val_std) 191 | print(vflip_train_std) 192 | print(vflip_val_std) 193 | print(hflip_train_std) 194 | print(hflip_val_std) 195 | 196 | -------------------------------------------------------------------------------- /colored_mnist/jobs_da.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd ~/AdditionalDriveA/DeployedProjects/TwoLatentSpacesVAE/paper_experiments/colored_mnist/ 3 | echo "Starting" 4 | #python choose_da_with_domain_classifier.py --seed 0 --da brightness 5 | python choose_da_with_domain_classifier.py --seed 0 --da contrast 6 | python choose_da_with_domain_classifier.py --seed 0 --da saturation 7 | #python choose_da_with_domain_classifier.py --seed 0 --da hue 8 | #python choose_da_with_domain_classifier.py --seed 0 --da rotation 9 | #python choose_da_with_domain_classifier.py --seed 0 --da translate 10 | #python choose_da_with_domain_classifier.py --seed 0 --da scale 11 | #python choose_da_with_domain_classifier.py --seed 0 --da shear 12 | #python choose_da_with_domain_classifier.py --seed 0 --da hflip 13 | #python choose_da_with_domain_classifier.py --seed 0 --da vflip 14 | #python choose_da_with_domain_classifier.py --seed 0 --da none 15 | 16 | #python choose_da_with_domain_classifier.py --seed 1 --da brightness 17 | python choose_da_with_domain_classifier.py --seed 1 --da contrast 18 | python choose_da_with_domain_classifier.py --seed 1 --da saturation 19 | #python choose_da_with_domain_classifier.py --seed 1 --da hue 20 | #python choose_da_with_domain_classifier.py --seed 1 --da rotation 21 | #python choose_da_with_domain_classifier.py --seed 1 --da translate 22 | #python choose_da_with_domain_classifier.py --seed 1 --da scale 23 | #python choose_da_with_domain_classifier.py --seed 1 --da shear 24 | #python choose_da_with_domain_classifier.py --seed 1 --da hflip 25 | #python choose_da_with_domain_classifier.py --seed 1 --da vflip 26 | #python choose_da_with_domain_classifier.py --seed 1 --da none 27 | 28 | #python choose_da_with_domain_classifier.py --seed 2 --da brightness 29 | python choose_da_with_domain_classifier.py --seed 2 --da contrast 30 | python choose_da_with_domain_classifier.py --seed 2 --da saturation 31 | #python choose_da_with_domain_classifier.py --seed 2 --da hue 32 | #python choose_da_with_domain_classifier.py --seed 2 --da rotation 33 | #python choose_da_with_domain_classifier.py --seed 2 --da translate 34 | #python choose_da_with_domain_classifier.py --seed 2 --da scale 35 | #python choose_da_with_domain_classifier.py --seed 2 --da shear 36 | #python choose_da_with_domain_classifier.py --seed 2 --da hflip 37 | #python choose_da_with_domain_classifier.py --seed 2 --da vflip 38 | #python choose_da_with_domain_classifier.py --seed 2 --da none 39 | 40 | #python choose_da_with_domain_classifier.py --seed 3 --da brightness 41 | python choose_da_with_domain_classifier.py --seed 3 --da contrast 42 | python choose_da_with_domain_classifier.py --seed 3 --da saturation 43 | #python choose_da_with_domain_classifier.py --seed 3 --da hue 44 | #python choose_da_with_domain_classifier.py --seed 3 --da rotation 45 | #python choose_da_with_domain_classifier.py --seed 3 --da translate 46 | #python choose_da_with_domain_classifier.py --seed 3 --da scale 47 | #python choose_da_with_domain_classifier.py --seed 3 --da shear 48 | #python choose_da_with_domain_classifier.py --seed 3 --da hflip 49 | #python choose_da_with_domain_classifier.py --seed 3 --da vflip 50 | #python choose_da_with_domain_classifier.py --seed 3 --da none 51 | 52 | #python choose_da_with_domain_classifier.py --seed 4 --da brightness 53 | python choose_da_with_domain_classifier.py --seed 4 --da contrast 54 | python choose_da_with_domain_classifier.py --seed 4 --da saturation 55 | #python choose_da_with_domain_classifier.py --seed 4 --da hue 56 | #python choose_da_with_domain_classifier.py --seed 4 --da rotation 57 | #python choose_da_with_domain_classifier.py --seed 4 --da translate 58 | #python choose_da_with_domain_classifier.py --seed 4 --da scale 59 | #python choose_da_with_domain_classifier.py --seed 4 --da shear 60 | #python choose_da_with_domain_classifier.py --seed 4 --da hflip 61 | #python choose_da_with_domain_classifier.py --seed 4 --da vflip 62 | #python choose_da_with_domain_classifier.py --seed 4 --da none 63 | 64 | 65 | echo "Done" 66 | -------------------------------------------------------------------------------- /colored_mnist/main_random_baseline.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import argparse 9 | import numpy as np 10 | import torch 11 | from torchvision import datasets 12 | from torch import nn, optim, autograd 13 | from torch.utils.data import TensorDataset, DataLoader 14 | import torch.utils.data 15 | import torch.nn.functional as F 16 | from torchvision.utils import save_image 17 | 18 | 19 | # Define and instantiate the model 20 | class MLP(nn.Module): 21 | def __init__(self): 22 | super(MLP, self).__init__() 23 | 24 | lin1 = nn.Linear(2 * 14 * 14, args.hidden_dim) 25 | lin2 = nn.Linear(args.hidden_dim, args.hidden_dim) 26 | lin3 = nn.Linear(args.hidden_dim, 1) 27 | for lin in [lin1, lin2, lin3]: 28 | nn.init.xavier_uniform_(lin.weight) 29 | nn.init.zeros_(lin.bias) 30 | self._main = nn.Sequential(lin1, nn.ReLU(True), lin2, nn.ReLU(True), lin3) 31 | 32 | def forward(self, input): 33 | out = input.view(input.shape[0], 2 * 14 * 14) 34 | out = self._main(out) 35 | return F.sigmoid(out) 36 | 37 | 38 | # Build environments 39 | def torch_bernoulli(p, size): 40 | return (torch.rand(size) < p).float() 41 | 42 | 43 | def make_environment(images, labels, e): 44 | def torch_xor(a, b): 45 | return (a - b).abs() # Assumes both inputs are either 0 or 1 46 | 47 | # 2x subsample for computational convenience 48 | images = images.reshape((-1, 28, 28))[:, ::2, ::2] 49 | # Assign a binary label based on the digit; flip label with probability 0.25 50 | labels = (labels < 5).float() 51 | labels = torch_xor(labels, torch_bernoulli(0.25, len(labels))) 52 | # Assign a color based on the label; flip the color with probability e 53 | colors = torch_xor(labels, torch_bernoulli(e, len(labels))) 54 | # Apply the color to the image by zeroing out the other color channel 55 | images = torch.stack([images, images], dim=1) 56 | images[torch.tensor(range(len(images))), (1 - colors).long(), :, :] *= 0 57 | return { 58 | 'images': (images.float() / 255.), 59 | 'labels': labels[:, None]} 60 | 61 | 62 | def train(args, model, device, train_loader, optimizer, epoch): 63 | model.train() 64 | for batch_idx, (data, target, _) in enumerate(train_loader): 65 | data, target = data.to(device), target.to(device) 66 | 67 | optimizer.zero_grad() 68 | output = model(data) 69 | loss = nn.BCELoss()(output, target) 70 | loss.backward() 71 | optimizer.step() 72 | 73 | # save_image(data.cpu(), 74 | # 'baseline.png', nrow=1) 75 | 76 | 77 | def test(args, model, device, test_loader): 78 | model.eval() 79 | test_loss = 0 80 | correct = 0 81 | with torch.no_grad(): 82 | for data, target, _ in test_loader: 83 | data, target = data.to(device), target.to(device) 84 | 85 | output = model(data) 86 | test_loss += nn.BCELoss()(output, target).item() # sum up batch loss 87 | pred = output >= 0.5 88 | correct += pred.eq(target.view_as(pred)).sum().item() 89 | 90 | test_loss /= len(test_loader.dataset) 91 | 92 | return test_loss, 100. * correct / len(test_loader.dataset) 93 | 94 | 95 | parser = argparse.ArgumentParser(description='Colored MNIST') 96 | parser.add_argument('--hidden_dim', type=int, default=256) 97 | # parser.add_argument('--l2_regularizer_weight', type=float,default=0.001) 98 | parser.add_argument('--lr', type=float, default=0.001) 99 | parser.add_argument('--epochs', type=int, default=100) 100 | args = parser.parse_args() 101 | 102 | # torch.manual_seed(0) 103 | # torch.backends.cudnn.benchmark = False 104 | # np.random.seed(0) 105 | 106 | # Load MNIST, make train/val splits, and shuffle train set examples 107 | 108 | mnist = datasets.MNIST('~/datasets/mnist', train=True, download=True) 109 | mnist_train = (mnist.data[:50000], mnist.targets[:50000]) 110 | mnist_val = (mnist.data[50000:], mnist.targets[50000:]) 111 | 112 | rng_state = np.random.get_state() 113 | np.random.shuffle(mnist_train[0].numpy()) 114 | np.random.set_state(rng_state) 115 | np.random.shuffle(mnist_train[1].numpy()) 116 | 117 | 118 | envs = [ 119 | make_environment(mnist_train[0][::2], mnist_train[1][::2], 0.2), 120 | make_environment(mnist_train[0][1::2], mnist_train[1][1::2], 0.1), 121 | make_environment(mnist_val[0], mnist_val[1], 0.9)] 122 | 123 | # Put into data loader 124 | device = torch.device("cuda") 125 | kwargs = {'num_workers': 1, 'pin_memory': False} 126 | 127 | env0 = TensorDataset(envs[0]['images'], envs[0]['labels'], torch.zeros_like(envs[0]['labels'])) 128 | env1 = TensorDataset(envs[1]['images'], envs[1]['labels'], torch.zeros_like(envs[1]['labels'])+1.) 129 | env2 = TensorDataset(envs[2]['images'], envs[2]['labels'], torch.zeros_like(envs[2]['labels'])) 130 | cmnist = torch.utils.data.ConcatDataset([env0, env1]) 131 | train_loader = torch.utils.data.DataLoader(cmnist, 132 | batch_size=100, 133 | shuffle=True, **kwargs) 134 | test_loader = torch.utils.data.DataLoader(env2, 135 | batch_size=100, 136 | shuffle=False, **kwargs) 137 | 138 | # Init model and optimizer 139 | model = MLP().to(device) 140 | optimizer = optim.Adam(model.parameters(), lr=0.01) 141 | 142 | for epoch in range(1, args.epochs + 1): 143 | print('\n Epoch: ' + str(epoch)) 144 | train(args, model, device, train_loader, optimizer, epoch) 145 | train_loss, train_acc = test(args, model, device, train_loader) 146 | test_loss, test_acc = test(args, model, device, test_loader) 147 | 148 | print(epoch, train_loss, train_acc, test_loss, test_acc) 149 | 150 | # # Save best 151 | # if val_acc >= best_val_acc: 152 | # best_val_acc = val_acc 153 | # 154 | # torch.save(model, model_name + '.model') 155 | # torch.save(args, model_name + '.config') 156 | # early_stopping = 0 157 | # 158 | # early_stopping += 1 159 | # 160 | # if early_stopping >= args.early_stop_after: 161 | # break 162 | -------------------------------------------------------------------------------- /colored_mnist/with_all_da.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd ~/AdditionalDriveA/DeployedProjects/TwoLatentSpacesVAE/paper_experiments/colored_mnist/ 3 | echo "Starting" 4 | python main_with_all_da.py --seed 0 5 | python main_with_all_da.py --seed 1 6 | python main_with_all_da.py --seed 2 7 | python main_with_all_da.py --seed 3 8 | python main_with_all_da.py --seed 4 9 | python main_with_all_da.py --seed 5 10 | python main_with_all_da.py --seed 6 11 | python main_with_all_da.py --seed 7 12 | python main_with_all_da.py --seed 8 13 | python main_with_all_da.py --seed 9 14 | 15 | 16 | echo "Done" 17 | -------------------------------------------------------------------------------- /colored_mnist/with_chosen_da.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd ~/AdditionalDriveA/DeployedProjects/TwoLatentSpacesVAE/paper_experiments/colored_mnist/ 3 | echo "Starting" 4 | python main_with_chosen_da.py --seed 0 5 | python main_with_chosen_da.py --seed 1 6 | python main_with_chosen_da.py --seed 2 7 | python main_with_chosen_da.py --seed 3 8 | python main_with_chosen_da.py --seed 4 9 | python main_with_chosen_da.py --seed 5 10 | python main_with_chosen_da.py --seed 6 11 | python main_with_chosen_da.py --seed 7 12 | python main_with_chosen_da.py --seed 8 13 | python main_with_chosen_da.py --seed 9 14 | 15 | 16 | echo "Done" 17 | -------------------------------------------------------------------------------- /rotated_MNIST/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AMLab-Amsterdam/DataAugmentationInterventions/78ce67174db487e9697b9a842e69818305bb41ef/rotated_MNIST/__init__.py -------------------------------------------------------------------------------- /rotated_MNIST/augmentations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AMLab-Amsterdam/DataAugmentationInterventions/78ce67174db487e9697b9a842e69818305bb41ef/rotated_MNIST/augmentations/__init__.py -------------------------------------------------------------------------------- /rotated_MNIST/augmentations/experiment_augmentations_test_0.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, "../../../") 3 | 4 | import argparse 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | import torch.utils.data as data_utils 9 | 10 | import numpy as np 11 | 12 | from paper_experiments.rotated_MNIST.mnist_loader_shifted_label_distribution_rotate import MnistRotatedDist 13 | from paper_experiments.rotated_MNIST.mnist_loader_shifted_label_distribution_flip import MnistRotatedDistFlip 14 | 15 | from paper_experiments.rotated_MNIST.mnist_loader import MnistRotated 16 | from paper_experiments.rotated_MNIST.augmentations.model_baseline import Net 17 | 18 | 19 | def train(args, model, device, train_loader, optimizer, epoch): 20 | model.train() 21 | for batch_idx, (data, target, _) in enumerate(train_loader): 22 | data, target = data.to(device), target.to(device) 23 | _, target = target.max(dim=1) 24 | 25 | optimizer.zero_grad() 26 | output = model(data) 27 | loss = F.nll_loss(output, target) 28 | loss.backward() 29 | optimizer.step() 30 | 31 | 32 | def test(args, model, device, test_loader): 33 | model.eval() 34 | test_loss = 0 35 | correct = 0 36 | with torch.no_grad(): 37 | for data, target, _ in test_loader: 38 | data, target = data.to(device), target.to(device) 39 | _, target = target.max(dim=1) 40 | 41 | output = model(data) 42 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 43 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 44 | correct += pred.eq(target.view_as(pred)).sum().item() 45 | 46 | test_loss /= len(test_loader.dataset) 47 | 48 | return test_loss, 100. * correct / len(test_loader.dataset) 49 | 50 | 51 | def main(): 52 | # Training settings 53 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 54 | parser.add_argument('--no-cuda', action='store_true', default=False, 55 | help='disables CUDA training') 56 | parser.add_argument('--seed', type=int, default=0, 57 | help='random seed (default: 1)') 58 | parser.add_argument('--batch-size', type=int, default=128, 59 | help='input batch size for training (default: 64)') 60 | parser.add_argument('--epochs', type=int, default=200, 61 | help='number of epochs to train (default: 10)') 62 | parser.add_argument('--lr', type=float, default=0.001, 63 | help='learning rate (default: 0.01)') 64 | parser.add_argument('--da', type=str, default='rotate', choices=['rotate', 'flip'], 65 | help='type of data augmentation') 66 | 67 | args = parser.parse_args() 68 | use_cuda = not args.no_cuda and torch.cuda.is_available() 69 | 70 | # Set seed 71 | torch.manual_seed(args.seed) 72 | torch.backends.cudnn.benchmark = False 73 | np.random.seed(args.seed) 74 | 75 | device = torch.device("cuda") 76 | kwargs = {'num_workers': 8, 'pin_memory': True} if use_cuda else {} 77 | 78 | # Load supervised training 79 | if args.da == 'rotate': 80 | mnist_30 = MnistRotatedDist('../dataset/', train=True, thetas=[30.0], d_label=0, transform=True) 81 | mnist_60 = MnistRotatedDist('../dataset/', train=True, thetas=[60.0], d_label=1, transform=True) 82 | mnist_90 = MnistRotatedDist('../dataset/', train=True, thetas=[90.0], d_label=2, transform=True) 83 | model_name = 'baseline_test_0_random_rotate_seed_' + str(args.seed) 84 | 85 | elif args.da == 'flip': 86 | mnist_30 = MnistRotatedDistFlip('../dataset/', train=True, thetas=[30.0], d_label=0) 87 | mnist_60 = MnistRotatedDistFlip('../dataset/', train=True, thetas=[60.0], d_label=1) 88 | mnist_90 = MnistRotatedDistFlip('../dataset/', train=True, thetas=[90.0], d_label=2) 89 | model_name = 'baseline_test_0_random_flips_seed_' + str(args.seed) 90 | 91 | mnist = data_utils.ConcatDataset([mnist_30, mnist_60, mnist_90]) 92 | 93 | train_size = int(0.9 * len(mnist)) 94 | val_size = len(mnist) - train_size 95 | train_dataset, val_dataset = torch.utils.data.random_split(mnist, [train_size, val_size]) 96 | 97 | train_loader = data_utils.DataLoader(train_dataset, 98 | batch_size=args.batch_size, 99 | shuffle=True, **kwargs) 100 | 101 | val_loader = data_utils.DataLoader(val_dataset, 102 | batch_size=args.batch_size, 103 | shuffle=False, **kwargs) 104 | 105 | model = Net().to(device) 106 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 107 | 108 | best_val_acc = 0 109 | 110 | for epoch in range(1, args.epochs + 1): 111 | print('\n Epoch: ' + str(epoch)) 112 | train(args, model, device, train_loader, optimizer, epoch) 113 | val_loss, val_acc = test(args, model, device, val_loader) 114 | 115 | print(epoch, val_loss, val_acc) 116 | 117 | # Save best 118 | if val_acc >= best_val_acc: 119 | best_val_acc = val_acc 120 | 121 | torch.save(model, model_name + '.model') 122 | torch.save(args, model_name + '.config') 123 | 124 | # Test loader 125 | mnist_0 = MnistRotated('../dataset/', train=False, thetas=[0.0], d_label=0) 126 | test_loader = data_utils.DataLoader(mnist_0, 127 | batch_size=args.batch_size, 128 | shuffle=False, **kwargs) 129 | 130 | model = torch.load(model_name + '.model').to(device) 131 | _, test_acc = test(args, model, device, test_loader) 132 | 133 | with open(model_name + '.txt', "w") as text_file: 134 | text_file.write("Test Acc: " + str(test_acc)) 135 | 136 | 137 | if __name__ == '__main__': 138 | main() -------------------------------------------------------------------------------- /rotated_MNIST/augmentations/experiment_augmentations_test_0_all_da.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, "../../../") 3 | 4 | import argparse 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | import torch.utils.data as data_utils 9 | 10 | import numpy as np 11 | 12 | from paper_experiments.rotated_MNIST.mnist_loader_shifted_label_distribution_all_da import MnistAllDaDist 13 | 14 | from paper_experiments.rotated_MNIST.mnist_loader import MnistRotated 15 | from paper_experiments.rotated_MNIST.augmentations.model_baseline import Net 16 | 17 | 18 | def train(args, model, device, train_loader, optimizer, epoch): 19 | model.train() 20 | for batch_idx, (data, target, _) in enumerate(train_loader): 21 | data, target = data.to(device), target.to(device) 22 | _, target = target.max(dim=1) 23 | 24 | optimizer.zero_grad() 25 | output = model(data) 26 | loss = F.nll_loss(output, target) 27 | loss.backward() 28 | optimizer.step() 29 | 30 | 31 | def test(args, model, device, test_loader): 32 | model.eval() 33 | test_loss = 0 34 | correct = 0 35 | with torch.no_grad(): 36 | for data, target, _ in test_loader: 37 | data, target = data.to(device), target.to(device) 38 | _, target = target.max(dim=1) 39 | 40 | output = model(data) 41 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 42 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 43 | correct += pred.eq(target.view_as(pred)).sum().item() 44 | 45 | test_loss /= len(test_loader.dataset) 46 | 47 | return test_loss, 100. * correct / len(test_loader.dataset) 48 | 49 | 50 | def main(): 51 | # Training settings 52 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 53 | parser.add_argument('--no-cuda', action='store_true', default=False, 54 | help='disables CUDA training') 55 | parser.add_argument('--seed', type=int, default=0, 56 | help='random seed (default: 1)') 57 | parser.add_argument('--batch-size', type=int, default=128, 58 | help='input batch size for training (default: 64)') 59 | parser.add_argument('--epochs', type=int, default=200, 60 | help='number of epochs to train (default: 10)') 61 | parser.add_argument('--lr', type=float, default=0.001, 62 | help='learning rate (default: 0.01)') 63 | parser.add_argument('--da', type=str, default='rotate', choices=['rotate', 'flip'], 64 | help='type of data augmentation') 65 | 66 | args = parser.parse_args() 67 | use_cuda = not args.no_cuda and torch.cuda.is_available() 68 | 69 | # Set seed 70 | torch.manual_seed(args.seed) 71 | torch.backends.cudnn.benchmark = False 72 | np.random.seed(args.seed) 73 | 74 | device = torch.device("cuda") 75 | kwargs = {'num_workers': 8, 'pin_memory': True} if use_cuda else {} 76 | 77 | mnist_30 = MnistAllDaDist('../dataset/', train=True, thetas=[30.0], d_label=0) 78 | mnist_60 = MnistAllDaDist('../dataset/', train=True, thetas=[60.0], d_label=1) 79 | mnist_90 = MnistAllDaDist('../dataset/', train=True, thetas=[90.0], d_label=2) 80 | model_name = 'baseline_test_0_all_da_seed_' + str(args.seed) 81 | 82 | mnist = data_utils.ConcatDataset([mnist_30, mnist_60, mnist_90]) 83 | 84 | train_size = int(0.9 * len(mnist)) 85 | val_size = len(mnist) - train_size 86 | train_dataset, val_dataset = torch.utils.data.random_split(mnist, [train_size, val_size]) 87 | 88 | train_loader = data_utils.DataLoader(train_dataset, 89 | batch_size=args.batch_size, 90 | shuffle=True, **kwargs) 91 | 92 | val_loader = data_utils.DataLoader(val_dataset, 93 | batch_size=args.batch_size, 94 | shuffle=False, **kwargs) 95 | 96 | model = Net().to(device) 97 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 98 | 99 | best_val_acc = 0 100 | 101 | for epoch in range(1, args.epochs + 1): 102 | print('\n Epoch: ' + str(epoch)) 103 | train(args, model, device, train_loader, optimizer, epoch) 104 | val_loss, val_acc = test(args, model, device, val_loader) 105 | 106 | print(epoch, val_loss, val_acc) 107 | 108 | # Save best 109 | if val_acc >= best_val_acc: 110 | best_val_acc = val_acc 111 | 112 | torch.save(model, model_name + '.model') 113 | torch.save(args, model_name + '.config') 114 | 115 | # Test loader 116 | mnist_0 = MnistRotated('../dataset/', train=False, thetas=[0.0], d_label=0) 117 | test_loader = data_utils.DataLoader(mnist_0, 118 | batch_size=args.batch_size, 119 | shuffle=False, **kwargs) 120 | 121 | model = torch.load(model_name + '.model').to(device) 122 | _, test_acc = test(args, model, device, test_loader) 123 | 124 | with open(model_name + '.txt', "w") as text_file: 125 | text_file.write("Test Acc: " + str(test_acc)) 126 | 127 | 128 | if __name__ == '__main__': 129 | main() -------------------------------------------------------------------------------- /rotated_MNIST/augmentations/experiment_augmentations_test_0_individual_classes.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, "../../../") 3 | 4 | import argparse 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | import torch.utils.data as data_utils 9 | 10 | import numpy as np 11 | 12 | from paper_experiments.rotated_MNIST.mnist_loader_shifted_label_distribution_rotate import MnistRotatedDist 13 | from paper_experiments.rotated_MNIST.mnist_loader_shifted_label_distribution_flip import MnistRotatedDistFlip 14 | 15 | from paper_experiments.rotated_MNIST.mnist_loader import MnistRotated 16 | from paper_experiments.rotated_MNIST.augmentations.model_baseline import Net 17 | 18 | 19 | def train(args, model, device, train_loader, optimizer, epoch): 20 | model.train() 21 | for batch_idx, (data, target, _) in enumerate(train_loader): 22 | data, target = data.to(device), target.to(device) 23 | _, target = target.max(dim=1) 24 | 25 | optimizer.zero_grad() 26 | output = model(data) 27 | loss = F.nll_loss(output, target) 28 | loss.backward() 29 | optimizer.step() 30 | 31 | 32 | def test(args, model, device, test_loader): 33 | model.eval() 34 | test_loss = 0 35 | correct = 0 36 | correct_dict = {'0': 0, '1': 0, '2': 0, '3': 0, '4': 0, '5': 0, '6': 0, '7': 0, '8': 0, '9': 0} 37 | length_dict = {'0': 0, '1': 0, '2': 0, '3': 0, '4': 0, '5': 0, '6': 0, '7': 0, '8': 0, '9': 0} 38 | with torch.no_grad(): 39 | for data, target, _ in test_loader: 40 | data, target = data.to(device), target.to(device) 41 | _, target = target.max(dim=1) 42 | 43 | output = model(data) 44 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 45 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 46 | 47 | for i in range(10): 48 | class_idx = target == i 49 | pred_idx = pred[class_idx] 50 | correct_dict[str(i)] += pred_idx.eq(target[class_idx].view_as(pred_idx)).sum().item() 51 | length_dict[str(i)] += len(pred_idx) 52 | 53 | correct += pred.eq(target.view_as(pred)).sum().item() 54 | 55 | for key in correct_dict: 56 | correct_dict[key] /= length_dict[key] 57 | 58 | test_loss /= len(test_loader.dataset) 59 | 60 | return test_loss, 100. * correct / len(test_loader.dataset), correct_dict 61 | 62 | 63 | def main(): 64 | # Training settings 65 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 66 | parser.add_argument('--no-cuda', action='store_true', default=False, 67 | help='disables CUDA training') 68 | parser.add_argument('--seed', type=int, default=1, 69 | help='random seed (default: 1)') 70 | parser.add_argument('--batch-size', type=int, default=128, 71 | help='input batch size for training (default: 64)') 72 | parser.add_argument('--epochs', type=int, default=50, 73 | help='number of epochs to train (default: 10)') 74 | parser.add_argument('--lr', type=float, default=0.001, 75 | help='learning rate (default: 0.01)') 76 | parser.add_argument('--da', type=str, default='flip', choices=['rotate', 'flip'], 77 | help='type of data augmentation') 78 | 79 | args = parser.parse_args() 80 | use_cuda = not args.no_cuda and torch.cuda.is_available() 81 | 82 | # Set seed 83 | torch.manual_seed(args.seed) 84 | torch.backends.cudnn.benchmark = False 85 | np.random.seed(args.seed) 86 | 87 | device = torch.device("cuda") 88 | 89 | model_name = 'baseline_test_0_random_flips_seed_' + str(args.seed) 90 | 91 | kwargs = {'num_workers': 8, 'pin_memory': True} if use_cuda else {} 92 | 93 | # Load supervised training 94 | if args.da == 'rotate': 95 | mnist_30 = MnistRotatedDist('../dataset/', train=True, thetas=[30.0], d_label=0, transform=True) 96 | mnist_60 = MnistRotatedDist('../dataset/', train=True, thetas=[60.0], d_label=1, transform=True) 97 | mnist_90 = MnistRotatedDist('../dataset/', train=True, thetas=[90.0], d_label=2, transform=True) 98 | model_name = 'baseline_test_0_random_rotate_seed_' + str(args.seed) 99 | 100 | elif args.da == 'flip': 101 | mnist_30 = MnistRotatedDistFlip('../dataset/', train=True, thetas=[30.0], d_label=0) 102 | mnist_60 = MnistRotatedDistFlip('../dataset/', train=True, thetas=[60.0], d_label=1) 103 | mnist_90 = MnistRotatedDistFlip('../dataset/', train=True, thetas=[90.0], d_label=2) 104 | model_name = 'baseline_test_0_random_flips_seed_' + str(args.seed) 105 | 106 | mnist = data_utils.ConcatDataset([mnist_30, mnist_60, mnist_90]) 107 | 108 | mnist = data_utils.ConcatDataset([mnist_30, mnist_60, mnist_90]) 109 | 110 | train_size = int(0.9 * len(mnist)) 111 | val_size = len(mnist) - train_size 112 | train_dataset, val_dataset = torch.utils.data.random_split(mnist, [train_size, val_size]) 113 | 114 | train_loader = data_utils.DataLoader(train_dataset, 115 | batch_size=args.batch_size, 116 | shuffle=True, **kwargs) 117 | 118 | val_loader = data_utils.DataLoader(val_dataset, 119 | batch_size=args.batch_size, 120 | shuffle=False, **kwargs) 121 | 122 | model = Net().to(device) 123 | # model = NetFlat().to(device) 124 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 125 | 126 | best_val_acc = 0 127 | 128 | for epoch in range(1, args.epochs + 1): 129 | print('\n Epoch: ' + str(epoch)) 130 | train(args, model, device, train_loader, optimizer, epoch) 131 | val_loss, val_acc, val_dict = test(args, model, device, val_loader) 132 | 133 | print(epoch, val_loss, val_acc) 134 | for key in val_dict: 135 | print(val_dict[key]) 136 | 137 | # Save best 138 | if val_acc >= best_val_acc: 139 | best_val_acc = val_acc 140 | 141 | torch.save(model, model_name + '.model') 142 | torch.save(args, model_name + '.config') 143 | 144 | # Test loader 145 | mnist_0 = MnistRotated('../dataset/', train=False, thetas=[0.0], d_label=0) 146 | test_loader = data_utils.DataLoader(mnist_0, 147 | batch_size=args.batch_size, 148 | shuffle=False, **kwargs) 149 | 150 | model = torch.load(model_name + '.model').to(device) 151 | _, test_acc, test_dict = test(args, model, device, test_loader) 152 | 153 | print(epoch, test_acc) 154 | for key in val_dict: 155 | print(test_dict[key]) 156 | 157 | with open(model_name + '.txt', "w") as text_file: 158 | text_file.write("Test Acc: " + str(test_acc)) 159 | 160 | 161 | if __name__ == '__main__': 162 | main() -------------------------------------------------------------------------------- /rotated_MNIST/augmentations/experiment_augmentations_test_30.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, "../../../") 3 | 4 | import argparse 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | import torch.utils.data as data_utils 9 | 10 | import numpy as np 11 | 12 | from paper_experiments.rotated_MNIST.mnist_loader_shifted_label_distribution_rotate import MnistRotatedDist 13 | from paper_experiments.rotated_MNIST.mnist_loader_shifted_label_distribution_flip import MnistRotatedDistFlip 14 | 15 | from paper_experiments.rotated_MNIST.mnist_loader import MnistRotated 16 | from paper_experiments.rotated_MNIST.augmentations.model_baseline import Net 17 | 18 | 19 | def train(args, model, device, train_loader, optimizer, epoch): 20 | model.train() 21 | for batch_idx, (data, target, _) in enumerate(train_loader): 22 | data, target = data.to(device), target.to(device) 23 | _, target = target.max(dim=1) 24 | 25 | optimizer.zero_grad() 26 | output = model(data) 27 | loss = F.nll_loss(output, target) 28 | loss.backward() 29 | optimizer.step() 30 | 31 | 32 | def test(args, model, device, test_loader): 33 | model.eval() 34 | test_loss = 0 35 | correct = 0 36 | with torch.no_grad(): 37 | for data, target, _ in test_loader: 38 | data, target = data.to(device), target.to(device) 39 | _, target = target.max(dim=1) 40 | 41 | output = model(data) 42 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 43 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 44 | correct += pred.eq(target.view_as(pred)).sum().item() 45 | 46 | test_loss /= len(test_loader.dataset) 47 | 48 | return test_loss, 100. * correct / len(test_loader.dataset) 49 | 50 | 51 | def main(): 52 | # Training settings 53 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 54 | parser.add_argument('--no-cuda', action='store_true', default=False, 55 | help='disables CUDA training') 56 | parser.add_argument('--seed', type=int, default=0, 57 | help='random seed (default: 1)') 58 | parser.add_argument('--batch-size', type=int, default=128, 59 | help='input batch size for training (default: 64)') 60 | parser.add_argument('--epochs', type=int, default=200, 61 | help='number of epochs to train (default: 10)') 62 | parser.add_argument('--lr', type=float, default=0.001, 63 | help='learning rate (default: 0.01)') 64 | parser.add_argument('--da', type=str, default='rotate', choices=['rotate', 'flip'], 65 | help='type of data augmentation') 66 | 67 | args = parser.parse_args() 68 | use_cuda = not args.no_cuda and torch.cuda.is_available() 69 | 70 | # Set seed 71 | torch.manual_seed(args.seed) 72 | torch.backends.cudnn.benchmark = False 73 | np.random.seed(args.seed) 74 | 75 | device = torch.device("cuda") 76 | kwargs = {'num_workers': 8, 'pin_memory': True} if use_cuda else {} 77 | 78 | # Load supervised training 79 | if args.da == 'rotate': 80 | mnist_0 = MnistRotatedDist('../dataset/', train=True, thetas=[0.0], d_label=0, transform=True) 81 | mnist_60 = MnistRotatedDist('../dataset/', train=True, thetas=[60.0], d_label=1, transform=True) 82 | mnist_90 = MnistRotatedDist('../dataset/', train=True, thetas=[90.0], d_label=2, transform=True) 83 | model_name = 'baseline_test_0_random_rotate_seed_' + str(args.seed) 84 | 85 | elif args.da == 'flip': 86 | mnist_0 = MnistRotatedDistFlip('../dataset/', train=True, thetas=[0.0], d_label=0) 87 | mnist_60 = MnistRotatedDistFlip('../dataset/', train=True, thetas=[60.0], d_label=1) 88 | mnist_90 = MnistRotatedDistFlip('../dataset/', train=True, thetas=[90.0], d_label=2) 89 | model_name = 'baseline_test_0_random_flips_seed_' + str(args.seed) 90 | 91 | mnist = data_utils.ConcatDataset([mnist_0, mnist_60, mnist_90]) 92 | 93 | train_size = int(0.9 * len(mnist)) 94 | val_size = len(mnist) - train_size 95 | train_dataset, val_dataset = torch.utils.data.random_split(mnist, [train_size, val_size]) 96 | 97 | train_loader = data_utils.DataLoader(train_dataset, 98 | batch_size=args.batch_size, 99 | shuffle=True, **kwargs) 100 | 101 | val_loader = data_utils.DataLoader(val_dataset, 102 | batch_size=args.batch_size, 103 | shuffle=False, **kwargs) 104 | 105 | model = Net().to(device) 106 | # model = NetFlat().to(device) 107 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 108 | 109 | best_val_acc = 0 110 | 111 | for epoch in range(1, args.epochs + 1): 112 | print('\n Epoch: ' + str(epoch)) 113 | train(args, model, device, train_loader, optimizer, epoch) 114 | val_loss, val_acc = test(args, model, device, val_loader) 115 | 116 | print(epoch, val_loss, val_acc) 117 | 118 | # Save best 119 | if val_acc >= best_val_acc: 120 | best_val_acc = val_acc 121 | 122 | torch.save(model, model_name + '.model') 123 | torch.save(args, model_name + '.config') 124 | 125 | 126 | # Test loader 127 | mnist_30 = MnistRotated('../dataset/', train=False, thetas=[30.0], d_label=0) 128 | test_loader = data_utils.DataLoader(mnist_30, 129 | batch_size=args.batch_size, 130 | shuffle=False, **kwargs) 131 | 132 | model = torch.load(model_name + '.model').to(device) 133 | _, test_acc = test(args, model, device, test_loader) 134 | 135 | with open(model_name + '.txt', "w") as text_file: 136 | text_file.write("Test Acc: " + str(test_acc)) 137 | 138 | 139 | if __name__ == '__main__': 140 | main() -------------------------------------------------------------------------------- /rotated_MNIST/augmentations/experiment_augmentations_test_30_all_da.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, "../../../") 3 | 4 | import argparse 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | import torch.utils.data as data_utils 9 | 10 | import numpy as np 11 | 12 | from paper_experiments.rotated_MNIST.mnist_loader_shifted_label_distribution_all_da import MnistAllDaDist 13 | 14 | from paper_experiments.rotated_MNIST.mnist_loader import MnistRotated 15 | from paper_experiments.rotated_MNIST.augmentations.model_baseline import Net 16 | 17 | 18 | def train(args, model, device, train_loader, optimizer, epoch): 19 | model.train() 20 | for batch_idx, (data, target, _) in enumerate(train_loader): 21 | data, target = data.to(device), target.to(device) 22 | _, target = target.max(dim=1) 23 | 24 | optimizer.zero_grad() 25 | output = model(data) 26 | loss = F.nll_loss(output, target) 27 | loss.backward() 28 | optimizer.step() 29 | 30 | 31 | def test(args, model, device, test_loader): 32 | model.eval() 33 | test_loss = 0 34 | correct = 0 35 | with torch.no_grad(): 36 | for data, target, _ in test_loader: 37 | data, target = data.to(device), target.to(device) 38 | _, target = target.max(dim=1) 39 | 40 | output = model(data) 41 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 42 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 43 | correct += pred.eq(target.view_as(pred)).sum().item() 44 | 45 | test_loss /= len(test_loader.dataset) 46 | 47 | return test_loss, 100. * correct / len(test_loader.dataset) 48 | 49 | 50 | def main(): 51 | # Training settings 52 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 53 | parser.add_argument('--no-cuda', action='store_true', default=False, 54 | help='disables CUDA training') 55 | parser.add_argument('--seed', type=int, default=0, 56 | help='random seed (default: 1)') 57 | parser.add_argument('--batch-size', type=int, default=128, 58 | help='input batch size for training (default: 64)') 59 | parser.add_argument('--epochs', type=int, default=200, 60 | help='number of epochs to train (default: 10)') 61 | parser.add_argument('--lr', type=float, default=0.001, 62 | help='learning rate (default: 0.01)') 63 | parser.add_argument('--da', type=str, default='rotate', choices=['rotate', 'flip'], 64 | help='type of data augmentation') 65 | 66 | args = parser.parse_args() 67 | use_cuda = not args.no_cuda and torch.cuda.is_available() 68 | 69 | # Set seed 70 | torch.manual_seed(args.seed) 71 | torch.backends.cudnn.benchmark = False 72 | np.random.seed(args.seed) 73 | 74 | device = torch.device("cuda") 75 | kwargs = {'num_workers': 8, 'pin_memory': True} if use_cuda else {} 76 | 77 | mnist_0 = MnistAllDaDist('../dataset/', train=True, thetas=[0.0], d_label=0) 78 | mnist_60 = MnistAllDaDist('../dataset/', train=True, thetas=[60.0], d_label=1) 79 | mnist_90 = MnistAllDaDist('../dataset/', train=True, thetas=[90.0], d_label=2) 80 | model_name = 'baseline_test_30_all_da_seed_' + str(args.seed) 81 | 82 | mnist = data_utils.ConcatDataset([mnist_0, mnist_60, mnist_90]) 83 | 84 | train_size = int(0.9 * len(mnist)) 85 | val_size = len(mnist) - train_size 86 | train_dataset, val_dataset = torch.utils.data.random_split(mnist, [train_size, val_size]) 87 | 88 | train_loader = data_utils.DataLoader(train_dataset, 89 | batch_size=args.batch_size, 90 | shuffle=True, **kwargs) 91 | 92 | val_loader = data_utils.DataLoader(val_dataset, 93 | batch_size=args.batch_size, 94 | shuffle=False, **kwargs) 95 | 96 | model = Net().to(device) 97 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 98 | 99 | best_val_acc = 0 100 | 101 | for epoch in range(1, args.epochs + 1): 102 | print('\n Epoch: ' + str(epoch)) 103 | train(args, model, device, train_loader, optimizer, epoch) 104 | val_loss, val_acc = test(args, model, device, val_loader) 105 | 106 | print(epoch, val_loss, val_acc) 107 | 108 | # Save best 109 | if val_acc >= best_val_acc: 110 | best_val_acc = val_acc 111 | 112 | torch.save(model, model_name + '.model') 113 | torch.save(args, model_name + '.config') 114 | 115 | # Test loader 116 | mnist_30 = MnistRotated('../dataset/', train=False, thetas=[30.0], d_label=0) 117 | test_loader = data_utils.DataLoader(mnist_30, 118 | batch_size=args.batch_size, 119 | shuffle=False, **kwargs) 120 | 121 | model = torch.load(model_name + '.model').to(device) 122 | _, test_acc = test(args, model, device, test_loader) 123 | 124 | with open(model_name + '.txt', "w") as text_file: 125 | text_file.write("Test Acc: " + str(test_acc)) 126 | 127 | 128 | if __name__ == '__main__': 129 | main() -------------------------------------------------------------------------------- /rotated_MNIST/augmentations/experiment_augmentations_test_60.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, "../../../") 3 | 4 | import argparse 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | import torch.utils.data as data_utils 9 | 10 | import numpy as np 11 | 12 | from paper_experiments.rotated_MNIST.mnist_loader_shifted_label_distribution_rotate import MnistRotatedDist 13 | from paper_experiments.rotated_MNIST.mnist_loader_shifted_label_distribution_flip import MnistRotatedDistFlip 14 | 15 | from paper_experiments.rotated_MNIST.mnist_loader import MnistRotated 16 | from paper_experiments.rotated_MNIST.augmentations.model_baseline import Net 17 | 18 | 19 | def train(args, model, device, train_loader, optimizer, epoch): 20 | model.train() 21 | for batch_idx, (data, target, _) in enumerate(train_loader): 22 | data, target = data.to(device), target.to(device) 23 | _, target = target.max(dim=1) 24 | 25 | optimizer.zero_grad() 26 | output = model(data) 27 | loss = F.nll_loss(output, target) 28 | loss.backward() 29 | optimizer.step() 30 | 31 | 32 | def test(args, model, device, test_loader): 33 | model.eval() 34 | test_loss = 0 35 | correct = 0 36 | with torch.no_grad(): 37 | for data, target, _ in test_loader: 38 | data, target = data.to(device), target.to(device) 39 | _, target = target.max(dim=1) 40 | 41 | output = model(data) 42 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 43 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 44 | correct += pred.eq(target.view_as(pred)).sum().item() 45 | 46 | test_loss /= len(test_loader.dataset) 47 | 48 | return test_loss, 100. * correct / len(test_loader.dataset) 49 | 50 | 51 | def main(): 52 | # Training settings 53 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 54 | parser.add_argument('--no-cuda', action='store_true', default=False, 55 | help='disables CUDA training') 56 | parser.add_argument('--seed', type=int, default=0, 57 | help='random seed (default: 1)') 58 | parser.add_argument('--batch-size', type=int, default=128, 59 | help='input batch size for training (default: 64)') 60 | parser.add_argument('--epochs', type=int, default=200, 61 | help='number of epochs to train (default: 10)') 62 | parser.add_argument('--lr', type=float, default=0.001, 63 | help='learning rate (default: 0.01)') 64 | parser.add_argument('--da', type=str, default='rotate', choices=['rotate', 'flip'], 65 | help='type of data augmentation') 66 | 67 | args = parser.parse_args() 68 | use_cuda = not args.no_cuda and torch.cuda.is_available() 69 | 70 | # Set seed 71 | torch.manual_seed(args.seed) 72 | torch.backends.cudnn.benchmark = False 73 | np.random.seed(args.seed) 74 | 75 | device = torch.device("cuda") 76 | kwargs = {'num_workers': 8, 'pin_memory': True} if use_cuda else {} 77 | 78 | # Load supervised training 79 | if args.da == 'rotate': 80 | mnist_0 = MnistRotatedDist('../dataset/', train=True, thetas=[0.0], d_label=0, transform=True) 81 | mnist_30 = MnistRotatedDist('../dataset/', train=True, thetas=[30.0], d_label=1, transform=True) 82 | mnist_90 = MnistRotatedDist('../dataset/', train=True, thetas=[90.0], d_label=2, transform=True) 83 | model_name = 'baseline_test_0_random_rotate_seed_' + str(args.seed) 84 | 85 | elif args.da == 'flip': 86 | mnist_0 = MnistRotatedDistFlip('../dataset/', train=True, thetas=[0.0], d_label=0) 87 | mnist_30 = MnistRotatedDistFlip('../dataset/', train=True, thetas=[30.0], d_label=1) 88 | mnist_90 = MnistRotatedDistFlip('../dataset/', train=True, thetas=[90.0], d_label=2) 89 | model_name = 'baseline_test_0_random_flips_seed_' + str(args.seed) 90 | 91 | mnist = data_utils.ConcatDataset([mnist_0, mnist_30, mnist_90]) 92 | 93 | train_size = int(0.9 * len(mnist)) 94 | val_size = len(mnist) - train_size 95 | train_dataset, val_dataset = torch.utils.data.random_split(mnist, [train_size, val_size]) 96 | 97 | train_loader = data_utils.DataLoader(train_dataset, 98 | batch_size=args.batch_size, 99 | shuffle=True, **kwargs) 100 | 101 | val_loader = data_utils.DataLoader(val_dataset, 102 | batch_size=args.batch_size, 103 | shuffle=False, **kwargs) 104 | 105 | model = Net().to(device) 106 | # model = NetFlat().to(device) 107 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 108 | 109 | best_val_acc = 0 110 | 111 | for epoch in range(1, args.epochs + 1): 112 | print('\n Epoch: ' + str(epoch)) 113 | train(args, model, device, train_loader, optimizer, epoch) 114 | val_loss, val_acc = test(args, model, device, val_loader) 115 | 116 | print(epoch, val_loss, val_acc) 117 | 118 | # Save best 119 | if val_acc >= best_val_acc: 120 | best_val_acc = val_acc 121 | 122 | torch.save(model, model_name + '.model') 123 | torch.save(args, model_name + '.config') 124 | 125 | 126 | # Test loader 127 | mnist_60 = MnistRotated('../dataset/', train=False, thetas=[60.0], d_label=0) 128 | test_loader = data_utils.DataLoader(mnist_60, 129 | batch_size=args.batch_size, 130 | shuffle=False, **kwargs) 131 | 132 | model = torch.load(model_name + '.model').to(device) 133 | _, test_acc = test(args, model, device, test_loader) 134 | 135 | with open(model_name + '.txt', "w") as text_file: 136 | text_file.write("Test Acc: " + str(test_acc)) 137 | 138 | 139 | if __name__ == '__main__': 140 | main() -------------------------------------------------------------------------------- /rotated_MNIST/augmentations/experiment_augmentations_test_60_all_da.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, "../../../") 3 | 4 | import argparse 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | import torch.utils.data as data_utils 9 | 10 | import numpy as np 11 | 12 | from paper_experiments.rotated_MNIST.mnist_loader_shifted_label_distribution_all_da import MnistAllDaDist 13 | 14 | from paper_experiments.rotated_MNIST.mnist_loader import MnistRotated 15 | from paper_experiments.rotated_MNIST.augmentations.model_baseline import Net 16 | 17 | 18 | def train(args, model, device, train_loader, optimizer, epoch): 19 | model.train() 20 | for batch_idx, (data, target, _) in enumerate(train_loader): 21 | data, target = data.to(device), target.to(device) 22 | _, target = target.max(dim=1) 23 | 24 | optimizer.zero_grad() 25 | output = model(data) 26 | loss = F.nll_loss(output, target) 27 | loss.backward() 28 | optimizer.step() 29 | 30 | 31 | def test(args, model, device, test_loader): 32 | model.eval() 33 | test_loss = 0 34 | correct = 0 35 | with torch.no_grad(): 36 | for data, target, _ in test_loader: 37 | data, target = data.to(device), target.to(device) 38 | _, target = target.max(dim=1) 39 | 40 | output = model(data) 41 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 42 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 43 | correct += pred.eq(target.view_as(pred)).sum().item() 44 | 45 | test_loss /= len(test_loader.dataset) 46 | 47 | return test_loss, 100. * correct / len(test_loader.dataset) 48 | 49 | 50 | def main(): 51 | # Training settings 52 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 53 | parser.add_argument('--no-cuda', action='store_true', default=False, 54 | help='disables CUDA training') 55 | parser.add_argument('--seed', type=int, default=0, 56 | help='random seed (default: 1)') 57 | parser.add_argument('--batch-size', type=int, default=128, 58 | help='input batch size for training (default: 64)') 59 | parser.add_argument('--epochs', type=int, default=200, 60 | help='number of epochs to train (default: 10)') 61 | parser.add_argument('--lr', type=float, default=0.001, 62 | help='learning rate (default: 0.01)') 63 | parser.add_argument('--da', type=str, default='rotate', choices=['rotate', 'flip'], 64 | help='type of data augmentation') 65 | 66 | args = parser.parse_args() 67 | use_cuda = not args.no_cuda and torch.cuda.is_available() 68 | 69 | # Set seed 70 | torch.manual_seed(args.seed) 71 | torch.backends.cudnn.benchmark = False 72 | np.random.seed(args.seed) 73 | 74 | device = torch.device("cuda") 75 | kwargs = {'num_workers': 8, 'pin_memory': True} if use_cuda else {} 76 | 77 | mnist_30 = MnistAllDaDist('../dataset/', train=True, thetas=[30.0], d_label=0) 78 | mnist_0 = MnistAllDaDist('../dataset/', train=True, thetas=[0.0], d_label=1) 79 | mnist_90 = MnistAllDaDist('../dataset/', train=True, thetas=[90.0], d_label=2) 80 | model_name = 'baseline_test_60_all_da_seed_' + str(args.seed) 81 | 82 | mnist = data_utils.ConcatDataset([mnist_30, mnist_0, mnist_90]) 83 | 84 | train_size = int(0.9 * len(mnist)) 85 | val_size = len(mnist) - train_size 86 | train_dataset, val_dataset = torch.utils.data.random_split(mnist, [train_size, val_size]) 87 | 88 | train_loader = data_utils.DataLoader(train_dataset, 89 | batch_size=args.batch_size, 90 | shuffle=True, **kwargs) 91 | 92 | val_loader = data_utils.DataLoader(val_dataset, 93 | batch_size=args.batch_size, 94 | shuffle=False, **kwargs) 95 | 96 | model = Net().to(device) 97 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 98 | 99 | best_val_acc = 0 100 | 101 | for epoch in range(1, args.epochs + 1): 102 | print('\n Epoch: ' + str(epoch)) 103 | train(args, model, device, train_loader, optimizer, epoch) 104 | val_loss, val_acc = test(args, model, device, val_loader) 105 | 106 | print(epoch, val_loss, val_acc) 107 | 108 | # Save best 109 | if val_acc >= best_val_acc: 110 | best_val_acc = val_acc 111 | 112 | torch.save(model, model_name + '.model') 113 | torch.save(args, model_name + '.config') 114 | 115 | # Test loader 116 | mnist_60 = MnistRotated('../dataset/', train=False, thetas=[60.0], d_label=0) 117 | test_loader = data_utils.DataLoader(mnist_60, 118 | batch_size=args.batch_size, 119 | shuffle=False, **kwargs) 120 | 121 | model = torch.load(model_name + '.model').to(device) 122 | _, test_acc = test(args, model, device, test_loader) 123 | 124 | with open(model_name + '.txt', "w") as text_file: 125 | text_file.write("Test Acc: " + str(test_acc)) 126 | 127 | 128 | if __name__ == '__main__': 129 | main() -------------------------------------------------------------------------------- /rotated_MNIST/augmentations/experiment_augmentations_test_90.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, "../../../") 3 | 4 | import argparse 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | import torch.utils.data as data_utils 9 | 10 | import numpy as np 11 | 12 | from paper_experiments.rotated_MNIST.mnist_loader_shifted_label_distribution_rotate import MnistRotatedDist 13 | from paper_experiments.rotated_MNIST.mnist_loader_shifted_label_distribution_flip import MnistRotatedDistFlip 14 | 15 | from paper_experiments.rotated_MNIST.mnist_loader import MnistRotated 16 | from paper_experiments.rotated_MNIST.augmentations.model_baseline import Net 17 | 18 | 19 | def train(args, model, device, train_loader, optimizer, epoch): 20 | model.train() 21 | for batch_idx, (data, target, _) in enumerate(train_loader): 22 | data, target = data.to(device), target.to(device) 23 | _, target = target.max(dim=1) 24 | 25 | optimizer.zero_grad() 26 | output = model(data) 27 | loss = F.nll_loss(output, target) 28 | loss.backward() 29 | optimizer.step() 30 | 31 | 32 | def test(args, model, device, test_loader): 33 | model.eval() 34 | test_loss = 0 35 | correct = 0 36 | with torch.no_grad(): 37 | for data, target, _ in test_loader: 38 | data, target = data.to(device), target.to(device) 39 | _, target = target.max(dim=1) 40 | 41 | output = model(data) 42 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 43 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 44 | correct += pred.eq(target.view_as(pred)).sum().item() 45 | 46 | test_loss /= len(test_loader.dataset) 47 | 48 | return test_loss, 100. * correct / len(test_loader.dataset) 49 | 50 | 51 | def main(): 52 | # Training settings 53 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 54 | parser.add_argument('--no-cuda', action='store_true', default=False, 55 | help='disables CUDA training') 56 | parser.add_argument('--seed', type=int, default=0, 57 | help='random seed (default: 1)') 58 | parser.add_argument('--batch-size', type=int, default=128, 59 | help='input batch size for training (default: 64)') 60 | parser.add_argument('--epochs', type=int, default=200, 61 | help='number of epochs to train (default: 10)') 62 | parser.add_argument('--lr', type=float, default=0.001, 63 | help='learning rate (default: 0.01)') 64 | parser.add_argument('--da', type=str, default='rotate', choices=['rotate', 'flip'], 65 | help='type of data augmentation') 66 | 67 | args = parser.parse_args() 68 | use_cuda = not args.no_cuda and torch.cuda.is_available() 69 | 70 | # Set seed 71 | torch.manual_seed(args.seed) 72 | torch.backends.cudnn.benchmark = False 73 | np.random.seed(args.seed) 74 | 75 | device = torch.device("cuda") 76 | kwargs = {'num_workers': 8, 'pin_memory': True} if use_cuda else {} 77 | 78 | # Load supervised training 79 | if args.da == 'rotate': 80 | mnist_0 = MnistRotatedDist('../dataset/', train=True, thetas=[0.0], d_label=0, transform=True) 81 | mnist_30 = MnistRotatedDist('../dataset/', train=True, thetas=[30.0], d_label=1, transform=True) 82 | mnist_60 = MnistRotatedDist('../dataset/', train=True, thetas=[60.0], d_label=2, transform=True) 83 | model_name = 'baseline_test_0_random_rotate_seed_' + str(args.seed) 84 | 85 | elif args.da == 'flip': 86 | mnist_0 = MnistRotatedDistFlip('../dataset/', train=True, thetas=[0.0], d_label=0) 87 | mnist_30 = MnistRotatedDistFlip('../dataset/', train=True, thetas=[30.0], d_label=1) 88 | mnist_60 = MnistRotatedDistFlip('../dataset/', train=True, thetas=[60.0], d_label=2) 89 | model_name = 'baseline_test_0_random_flips_seed_' + str(args.seed) 90 | 91 | mnist = data_utils.ConcatDataset([mnist_0, mnist_30, mnist_60]) 92 | 93 | train_size = int(0.9 * len(mnist)) 94 | val_size = len(mnist) - train_size 95 | train_dataset, val_dataset = torch.utils.data.random_split(mnist, [train_size, val_size]) 96 | 97 | train_loader = data_utils.DataLoader(train_dataset, 98 | batch_size=args.batch_size, 99 | shuffle=True, **kwargs) 100 | 101 | val_loader = data_utils.DataLoader(val_dataset, 102 | batch_size=args.batch_size, 103 | shuffle=False, **kwargs) 104 | 105 | model = Net().to(device) 106 | # model = NetFlat().to(device) 107 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 108 | 109 | best_val_acc = 0 110 | 111 | for epoch in range(1, args.epochs + 1): 112 | print('\n Epoch: ' + str(epoch)) 113 | train(args, model, device, train_loader, optimizer, epoch) 114 | val_loss, val_acc = test(args, model, device, val_loader) 115 | 116 | print(epoch, val_loss, val_acc) 117 | 118 | # Save best 119 | if val_acc >= best_val_acc: 120 | best_val_acc = val_acc 121 | 122 | torch.save(model, model_name + '.model') 123 | torch.save(args, model_name + '.config') 124 | 125 | 126 | # Test loader 127 | mnist_90 = MnistRotated('../dataset/', train=False, thetas=[90.0], d_label=0) 128 | test_loader = data_utils.DataLoader(mnist_90, 129 | batch_size=args.batch_size, 130 | shuffle=False, **kwargs) 131 | 132 | model = torch.load(model_name + '.model').to(device) 133 | _, test_acc = test(args, model, device, test_loader) 134 | 135 | with open(model_name + '.txt', "w") as text_file: 136 | text_file.write("Test Acc: " + str(test_acc)) 137 | 138 | 139 | if __name__ == '__main__': 140 | main() -------------------------------------------------------------------------------- /rotated_MNIST/augmentations/experiment_augmentations_test_90_all_da.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, "../../../") 3 | 4 | import argparse 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | import torch.utils.data as data_utils 9 | 10 | import numpy as np 11 | 12 | from paper_experiments.rotated_MNIST.mnist_loader_shifted_label_distribution_all_da import MnistAllDaDist 13 | 14 | from paper_experiments.rotated_MNIST.mnist_loader import MnistRotated 15 | from paper_experiments.rotated_MNIST.augmentations.model_baseline import Net 16 | 17 | 18 | def train(args, model, device, train_loader, optimizer, epoch): 19 | model.train() 20 | for batch_idx, (data, target, _) in enumerate(train_loader): 21 | data, target = data.to(device), target.to(device) 22 | _, target = target.max(dim=1) 23 | 24 | optimizer.zero_grad() 25 | output = model(data) 26 | loss = F.nll_loss(output, target) 27 | loss.backward() 28 | optimizer.step() 29 | 30 | 31 | def test(args, model, device, test_loader): 32 | model.eval() 33 | test_loss = 0 34 | correct = 0 35 | with torch.no_grad(): 36 | for data, target, _ in test_loader: 37 | data, target = data.to(device), target.to(device) 38 | _, target = target.max(dim=1) 39 | 40 | output = model(data) 41 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 42 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 43 | correct += pred.eq(target.view_as(pred)).sum().item() 44 | 45 | test_loss /= len(test_loader.dataset) 46 | 47 | return test_loss, 100. * correct / len(test_loader.dataset) 48 | 49 | 50 | def main(): 51 | # Training settings 52 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 53 | parser.add_argument('--no-cuda', action='store_true', default=False, 54 | help='disables CUDA training') 55 | parser.add_argument('--seed', type=int, default=0, 56 | help='random seed (default: 1)') 57 | parser.add_argument('--batch-size', type=int, default=128, 58 | help='input batch size for training (default: 64)') 59 | parser.add_argument('--epochs', type=int, default=200, 60 | help='number of epochs to train (default: 10)') 61 | parser.add_argument('--lr', type=float, default=0.001, 62 | help='learning rate (default: 0.01)') 63 | parser.add_argument('--da', type=str, default='rotate', choices=['rotate', 'flip'], 64 | help='type of data augmentation') 65 | 66 | args = parser.parse_args() 67 | use_cuda = not args.no_cuda and torch.cuda.is_available() 68 | 69 | # Set seed 70 | torch.manual_seed(args.seed) 71 | torch.backends.cudnn.benchmark = False 72 | np.random.seed(args.seed) 73 | 74 | device = torch.device("cuda") 75 | kwargs = {'num_workers': 8, 'pin_memory': True} if use_cuda else {} 76 | 77 | mnist_30 = MnistAllDaDist('../dataset/', train=True, thetas=[30.0], d_label=0) 78 | mnist_60 = MnistAllDaDist('../dataset/', train=True, thetas=[60.0], d_label=1) 79 | mnist_0 = MnistAllDaDist('../dataset/', train=True, thetas=[0.0], d_label=2) 80 | model_name = 'baseline_test_90_all_da_seed_' + str(args.seed) 81 | 82 | mnist = data_utils.ConcatDataset([mnist_30, mnist_60, mnist_0]) 83 | 84 | train_size = int(0.9 * len(mnist)) 85 | val_size = len(mnist) - train_size 86 | train_dataset, val_dataset = torch.utils.data.random_split(mnist, [train_size, val_size]) 87 | 88 | train_loader = data_utils.DataLoader(train_dataset, 89 | batch_size=args.batch_size, 90 | shuffle=True, **kwargs) 91 | 92 | val_loader = data_utils.DataLoader(val_dataset, 93 | batch_size=args.batch_size, 94 | shuffle=False, **kwargs) 95 | 96 | model = Net().to(device) 97 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 98 | 99 | best_val_acc = 0 100 | 101 | for epoch in range(1, args.epochs + 1): 102 | print('\n Epoch: ' + str(epoch)) 103 | train(args, model, device, train_loader, optimizer, epoch) 104 | val_loss, val_acc = test(args, model, device, val_loader) 105 | 106 | print(epoch, val_loss, val_acc) 107 | 108 | # Save best 109 | if val_acc >= best_val_acc: 110 | best_val_acc = val_acc 111 | 112 | torch.save(model, model_name + '.model') 113 | torch.save(args, model_name + '.config') 114 | 115 | # Test loader 116 | mnist_90 = MnistRotated('../dataset/', train=False, thetas=[90.0], d_label=0) 117 | test_loader = data_utils.DataLoader(mnist_90, 118 | batch_size=args.batch_size, 119 | shuffle=False, **kwargs) 120 | 121 | model = torch.load(model_name + '.model').to(device) 122 | _, test_acc = test(args, model, device, test_loader) 123 | 124 | with open(model_name + '.txt', "w") as text_file: 125 | text_file.write("Test Acc: " + str(test_acc)) 126 | 127 | 128 | if __name__ == '__main__': 129 | main() -------------------------------------------------------------------------------- /rotated_MNIST/augmentations/jobs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AMLab-Amsterdam/DataAugmentationInterventions/78ce67174db487e9697b9a842e69818305bb41ef/rotated_MNIST/augmentations/jobs/__init__.py -------------------------------------------------------------------------------- /rotated_MNIST/augmentations/jobs/test_0_all_da.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd ~/AdditionalDriveA/DeployedProjects/TwoLatentSpacesVAE/paper_experiments/rotated_MNIST/augmentations 3 | echo "Starting" 4 | python experiment_augmentations_test_0_all_da.py --seed 0 5 | python experiment_augmentations_test_0_all_da.py --seed 1 6 | python experiment_augmentations_test_0_all_da.py --seed 2 7 | python experiment_augmentations_test_0_all_da.py --seed 3 8 | python experiment_augmentations_test_0_all_da.py --seed 4 9 | python experiment_augmentations_test_0_all_da.py --seed 5 10 | python experiment_augmentations_test_0_all_da.py --seed 6 11 | python experiment_augmentations_test_0_all_da.py --seed 7 12 | python experiment_augmentations_test_0_all_da.py --seed 8 13 | python experiment_augmentations_test_0_all_da.py --seed 9 14 | 15 | echo "Done" 16 | -------------------------------------------------------------------------------- /rotated_MNIST/augmentations/jobs/test_0_flip.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd ~/AdditionalDriveA/DeployedProjects/TwoLatentSpacesVAE/paper_experiments/rotated_MNIST/augmentations 3 | echo "Starting" 4 | python experiment_augmentations_test_0.py --seed 0 --da flip 5 | python experiment_augmentations_test_0.py --seed 1 --da flip 6 | python experiment_augmentations_test_0.py --seed 2 --da flip 7 | python experiment_augmentations_test_0.py --seed 3 --da flip 8 | python experiment_augmentations_test_0.py --seed 4 --da flip 9 | python experiment_augmentations_test_0.py --seed 5 --da flip 10 | python experiment_augmentations_test_0.py --seed 6 --da flip 11 | python experiment_augmentations_test_0.py --seed 7 --da flip 12 | python experiment_augmentations_test_0.py --seed 8 --da flip 13 | python experiment_augmentations_test_0.py --seed 9 --da flip 14 | 15 | echo "Done" 16 | -------------------------------------------------------------------------------- /rotated_MNIST/augmentations/jobs/test_0_rotate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd ~/AdditionalDriveA/DeployedProjects/TwoLatentSpacesVAE/paper_experiments/rotated_MNIST/augmentations 3 | echo "Starting" 4 | python experiment_augmentations_test_0.py --seed 0 --da rotate 5 | python experiment_augmentations_test_0.py --seed 1 --da rotate 6 | python experiment_augmentations_test_0.py --seed 2 --da rotate 7 | python experiment_augmentations_test_0.py --seed 3 --da rotate 8 | python experiment_augmentations_test_0.py --seed 4 --da rotate 9 | python experiment_augmentations_test_0.py --seed 5 --da rotate 10 | python experiment_augmentations_test_0.py --seed 6 --da rotate 11 | python experiment_augmentations_test_0.py --seed 7 --da rotate 12 | python experiment_augmentations_test_0.py --seed 8 --da rotate 13 | python experiment_augmentations_test_0.py --seed 9 --da rotate 14 | 15 | echo "Done" 16 | -------------------------------------------------------------------------------- /rotated_MNIST/augmentations/jobs/test_30_all_da.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd ~/AdditionalDriveA/DeployedProjects/TwoLatentSpacesVAE/paper_experiments/rotated_MNIST/augmentations 3 | echo "Starting" 4 | python experiment_augmentations_test_30_all_da.py --seed 0 5 | python experiment_augmentations_test_30_all_da.py --seed 1 6 | python experiment_augmentations_test_30_all_da.py --seed 2 7 | python experiment_augmentations_test_30_all_da.py --seed 3 8 | python experiment_augmentations_test_30_all_da.py --seed 4 9 | python experiment_augmentations_test_30_all_da.py --seed 5 10 | python experiment_augmentations_test_30_all_da.py --seed 6 11 | python experiment_augmentations_test_30_all_da.py --seed 7 12 | python experiment_augmentations_test_30_all_da.py --seed 8 13 | python experiment_augmentations_test_30_all_da.py --seed 9 14 | 15 | echo "Done" 16 | -------------------------------------------------------------------------------- /rotated_MNIST/augmentations/jobs/test_30_flip.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd ~/AdditionalDriveA/DeployedProjects/TwoLatentSpacesVAE/paper_experiments/rotated_MNIST/augmentations 3 | echo "Starting" 4 | python experiment_augmentations_test_30.py --seed 0 --da flip 5 | python experiment_augmentations_test_30.py --seed 1 --da flip 6 | python experiment_augmentations_test_30.py --seed 2 --da flip 7 | python experiment_augmentations_test_30.py --seed 3 --da flip 8 | python experiment_augmentations_test_30.py --seed 4 --da flip 9 | python experiment_augmentations_test_30.py --seed 5 --da flip 10 | python experiment_augmentations_test_30.py --seed 6 --da flip 11 | python experiment_augmentations_test_30.py --seed 7 --da flip 12 | python experiment_augmentations_test_30.py --seed 8 --da flip 13 | python experiment_augmentations_test_30.py --seed 9 --da flip 14 | 15 | echo "Done" 16 | -------------------------------------------------------------------------------- /rotated_MNIST/augmentations/jobs/test_30_rotate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd ~/AdditionalDriveA/DeployedProjects/TwoLatentSpacesVAE/paper_experiments/rotated_MNIST/augmentations 3 | echo "Starting" 4 | python experiment_augmentations_test_30.py --seed 0 --da rotate 5 | python experiment_augmentations_test_30.py --seed 1 --da rotate 6 | python experiment_augmentations_test_30.py --seed 2 --da rotate 7 | python experiment_augmentations_test_30.py --seed 3 --da rotate 8 | python experiment_augmentations_test_30.py --seed 4 --da rotate 9 | python experiment_augmentations_test_30.py --seed 5 --da rotate 10 | python experiment_augmentations_test_30.py --seed 6 --da rotate 11 | python experiment_augmentations_test_30.py --seed 7 --da rotate 12 | python experiment_augmentations_test_30.py --seed 8 --da rotate 13 | python experiment_augmentations_test_30.py --seed 9 --da rotate 14 | 15 | 16 | echo "Done" 17 | -------------------------------------------------------------------------------- /rotated_MNIST/augmentations/jobs/test_60_all_da.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd ~/AdditionalDriveA/DeployedProjects/TwoLatentSpacesVAE/paper_experiments/rotated_MNIST/augmentations 3 | echo "Starting" 4 | python experiment_augmentations_test_60_all_da.py --seed 0 5 | python experiment_augmentations_test_60_all_da.py --seed 1 6 | python experiment_augmentations_test_60_all_da.py --seed 2 7 | python experiment_augmentations_test_60_all_da.py --seed 3 8 | python experiment_augmentations_test_60_all_da.py --seed 4 9 | python experiment_augmentations_test_60_all_da.py --seed 5 10 | python experiment_augmentations_test_60_all_da.py --seed 6 11 | python experiment_augmentations_test_60_all_da.py --seed 7 12 | python experiment_augmentations_test_60_all_da.py --seed 8 13 | python experiment_augmentations_test_60_all_da.py --seed 9 14 | 15 | echo "Done" 16 | -------------------------------------------------------------------------------- /rotated_MNIST/augmentations/jobs/test_60_flip.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd ~/AdditionalDriveA/DeployedProjects/TwoLatentSpacesVAE/paper_experiments/rotated_MNIST/augmentations 3 | echo "Starting" 4 | python experiment_augmentations_test_60.py --seed 0 --da flip 5 | python experiment_augmentations_test_60.py --seed 1 --da flip 6 | python experiment_augmentations_test_60.py --seed 2 --da flip 7 | python experiment_augmentations_test_60.py --seed 3 --da flip 8 | python experiment_augmentations_test_60.py --seed 4 --da flip 9 | python experiment_augmentations_test_60.py --seed 5 --da flip 10 | python experiment_augmentations_test_60.py --seed 6 --da flip 11 | python experiment_augmentations_test_60.py --seed 7 --da flip 12 | python experiment_augmentations_test_60.py --seed 8 --da flip 13 | python experiment_augmentations_test_60.py --seed 9 --da flip 14 | 15 | echo "Done" 16 | -------------------------------------------------------------------------------- /rotated_MNIST/augmentations/jobs/test_60_rotate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd ~/AdditionalDriveA/DeployedProjects/TwoLatentSpacesVAE/paper_experiments/rotated_MNIST/augmentations 3 | echo "Starting" 4 | python experiment_augmentations_test_60.py --seed 0 --da rotate 5 | python experiment_augmentations_test_60.py --seed 1 --da rotate 6 | python experiment_augmentations_test_60.py --seed 2 --da rotate 7 | python experiment_augmentations_test_60.py --seed 3 --da rotate 8 | python experiment_augmentations_test_60.py --seed 4 --da rotate 9 | python experiment_augmentations_test_60.py --seed 5 --da rotate 10 | python experiment_augmentations_test_60.py --seed 6 --da rotate 11 | python experiment_augmentations_test_60.py --seed 7 --da rotate 12 | python experiment_augmentations_test_60.py --seed 8 --da rotate 13 | python experiment_augmentations_test_60.py --seed 9 --da rotate 14 | 15 | 16 | echo "Done" 17 | -------------------------------------------------------------------------------- /rotated_MNIST/augmentations/jobs/test_90_all_da.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd ~/AdditionalDriveA/DeployedProjects/TwoLatentSpacesVAE/paper_experiments/rotated_MNIST/augmentations 3 | echo "Starting" 4 | python experiment_augmentations_test_90_all_da.py --seed 0 5 | python experiment_augmentations_test_90_all_da.py --seed 1 6 | python experiment_augmentations_test_90_all_da.py --seed 2 7 | python experiment_augmentations_test_90_all_da.py --seed 3 8 | python experiment_augmentations_test_90_all_da.py --seed 4 9 | python experiment_augmentations_test_90_all_da.py --seed 5 10 | python experiment_augmentations_test_90_all_da.py --seed 6 11 | python experiment_augmentations_test_90_all_da.py --seed 7 12 | python experiment_augmentations_test_90_all_da.py --seed 8 13 | python experiment_augmentations_test_90_all_da.py --seed 9 14 | 15 | echo "Done" 16 | -------------------------------------------------------------------------------- /rotated_MNIST/augmentations/jobs/test_90_flip.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd ~/AdditionalDriveA/DeployedProjects/TwoLatentSpacesVAE/paper_experiments/rotated_MNIST/augmentations 3 | echo "Starting" 4 | python experiment_augmentations_test_90.py --seed 0 --da flip 5 | python experiment_augmentations_test_90.py --seed 1 --da flip 6 | python experiment_augmentations_test_90.py --seed 2 --da flip 7 | python experiment_augmentations_test_90.py --seed 3 --da flip 8 | python experiment_augmentations_test_90.py --seed 4 --da flip 9 | python experiment_augmentations_test_90.py --seed 5 --da flip 10 | python experiment_augmentations_test_90.py --seed 6 --da flip 11 | python experiment_augmentations_test_90.py --seed 7 --da flip 12 | python experiment_augmentations_test_90.py --seed 8 --da flip 13 | python experiment_augmentations_test_90.py --seed 9 --da flip 14 | 15 | echo "Done" 16 | -------------------------------------------------------------------------------- /rotated_MNIST/augmentations/jobs/test_90_rotate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd ~/AdditionalDriveA/DeployedProjects/TwoLatentSpacesVAE/paper_experiments/rotated_MNIST/augmentations 3 | echo "Starting" 4 | python experiment_augmentations_test_90.py --seed 0 --da rotate 5 | python experiment_augmentations_test_90.py --seed 1 --da rotate 6 | python experiment_augmentations_test_90.py --seed 2 --da rotate 7 | python experiment_augmentations_test_90.py --seed 3 --da rotate 8 | python experiment_augmentations_test_90.py --seed 4 --da rotate 9 | python experiment_augmentations_test_90.py --seed 5 --da rotate 10 | python experiment_augmentations_test_90.py --seed 6 --da rotate 11 | python experiment_augmentations_test_90.py --seed 7 --da rotate 12 | python experiment_augmentations_test_90.py --seed 8 --da rotate 13 | python experiment_augmentations_test_90.py --seed 9 --da rotate 14 | 15 | echo "Done" 16 | -------------------------------------------------------------------------------- /rotated_MNIST/augmentations/model_baseline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class Net(nn.Module): 7 | def __init__(self): 8 | super(Net, self).__init__() 9 | self.encoder = nn.Sequential( 10 | nn.Conv2d(1, 32, kernel_size=5, stride=1, bias=False), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2, 2), 11 | nn.Conv2d(32, 64, kernel_size=5, stride=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2, 2), 12 | ) 13 | self.classifier = nn.Sequential(nn.Linear(64 * 4 * 4, 128), 14 | nn.Dropout(0.5), 15 | nn.ReLU(), 16 | nn.Linear(128, 10)) 17 | 18 | def forward(self, x): 19 | h = self.encoder(x) 20 | h = torch.flatten(h, 1) 21 | h = self.classifier(h) 22 | output = F.log_softmax(h, dim=1) 23 | return output -------------------------------------------------------------------------------- /rotated_MNIST/choose_da_with_domain_classifier/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AMLab-Amsterdam/DataAugmentationInterventions/78ce67174db487e9697b9a842e69818305bb41ef/rotated_MNIST/choose_da_with_domain_classifier/__init__.py -------------------------------------------------------------------------------- /rotated_MNIST/choose_da_with_domain_classifier/evaulate_domain_experiments.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import re 4 | import numpy as np 5 | 6 | os.chdir('./') 7 | 8 | brightness_train = [] 9 | brightness_val = [] 10 | contrast_train = [] 11 | contrast_val = [] 12 | saturation_train = [] 13 | saturation_val = [] 14 | hue_train = [] 15 | hue_val = [] 16 | rotation_train = [] 17 | rotation_val = [] 18 | translate_train = [] 19 | translate_val = [] 20 | scale_train = [] 21 | scale_val = [] 22 | shear_train = [] 23 | shear_val = [] 24 | vflip_train = [] 25 | vflip_val = [] 26 | hflip_train = [] 27 | hflip_val = [] 28 | 29 | for file in glob.glob("*.txt"): 30 | if 'brightness' in file: 31 | with open(file) as f: 32 | content = f.readlines() 33 | train, val = re.findall(r"[-+]?\d*\.\d+|\d+", content[0]) 34 | train, val = float(train), float(val) 35 | brightness_train.append(train) 36 | brightness_val.append(val) 37 | 38 | if 'contrast' in file: 39 | with open(file) as f: 40 | content = f.readlines() 41 | train, val = re.findall(r"[-+]?\d*\.\d+|\d+", content[0]) 42 | train, val = float(train), float(val) 43 | contrast_train.append(train) 44 | contrast_val.append(val) 45 | 46 | if 'saturation' in file: 47 | with open(file) as f: 48 | content = f.readlines() 49 | train, val = re.findall(r"[-+]?\d*\.\d+|\d+", content[0]) 50 | train, val = float(train), float(val) 51 | saturation_train.append(train) 52 | saturation_val.append(val) 53 | 54 | if 'hue' in file: 55 | with open(file) as f: 56 | content = f.readlines() 57 | train, val = re.findall(r"[-+]?\d*\.\d+|\d+", content[0]) 58 | train, val = float(train), float(val) 59 | hue_train.append(train) 60 | hue_val.append(val) 61 | 62 | if 'rotation' in file: 63 | with open(file) as f: 64 | content = f.readlines() 65 | train, val = re.findall(r"[-+]?\d*\.\d+|\d+", content[0]) 66 | train, val = float(train), float(val) 67 | rotation_train.append(train) 68 | rotation_val.append(val) 69 | 70 | if 'translate' in file: 71 | with open(file) as f: 72 | content = f.readlines() 73 | train, val = re.findall(r"[-+]?\d*\.\d+|\d+", content[0]) 74 | train, val = float(train), float(val) 75 | translate_train.append(train) 76 | translate_val.append(val) 77 | 78 | if 'scale' in file: 79 | with open(file) as f: 80 | content = f.readlines() 81 | train, val = re.findall(r"[-+]?\d*\.\d+|\d+", content[0]) 82 | train, val = float(train), float(val) 83 | scale_train.append(train) 84 | scale_val.append(val) 85 | 86 | if 'shear' in file: 87 | with open(file) as f: 88 | content = f.readlines() 89 | train, val = re.findall(r"[-+]?\d*\.\d+|\d+", content[0]) 90 | train, val = float(train), float(val) 91 | shear_train.append(train) 92 | shear_val.append(val) 93 | 94 | if 'vflip' in file: 95 | with open(file) as f: 96 | content = f.readlines() 97 | train, val = re.findall(r"[-+]?\d*\.\d+|\d+", content[0]) 98 | train, val = float(train), float(val) 99 | vflip_train.append(train) 100 | vflip_val.append(val) 101 | 102 | if 'hflip' in file: 103 | with open(file) as f: 104 | content = f.readlines() 105 | train, val = re.findall(r"[-+]?\d*\.\d+|\d+", content[0]) 106 | train, val = float(train), float(val) 107 | hflip_train.append(train) 108 | hflip_val.append(val) 109 | 110 | brightness_train_std = np.std(np.array(brightness_train)) 111 | brightness_val_std = np.std(np.array(brightness_val)) 112 | contrast_train_std = np.std(np.array(contrast_train)) 113 | contrast_val_std = np.std(np.array(contrast_val)) 114 | saturation_train_std = np.std(np.array(saturation_train)) 115 | saturation_val_std = np.std(np.array(saturation_val)) 116 | hue_train_std = np.std(np.array(hue_train)) 117 | hue_val_std = np.std(np.array(hue_val)) 118 | rotation_train_std = np.std(np.array(rotation_train)) 119 | rotation_val_std = np.std(np.array(rotation_val)) 120 | translate_train_std = np.std(np.array(translate_train)) 121 | translate_val_std = np.std(np.array(translate_val)) 122 | scale_train_std = np.std(np.array(scale_train)) 123 | scale_val_std = np.std(np.array(scale_val)) 124 | shear_train_std = np.std(np.array(shear_train)) 125 | shear_val_std = np.std(np.array(shear_val)) 126 | vflip_train_std = np.std(np.array(vflip_train)) 127 | vflip_val_std = np.std(np.array(vflip_val)) 128 | hflip_train_std = np.std(np.array(hflip_train)) 129 | hflip_val_std = np.std(np.array(hflip_val)) 130 | 131 | 132 | brightness_train = np.mean(np.array(brightness_train)) 133 | brightness_val = np.mean(np.array(brightness_val)) 134 | contrast_train = np.mean(np.array(contrast_train)) 135 | contrast_val = np.mean(np.array(contrast_val)) 136 | saturation_train = np.mean(np.array(saturation_train)) 137 | saturation_val = np.mean(np.array(saturation_val)) 138 | hue_train = np.mean(np.array(hue_train)) 139 | hue_val = np.mean(np.array(hue_val)) 140 | rotation_train = np.mean(np.array(rotation_train)) 141 | rotation_val = np.mean(np.array(rotation_val)) 142 | translate_train = np.mean(np.array(translate_train)) 143 | translate_val = np.mean(np.array(translate_val)) 144 | scale_train = np.mean(np.array(scale_train)) 145 | scale_val = np.mean(np.array(scale_val)) 146 | shear_train = np.mean(np.array(shear_train)) 147 | shear_val = np.mean(np.array(shear_val)) 148 | vflip_train = np.mean(np.array(vflip_train)) 149 | vflip_val = np.mean(np.array(vflip_val)) 150 | hflip_train = np.mean(np.array(hflip_train)) 151 | hflip_val = np.mean(np.array(hflip_val)) 152 | 153 | print('mean') 154 | print(brightness_train) 155 | print(brightness_val) 156 | print(contrast_train) 157 | print(contrast_val) 158 | print(saturation_train) 159 | print(saturation_val) 160 | print(hue_train) 161 | print(hue_val) 162 | print(rotation_train) 163 | print(rotation_val) 164 | print(translate_train) 165 | print(translate_val) 166 | print(scale_train) 167 | print(scale_val) 168 | print(shear_train) 169 | print(shear_val) 170 | print(vflip_train) 171 | print(vflip_val) 172 | print(hflip_train) 173 | print(hflip_val) 174 | print('std') 175 | print(brightness_train_std) 176 | print(brightness_val_std) 177 | print(contrast_train_std) 178 | print(contrast_val_std) 179 | print(saturation_train_std) 180 | print(saturation_val_std) 181 | print(hue_train_std) 182 | print(hue_val_std) 183 | print(rotation_train_std) 184 | print(rotation_val_std) 185 | print(translate_train_std) 186 | print(translate_val_std) 187 | print(scale_train_std) 188 | print(scale_val_std) 189 | print(shear_train_std) 190 | print(shear_val_std) 191 | print(vflip_train_std) 192 | print(vflip_val_std) 193 | print(hflip_train_std) 194 | print(hflip_val_std) -------------------------------------------------------------------------------- /rotated_MNIST/choose_da_with_domain_classifier/experiment_test_0_only_rotation.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, "../../../") 3 | 4 | import argparse 5 | import torch 6 | import numpy as np 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | import torch.utils.data as data_utils 10 | import torchvision 11 | import PIL 12 | import wandb 13 | 14 | from paper_experiments.rotated_MNIST.mnist_loader_shifted_label_distribution_rotate_da import MnistRotatedDistDa 15 | from paper_experiments.rotated_MNIST.augmentations.model_baseline import Net 16 | 17 | 18 | def train(args, model, device, train_loader, optimizer, epoch): 19 | model.train() 20 | for batch_idx, (data, _, domain) in enumerate(train_loader): 21 | data, domain = data.to(device), domain.to(device) 22 | _, domain = domain.max(dim=1) 23 | 24 | optimizer.zero_grad() 25 | output = model(data) 26 | loss = F.nll_loss(output, domain) 27 | loss.backward() 28 | optimizer.step() 29 | 30 | 31 | def test(args, model, device, test_loader): 32 | model.eval() 33 | test_loss = 0 34 | correct = 0 35 | with torch.no_grad(): 36 | for data, _, domain in test_loader: 37 | data, domain = data.to(device), domain.to(device) 38 | _, domain = domain.max(dim=1) 39 | 40 | output = model(data) 41 | test_loss += F.nll_loss(output, domain, reduction='sum').item() # sum up batch loss 42 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 43 | correct += pred.eq(domain.view_as(pred)).sum().item() 44 | 45 | test_loss /= len(test_loader.dataset) 46 | 47 | return test_loss, 100. * correct / len(test_loader.dataset) 48 | 49 | 50 | def main(): 51 | # Training settings 52 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 53 | parser.add_argument('--no-cuda', action='store_true', default=False, 54 | help='disables CUDA training') 55 | parser.add_argument('--seed', type=int, default=0, 56 | help='random seed (default: 1)') 57 | parser.add_argument('--batch-size', type=int, default=128, 58 | help='input batch size for training (default: 64)') 59 | parser.add_argument('--epochs', type=int, default=50, 60 | help='number of epochs to train (default: 10)') 61 | parser.add_argument('--lr', type=float, default=0.001, 62 | help='learning rate (default: 0.01)') 63 | parser.add_argument('--da', type=str, default='rotation1', choices=['rotation1', 64 | 'rotation2', 65 | 'rotation3', 66 | 'rotation4', 67 | 'rotation5', 68 | ]) 69 | parser.add_argument('-dd', '--data_dir', type=str, default='./data', 70 | help='Directory to download data to and load data from') 71 | parser.add_argument('-wd', '--wandb_dir', type=str, default='./', 72 | help='(OVERRIDDEN BY ENV_VAR for sweep) Directory to download data to and load data from') 73 | 74 | args = parser.parse_args() 75 | use_cuda = not args.no_cuda and torch.cuda.is_available() 76 | 77 | print(args.da) 78 | 79 | # Set seed 80 | torch.manual_seed(args.seed) 81 | np.random.seed(args.seed) 82 | 83 | device = torch.device("cuda") 84 | kwargs = {'num_workers': 8, 'pin_memory': True} if use_cuda else {} 85 | transform_dict = {'rotation1': torchvision.transforms.RandomAffine(15, translate=None, scale=None, shear=None), 86 | 'rotation2': torchvision.transforms.RandomAffine(45, translate=None, scale=None, shear=None), 87 | 'rotation3': torchvision.transforms.RandomAffine(90, translate=None, scale=None, 88 | shear=None), 89 | 'rotation4': torchvision.transforms.RandomAffine([0, 180], translate=None, scale=None, 90 | shear=None), 91 | 'rotation5': torchvision.transforms.RandomAffine([0, 359], translate=None, scale=None, 92 | shear=None) 93 | } 94 | 95 | rng_state = np.random.get_state() 96 | mnist_0_train = MnistRotatedDistDa('../dataset/', train=True, thetas=[0], d_label=0, transform=transform_dict[args.da], 97 | rng_state=rng_state) 98 | mnist_0_val = MnistRotatedDistDa('../dataset/', train=False, thetas=[0], d_label=0, transform=None, 99 | rng_state=rng_state) 100 | rng_state = np.random.get_state() 101 | mnist_30_train = MnistRotatedDistDa('../dataset/', train=True, thetas=[30.0], d_label=1, 102 | transform=transform_dict[args.da], rng_state=rng_state) 103 | mnist_30_val = MnistRotatedDistDa('../dataset/', train=False, thetas=[30.0], d_label=1, transform=None, 104 | rng_state=rng_state) 105 | rng_state = np.random.get_state() 106 | mnist_60_train = MnistRotatedDistDa('../dataset/', train=True, thetas=[60.0], d_label=2, 107 | transform=transform_dict[args.da], rng_state=rng_state) 108 | mnist_60_val = MnistRotatedDistDa('../dataset/', train=False, thetas=[60.0], d_label=2, 109 | transform=None, rng_state=rng_state) 110 | mnist_train = data_utils.ConcatDataset([mnist_0_train, mnist_30_train, mnist_60_train]) 111 | train_loader = data_utils.DataLoader(mnist_train, 112 | batch_size=100, 113 | shuffle=True, 114 | **kwargs) 115 | 116 | mnist_val = data_utils.ConcatDataset([mnist_0_val, mnist_30_val, mnist_60_val]) 117 | val_loader = data_utils.DataLoader(mnist_val, 118 | batch_size=100, 119 | shuffle=True, 120 | **kwargs) 121 | 122 | wandb.init(project="NewRotated_MNISTOnlyRot", config=args, name=args.da) 123 | model_name = 'baseline_test_0_'+ args.da +'_seed_' + str(args.seed) 124 | 125 | 126 | model = Net().to(device) 127 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 128 | 129 | train_accs = [] 130 | val_accs = [] 131 | for epoch in range(1, args.epochs + 1): 132 | print('\n Epoch: ' + str(epoch)) 133 | train(args, model, device, train_loader, optimizer, epoch) 134 | train_loss, train_acc = test(args, model, device, train_loader) 135 | val_loss, val_acc = test(args, model, device, val_loader) 136 | 137 | wandb.log({'train accuracy': train_acc, 'val accuracy': val_acc}) 138 | train_accs.append(train_acc) 139 | val_accs.append(val_acc) 140 | 141 | train_accs = np.array(train_accs) 142 | mean_train_accs = np.mean(train_accs[-10:]) 143 | print(mean_train_accs) 144 | 145 | val_accs = np.array(val_accs) 146 | mean_val_accs = np.mean(val_accs[-10:]) 147 | print(mean_val_accs) 148 | 149 | with open(model_name + '.txt', "w") as text_file: 150 | text_file.write("Mean train acc: " + str(mean_train_accs)) 151 | text_file.write("Mean val acc: " + str(mean_val_accs)) 152 | 153 | wandb.run.summary["mean_train_accs"] = mean_train_accs 154 | wandb.run.summary["mean_val_accs"] = mean_val_accs 155 | 156 | 157 | if __name__ == '__main__': 158 | main() -------------------------------------------------------------------------------- /rotated_MNIST/choose_da_with_domain_classifier/jobs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AMLab-Amsterdam/DataAugmentationInterventions/78ce67174db487e9697b9a842e69818305bb41ef/rotated_MNIST/choose_da_with_domain_classifier/jobs/__init__.py -------------------------------------------------------------------------------- /rotated_MNIST/choose_da_with_domain_classifier/jobs/test_0.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd ~/AdditionalDriveA/DeployedProjects/TwoLatentSpacesVAE/paper_experiments/rotated_MNIST/choose_da_with_domain_classifier 3 | echo "Starting" 4 | python experiment_test_0.py --seed 0 --da brightness 5 | python experiment_test_0.py --seed 0 --da contrast 6 | python experiment_test_0.py --seed 0 --da saturation 7 | python experiment_test_0.py --seed 0 --da hue 8 | python experiment_test_0.py --seed 0 --da rotation 9 | python experiment_test_0.py --seed 0 --da translate 10 | python experiment_test_0.py --seed 0 --da scale 11 | python experiment_test_0.py --seed 0 --da shear 12 | python experiment_test_0.py --seed 0 --da none 13 | python experiment_test_0.py --seed 0 --da hflip 14 | python experiment_test_0.py --seed 0 --da vflip 15 | 16 | python experiment_test_0.py --seed 1 --da brightness 17 | python experiment_test_0.py --seed 1 --da contrast 18 | python experiment_test_0.py --seed 1 --da saturation 19 | python experiment_test_0.py --seed 1 --da hue 20 | python experiment_test_0.py --seed 1 --da rotation 21 | python experiment_test_0.py --seed 1 --da translate 22 | python experiment_test_0.py --seed 1 --da scale 23 | python experiment_test_0.py --seed 1 --da shear 24 | python experiment_test_0.py --seed 1 --da none 25 | python experiment_test_0.py --seed 1 --da hflip 26 | python experiment_test_0.py --seed 1 --da vflip 27 | 28 | python experiment_test_0.py --seed 2 --da brightness 29 | python experiment_test_0.py --seed 2 --da contrast 30 | python experiment_test_0.py --seed 2 --da saturation 31 | python experiment_test_0.py --seed 2 --da hue 32 | python experiment_test_0.py --seed 2 --da rotation 33 | python experiment_test_0.py --seed 2 --da translate 34 | python experiment_test_0.py --seed 2 --da scale 35 | python experiment_test_0.py --seed 2 --da shear 36 | python experiment_test_0.py --seed 2 --da none 37 | python experiment_test_0.py --seed 2 --da hflip 38 | python experiment_test_0.py --seed 2 --da vflip 39 | 40 | python experiment_test_0.py --seed 3 --da brightness 41 | python experiment_test_0.py --seed 3 --da contrast 42 | python experiment_test_0.py --seed 3 --da saturation 43 | python experiment_test_0.py --seed 3 --da hue 44 | python experiment_test_0.py --seed 3 --da rotation 45 | python experiment_test_0.py --seed 3 --da translate 46 | python experiment_test_0.py --seed 3 --da scale 47 | python experiment_test_0.py --seed 3 --da shear 48 | python experiment_test_0.py --seed 3 --da none 49 | python experiment_test_0.py --seed 3 --da hflip 50 | python experiment_test_0.py --seed 3 --da vflip 51 | 52 | python experiment_test_0.py --seed 4 --da brightness 53 | python experiment_test_0.py --seed 4 --da contrast 54 | python experiment_test_0.py --seed 4 --da saturation 55 | python experiment_test_0.py --seed 4 --da hue 56 | python experiment_test_0.py --seed 4 --da rotation 57 | python experiment_test_0.py --seed 4 --da translate 58 | python experiment_test_0.py --seed 4 --da scale 59 | python experiment_test_0.py --seed 4 --da shear 60 | python experiment_test_0.py --seed 4 --da none 61 | python experiment_test_0.py --seed 4 --da hflip 62 | python experiment_test_0.py --seed 4 --da vflip 63 | 64 | echo "Done" 65 | -------------------------------------------------------------------------------- /rotated_MNIST/choose_da_with_domain_classifier/jobs/test_0_rot_only.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd ~/AdditionalDriveA/DeployedProjects/TwoLatentSpacesVAE/paper_experiments/rotated_MNIST/choose_da_with_domain_classifier 3 | echo "Starting" 4 | python experiment_test_0_only_rotation.py --seed 0 --da rotation1 5 | python experiment_test_0_only_rotation.py --seed 0 --da rotation2 6 | python experiment_test_0_only_rotation.py --seed 0 --da rotation3 7 | python experiment_test_0_only_rotation.py --seed 0 --da rotation4 8 | python experiment_test_0_only_rotation.py --seed 0 --da rotation5 9 | 10 | python experiment_test_0_only_rotation.py --seed 1 --da rotation1 11 | python experiment_test_0_only_rotation.py --seed 1 --da rotation2 12 | python experiment_test_0_only_rotation.py --seed 1 --da rotation3 13 | python experiment_test_0_only_rotation.py --seed 1 --da rotation4 14 | python experiment_test_0_only_rotation.py --seed 1 --da rotation5 15 | 16 | python experiment_test_0_only_rotation.py --seed 2 --da rotation1 17 | python experiment_test_0_only_rotation.py --seed 2 --da rotation2 18 | python experiment_test_0_only_rotation.py --seed 2 --da rotation3 19 | python experiment_test_0_only_rotation.py --seed 2 --da rotation4 20 | python experiment_test_0_only_rotation.py --seed 2 --da rotation5 21 | 22 | python experiment_test_0_only_rotation.py --seed 3 --da rotation1 23 | python experiment_test_0_only_rotation.py --seed 3 --da rotation2 24 | python experiment_test_0_only_rotation.py --seed 3 --da rotation3 25 | python experiment_test_0_only_rotation.py --seed 3 --da rotation4 26 | python experiment_test_0_only_rotation.py --seed 3 --da rotation5 27 | 28 | python experiment_test_0_only_rotation.py --seed 4 --da rotation1 29 | python experiment_test_0_only_rotation.py --seed 4 --da rotation2 30 | python experiment_test_0_only_rotation.py --seed 4 --da rotation3 31 | python experiment_test_0_only_rotation.py --seed 4 --da rotation4 32 | python experiment_test_0_only_rotation.py --seed 4 --da rotation5 33 | 34 | 35 | echo "Done" 36 | -------------------------------------------------------------------------------- /rotated_MNIST/domain_classifier/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AMLab-Amsterdam/DataAugmentationInterventions/78ce67174db487e9697b9a842e69818305bb41ef/rotated_MNIST/domain_classifier/__init__.py -------------------------------------------------------------------------------- /rotated_MNIST/domain_classifier/domain_classifier.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, "../../../") 3 | 4 | import argparse 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | import torch.utils.data as data_utils 9 | import torch.nn as nn 10 | 11 | import numpy as np 12 | 13 | from paper_experiments.rotated_MNIST.mnist_loader_shifted_label_distribution_rotate import MnistRotatedDist 14 | from paper_experiments.rotated_MNIST.mnist_loader import MnistRotated 15 | 16 | 17 | class Net(nn.Module): 18 | def __init__(self): 19 | super(Net, self).__init__() 20 | self.encoder = nn.Sequential( 21 | nn.Conv2d(1, 32, kernel_size=5, stride=1, bias=False), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2, 2), 22 | nn.Conv2d(32, 64, kernel_size=5, stride=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2, 2), 23 | ) 24 | self.classifier = nn.Sequential(nn.Linear(64 * 4 * 4, 128), 25 | nn.Dropout(0.5), 26 | nn.ReLU(), 27 | nn.Linear(128, 4)) 28 | 29 | def forward(self, x): 30 | h = self.encoder(x) 31 | h = torch.flatten(h, 1) 32 | h = self.classifier(h) 33 | output = F.log_softmax(h, dim=1) 34 | return output 35 | 36 | 37 | def train(args, model, device, train_loader, optimizer, epoch): 38 | model.train() 39 | for batch_idx, (data, _, domain) in enumerate(train_loader): 40 | data, domain = data.to(device), domain.to(device) 41 | _, domain = domain.max(dim=1) 42 | 43 | optimizer.zero_grad() 44 | output = model(data) 45 | loss = F.nll_loss(output, domain) 46 | loss.backward() 47 | optimizer.step() 48 | 49 | 50 | def test(args, model, device, test_loader): 51 | model.eval() 52 | test_loss = 0 53 | correct = 0 54 | with torch.no_grad(): 55 | for data, _, domain in test_loader: 56 | data, domain = data.to(device), domain.to(device) 57 | _, domain = domain.max(dim=1) 58 | 59 | output = model(data) 60 | test_loss += F.nll_loss(output, domain, reduction='sum').item() # sum up batch loss 61 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 62 | correct += pred.eq(domain.view_as(pred)).sum().item() 63 | 64 | test_loss /= len(test_loader.dataset) 65 | 66 | return test_loss, 100. * correct / len(test_loader.dataset) 67 | 68 | 69 | def main(): 70 | # Training settings 71 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 72 | parser.add_argument('--no-cuda', action='store_true', default=False, 73 | help='disables CUDA training') 74 | parser.add_argument('--seed', type=int, default=0, 75 | help='random seed (default: 1)') 76 | parser.add_argument('--batch-size', type=int, default=128, 77 | help='input batch size for training (default: 64)') 78 | parser.add_argument('--epochs', type=int, default=200, 79 | help='number of epochs to train (default: 10)') 80 | parser.add_argument('--lr', type=float, default=0.001, 81 | help='learning rate (default: 0.01)') 82 | parser.add_argument('--da', type=str, default='none', choices=['none', 'rotate', 'flip'], 83 | help='type of data augmentation') 84 | 85 | args = parser.parse_args() 86 | use_cuda = not args.no_cuda and torch.cuda.is_available() 87 | 88 | # Set seed 89 | torch.manual_seed(args.seed) 90 | torch.backends.cudnn.benchmark = False 91 | np.random.seed(args.seed) 92 | 93 | device = torch.device("cuda") 94 | kwargs = {'num_workers': 8, 'pin_memory': True} if use_cuda else {} 95 | 96 | # Load supervised training 97 | if args.da == 'none': 98 | mnist_0 = MnistRotatedDist('../dataset/', train=True, thetas=[0.0], d_label=0, transform=False) 99 | mnist_30 = MnistRotatedDist('../dataset/', train=True, thetas=[30.0], d_label=1, transform=False) 100 | mnist_60 = MnistRotatedDist('../dataset/', train=True, thetas=[60.0], d_label=2, transform=False) 101 | mnist_90 = MnistRotatedDist('../dataset/', train=True, thetas=[90.0], d_label=3, transform=False) 102 | model_name = 'domain_classifier_none_seed_' + str(args.seed) 103 | 104 | elif args.da == 'rotate': 105 | mnist_0 = MnistRotatedDist('../dataset/', train=True, thetas=[0.0], d_label=0, transform=True) 106 | mnist_30 = MnistRotatedDist('../dataset/', train=True, thetas=[30.0], d_label=1, transform=True) 107 | mnist_60 = MnistRotatedDist('../dataset/', train=True, thetas=[60.0], d_label=2, transform=True) 108 | mnist_90 = MnistRotatedDist('../dataset/', train=True, thetas=[90.0], d_label=3, transform=True) 109 | model_name = 'domain_classifier_rotate_seed_' + str(args.seed) 110 | 111 | mnist = data_utils.ConcatDataset([mnist_0, mnist_30, mnist_60, mnist_90]) 112 | 113 | train_size = int(0.9 * len(mnist)) 114 | val_size = len(mnist) - train_size 115 | train_dataset, val_dataset = torch.utils.data.random_split(mnist, [train_size, val_size]) 116 | 117 | train_loader = data_utils.DataLoader(train_dataset, 118 | batch_size=args.batch_size, 119 | shuffle=True, **kwargs) 120 | 121 | val_loader = data_utils.DataLoader(val_dataset, 122 | batch_size=args.batch_size, 123 | shuffle=False, **kwargs) 124 | 125 | model = Net().to(device) 126 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 127 | 128 | best_val_acc = 0 129 | 130 | for epoch in range(1, args.epochs + 1): 131 | print('\n Epoch: ' + str(epoch)) 132 | train(args, model, device, train_loader, optimizer, epoch) 133 | val_loss, val_acc = test(args, model, device, val_loader) 134 | 135 | print(epoch, val_loss, val_acc) 136 | 137 | # Save best 138 | if val_acc >= best_val_acc: 139 | best_val_acc = val_acc 140 | 141 | torch.save(model, model_name + '.model') 142 | torch.save(args, model_name + '.config') 143 | 144 | # Test loader 145 | mnist_0 = MnistRotated('../dataset/', train=False, thetas=[0.0], d_label=0) 146 | mnist_30 = MnistRotated('../dataset/', train=False, thetas=[30.0], d_label=1) 147 | mnist_60 = MnistRotated('../dataset/', train=False, thetas=[60.0], d_label=2) 148 | mnist_90 = MnistRotated('../dataset/', train=False, thetas=[90.0], d_label=3) 149 | 150 | mnist = data_utils.ConcatDataset([mnist_0, mnist_30, mnist_60, mnist_90]) 151 | 152 | test_loader = data_utils.DataLoader(mnist, 153 | batch_size=args.batch_size, 154 | shuffle=False, **kwargs) 155 | 156 | model = torch.load(model_name + '.model').to(device) 157 | _, test_acc = test(args, model, device, test_loader) 158 | 159 | with open(model_name + '.txt', "w") as text_file: 160 | text_file.write("Test Acc: " + str(test_acc)) 161 | 162 | 163 | if __name__ == '__main__': 164 | main() -------------------------------------------------------------------------------- /rotated_MNIST/domain_classifier/domain_classifier_none.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | #SBATCH --job-name=baseline 3 | #SBATCH --partition=gpu 4 | #SBATCH --ntasks=1 5 | #SBATCH --gres=gpu:1 6 | #SBATCH --verbose 7 | #SBATCH -t 08:00:00 8 | 9 | cd ~/AdditionalDriveA/DeployedProjects/TwoLatentSpacesVAE/paper_experiments/rotated_MNIST/domain_classifier 10 | 11 | source activate twoVAE 12 | 13 | echo "Starting" 14 | python domain_classifier.py --da none --seed 0 15 | python domain_classifier.py --da none --seed 1 16 | python domain_classifier.py --da none --seed 2 17 | python domain_classifier.py --da none --seed 3 18 | python domain_classifier.py --da none --seed 4 19 | python domain_classifier.py --da none --seed 5 20 | python domain_classifier.py --da none --seed 6 21 | python domain_classifier.py --da none --seed 7 22 | python domain_classifier.py --da none --seed 8 23 | python domain_classifier.py --da none --seed 9 24 | 25 | wait # Waits for parallel jobs to finish 26 | echo "Done" 27 | -------------------------------------------------------------------------------- /rotated_MNIST/domain_classifier/domain_classifier_rotate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | #SBATCH --job-name=baseline 3 | #SBATCH --partition=gpu 4 | #SBATCH --ntasks=1 5 | #SBATCH --gres=gpu:1 6 | #SBATCH --verbose 7 | #SBATCH -t 08:00:00 8 | 9 | cd ~/AdditionalDriveA/DeployedProjects/TwoLatentSpacesVAE/paper_experiments/rotated_MNIST/domain_classifier 10 | 11 | source activate twoVAE 12 | 13 | echo "Starting" 14 | python domain_classifier.py --da rotate --seed 0 15 | python domain_classifier.py --da rotate --seed 1 16 | python domain_classifier.py --da rotate --seed 2 17 | python domain_classifier.py --da rotate --seed 3 18 | python domain_classifier.py --da rotate --seed 4 19 | python domain_classifier.py --da rotate --seed 5 20 | python domain_classifier.py --da rotate --seed 6 21 | python domain_classifier.py --da rotate --seed 7 22 | python domain_classifier.py --da rotate --seed 8 23 | python domain_classifier.py --da rotate --seed 9 24 | 25 | wait # Waits for parallel jobs to finish 26 | echo "Done" 27 | -------------------------------------------------------------------------------- /rotated_MNIST/mnist_loader.py: -------------------------------------------------------------------------------- 1 | """Pytorch Dataset object that loads MNIST and SVHN. It returns x,y,s where s=0 when x,y is taken from MNIST.""" 2 | 3 | import os 4 | import numpy as np 5 | import torch 6 | import torch.utils.data as data_utils 7 | from torchvision import datasets, transforms 8 | 9 | 10 | class MnistRotated(data_utils.Dataset): 11 | def __init__(self, root, train=True, thetas=[0], d_label=0, download=True, transform=False): 12 | self.root = os.path.expanduser(root) 13 | self.train = train 14 | self.thetas = thetas 15 | self.d_label = d_label 16 | self.download = download 17 | self.transform = transform 18 | 19 | self.to_pil = transforms.ToPILImage() 20 | self.to_tensor = transforms.ToTensor() 21 | self.y_to_categorical = torch.eye(10) 22 | self.d_to_categorical = torch.eye(4) 23 | 24 | self.imgs, self.labels = self._get_data() 25 | 26 | def _get_data(self): 27 | mnist_loader = torch.utils.data.DataLoader(datasets.MNIST(self.root, 28 | train=self.train, 29 | download=self.download, 30 | transform=transforms.ToTensor()), 31 | batch_size=60000, 32 | shuffle=False) 33 | 34 | for i, (x, y) in enumerate(mnist_loader): 35 | mnist_imgs = x 36 | mnist_labels = y 37 | 38 | pil_list = [] 39 | for x in mnist_imgs: 40 | pil_list.append(self.to_pil(x)) 41 | 42 | return pil_list, mnist_labels 43 | 44 | def __len__(self): 45 | return len(self.labels) 46 | 47 | def __getitem__(self, index): 48 | x = self.imgs[index] 49 | y = self.labels[index] 50 | 51 | d = np.random.choice(range(len(self.thetas))) 52 | 53 | if self.transform: # data augmentation random rotation by +- 90 degrees 54 | pass 55 | # random_rotation = np.random.randint(0, 360, 1) 56 | # return self.to_tensor(transforms.functional.rotate(x, self.thetas[d] + random_rotation)), self.y_to_categorical[y], \ 57 | # self.d_to_categorical[self.d_label] 58 | else: 59 | return self.to_tensor(transforms.functional.rotate(x, self.thetas[d])), self.y_to_categorical[y], self.d_to_categorical[self.d_label] 60 | 61 | 62 | if __name__ == "__main__": 63 | from torchvision.utils import save_image 64 | 65 | seed = 0 66 | 67 | torch.manual_seed(seed) 68 | torch.backends.cudnn.benchmark = False 69 | np.random.seed(seed) 70 | 71 | mnist_30 = MnistRotated('../dataset/', train=True, thetas=[30.0], d_label=0, transform=False) 72 | mnist_60 = MnistRotated('../dataset/', train=True, thetas=[60.0], d_label=1, transform=False) 73 | mnist_90 = MnistRotated('../dataset/', train=True, thetas=[90.0], d_label=2, transform=False) 74 | 75 | mnist = data_utils.ConcatDataset([mnist_30, mnist_60, mnist_90]) 76 | 77 | train_loader = data_utils.DataLoader(mnist, 78 | batch_size=100, 79 | shuffle=True) 80 | 81 | for i, (x, y, d) in enumerate(train_loader): 82 | _, d = d.max(dim=1) 83 | 84 | 85 | # y = y.argmax(-1) 86 | # 87 | # index = y == 5 88 | # x = x[index] 89 | 90 | save_image(x.cpu(), 91 | 'rotated_mnist.png', nrow=1) 92 | 93 | print(d) 94 | break -------------------------------------------------------------------------------- /rotated_MNIST/mnist_loader_da.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.utils.data as data_utils 5 | from torchvision import datasets, transforms 6 | 7 | 8 | class MnistRotatedDistDa(data_utils.Dataset): 9 | def __init__(self, root, train=True, thetas=[0], d_label=0, download=True, transform=None, rng_state=0): 10 | self.root = os.path.expanduser(root) 11 | self.train = train 12 | self.thetas = thetas 13 | self.d_label = d_label 14 | self.download = download 15 | self.transform = transform 16 | self.rng_state = rng_state 17 | 18 | self.to_pil = transforms.ToPILImage() 19 | self.to_tensor = transforms.ToTensor() 20 | self.y_to_categorical = torch.eye(10) 21 | self.d_to_categorical = torch.eye(4) 22 | 23 | self.imgs, self.labels = self._get_data() 24 | len_train = int(0.8*len(self.imgs)) 25 | 26 | if self.train: 27 | self.imgs = self.imgs[:len_train] 28 | self.labels = self.labels[:len_train] 29 | else: 30 | self.imgs = self.imgs[len_train:] 31 | self.labels = self.labels[len_train:] 32 | 33 | def _get_data(self): 34 | mnist_loader = torch.utils.data.DataLoader(datasets.MNIST(self.root, 35 | train=True, 36 | download=self.download, 37 | transform=transforms.ToTensor()), 38 | batch_size=60000, 39 | shuffle=False) 40 | 41 | for i, (x, y) in enumerate(mnist_loader): 42 | mnist_imgs = x 43 | mnist_labels = y 44 | 45 | # Get 10 random ints between 80 and 160 46 | np.random.set_state(self.rng_state) 47 | label_dist = np.random.randint(80, 160, 10) 48 | 49 | mnist_imgs_dist, mnist_labels_dist = [], [] 50 | for i in range(10): 51 | idx = np.where(mnist_labels == i)[0] 52 | np.random.shuffle(idx) 53 | idx = idx[:label_dist[i]] # select the right amount of labels for each class 54 | mnist_imgs_dist.append(mnist_imgs[idx]) 55 | mnist_labels_dist.append(mnist_labels[idx]) 56 | 57 | mnist_imgs_dist = torch.cat(mnist_imgs_dist) 58 | mnist_labels_dist = torch.cat(mnist_labels_dist) 59 | 60 | pil_list = [] 61 | for x in mnist_imgs_dist: 62 | pil_list.append(self.to_pil(x)) 63 | 64 | return pil_list, mnist_labels_dist 65 | 66 | def __len__(self): 67 | return len(self.labels) 68 | 69 | def __getitem__(self, index): 70 | x = self.imgs[index] 71 | y = self.labels[index] 72 | 73 | d = np.random.choice(range(len(self.thetas))) 74 | 75 | if self.transform is not None: # data augmentation 76 | return self.to_tensor(self.transform(transforms.functional.rotate(x, self.thetas[d]))), self.y_to_categorical[y], self.d_to_categorical[self.d_label] 77 | else: 78 | return self.to_tensor(transforms.functional.rotate(x, self.thetas[d])), self.y_to_categorical[y], self.d_to_categorical[self.d_label] -------------------------------------------------------------------------------- /rotated_MNIST/mnist_loader_shifted_label_distribution_all_da.py: -------------------------------------------------------------------------------- 1 | """Pytorch Dataset object that loads MNIST and SVHN. It returns x,y,s where s=0 when x,y is taken from MNIST.""" 2 | 3 | import os 4 | import numpy as np 5 | import torch 6 | import torch.utils.data as data_utils 7 | from torchvision import datasets, transforms 8 | import torchvision 9 | import PIL 10 | 11 | 12 | class MnistAllDaDist(data_utils.Dataset): 13 | def __init__(self, root, train=True, thetas=[0], d_label=0, download=True): 14 | self.root = os.path.expanduser(root) 15 | self.train = train 16 | self.thetas = thetas 17 | self.d_label = d_label 18 | self.download = download 19 | transform_dict = { 20 | 'brightness': torchvision.transforms.ColorJitter(brightness=1.0, contrast=0, saturation=0, hue=0), 21 | 'contrast': torchvision.transforms.ColorJitter(brightness=0, contrast=1.0, saturation=0, hue=0), 22 | 'saturation': torchvision.transforms.ColorJitter(brightness=0, contrast=0, saturation=1.0, hue=0), 23 | 'hue': torchvision.transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0.5), 24 | 'rotation': torchvision.transforms.RandomAffine([0, 359], translate=None, scale=None, shear=None, 25 | resample=PIL.Image.BILINEAR, fillcolor=0), 26 | 'translate': torchvision.transforms.RandomAffine(0, translate=[0.2, 0.2], scale=None, shear=None, 27 | resample=PIL.Image.BILINEAR, fillcolor=0), 28 | 'scale': torchvision.transforms.RandomAffine(0, translate=None, scale=[0.8, 1.2], shear=None, 29 | resample=PIL.Image.BILINEAR, fillcolor=0), 30 | 'shear': torchvision.transforms.RandomAffine(0, translate=None, scale=None, 31 | shear=[-10., 10., -10., 10.], 32 | resample=PIL.Image.BILINEAR, fillcolor=0), 33 | 'vflip': torchvision.transforms.RandomVerticalFlip(p=0.5), 34 | 'hflip': torchvision.transforms.RandomHorizontalFlip(p=0.5), 35 | 'none': None, 36 | } 37 | 38 | self.transforms = torchvision.transforms.Compose([transform_dict['brightness'], 39 | transform_dict['contrast'], 40 | transform_dict['saturation'], 41 | transform_dict['hue'], 42 | transform_dict['rotation'], 43 | transform_dict['translate'], 44 | transform_dict['scale'], 45 | transform_dict['shear'], 46 | transform_dict['vflip'], 47 | transform_dict['hflip']]) 48 | 49 | self.to_pil = transforms.ToPILImage() 50 | self.to_tensor = transforms.ToTensor() 51 | self.y_to_categorical = torch.eye(10) 52 | self.d_to_categorical = torch.eye(4) 53 | 54 | self.imgs, self.labels = self._get_data() 55 | 56 | 57 | def _get_data(self): 58 | mnist_loader = torch.utils.data.DataLoader(datasets.MNIST(self.root, 59 | train=self.train, 60 | download=self.download, 61 | transform=transforms.ToTensor()), 62 | batch_size=60000, 63 | shuffle=False) 64 | 65 | for i, (x, y) in enumerate(mnist_loader): 66 | mnist_imgs = x 67 | mnist_labels = y 68 | 69 | # Get 10 random ints between 80 and 160 70 | label_dist = np.random.randint(80, 160, 10) 71 | 72 | mnist_imgs_dist, mnist_labels_dist = [], [] 73 | for i in range(10): 74 | idx = np.where(mnist_labels == i)[0] 75 | np.random.shuffle(idx) 76 | idx = idx[:label_dist[i]] # select the right amount of labels for each class 77 | mnist_imgs_dist.append(mnist_imgs[idx]) 78 | mnist_labels_dist.append(mnist_labels[idx]) 79 | 80 | mnist_imgs_dist = torch.cat(mnist_imgs_dist) 81 | mnist_labels_dist = torch.cat(mnist_labels_dist) 82 | 83 | pil_list = [] 84 | for x in mnist_imgs_dist: 85 | pil_list.append(self.to_pil(x)) 86 | 87 | return pil_list, mnist_labels_dist 88 | 89 | def __len__(self): 90 | return len(self.labels) 91 | 92 | def __getitem__(self, index): 93 | x = self.imgs[index] 94 | y = self.labels[index] 95 | 96 | d = np.random.choice(range(len(self.thetas))) 97 | 98 | return self.to_tensor(self.transforms(transforms.functional.rotate(x, self.thetas[d]))), self.y_to_categorical[y], self.d_to_categorical[self.d_label] 99 | -------------------------------------------------------------------------------- /rotated_MNIST/mnist_loader_shifted_label_distribution_rotate.py: -------------------------------------------------------------------------------- 1 | """Pytorch Dataset object that loads MNIST and SVHN. It returns x,y,s where s=0 when x,y is taken from MNIST.""" 2 | 3 | import os 4 | import numpy as np 5 | import torch 6 | import torch.utils.data as data_utils 7 | from torchvision import datasets, transforms 8 | 9 | 10 | class MnistRotatedDist(data_utils.Dataset): 11 | def __init__(self, root, train=True, thetas=[0], d_label=0, download=True, transform=False): 12 | self.root = os.path.expanduser(root) 13 | self.train = train 14 | self.thetas = thetas 15 | self.d_label = d_label 16 | self.download = download 17 | self.transform = transform 18 | 19 | self.to_pil = transforms.ToPILImage() 20 | self.to_tensor = transforms.ToTensor() 21 | self.y_to_categorical = torch.eye(10) 22 | self.d_to_categorical = torch.eye(4) 23 | 24 | self.imgs, self.labels = self._get_data() 25 | 26 | 27 | def _get_data(self): 28 | mnist_loader = torch.utils.data.DataLoader(datasets.MNIST(self.root, 29 | train=self.train, 30 | download=self.download, 31 | transform=transforms.ToTensor()), 32 | batch_size=60000, 33 | shuffle=False) 34 | 35 | for i, (x, y) in enumerate(mnist_loader): 36 | mnist_imgs = x 37 | mnist_labels = y 38 | 39 | # Get 10 random ints between 80 and 160 40 | label_dist = np.random.randint(80, 160, 10) 41 | 42 | mnist_imgs_dist, mnist_labels_dist = [], [] 43 | for i in range(10): 44 | idx = np.where(mnist_labels == i)[0] 45 | np.random.shuffle(idx) 46 | idx = idx[:label_dist[i]] # select the right amount of labels for each class 47 | mnist_imgs_dist.append(mnist_imgs[idx]) 48 | mnist_labels_dist.append(mnist_labels[idx]) 49 | 50 | mnist_imgs_dist = torch.cat(mnist_imgs_dist) 51 | mnist_labels_dist = torch.cat(mnist_labels_dist) 52 | 53 | pil_list = [] 54 | for x in mnist_imgs_dist: 55 | pil_list.append(self.to_pil(x)) 56 | 57 | return pil_list, mnist_labels_dist 58 | 59 | def __len__(self): 60 | return len(self.labels) 61 | 62 | def __getitem__(self, index): 63 | x = self.imgs[index] 64 | y = self.labels[index] 65 | 66 | d = np.random.choice(range(len(self.thetas))) 67 | 68 | if self.transform: # data augmentation, random rotation by 90 degrees 69 | random_rotation = np.random.randint(0, 360, 1) 70 | return self.to_tensor(transforms.functional.rotate(x, self.thetas[d] + random_rotation)), self.y_to_categorical[y], self.d_to_categorical[self.d_label] 71 | else: 72 | return self.to_tensor(transforms.functional.rotate(x, self.thetas[d])), self.y_to_categorical[y], self.d_to_categorical[self.d_label] 73 | 74 | 75 | if __name__ == "__main__": 76 | from torchvision.utils import save_image 77 | 78 | seed = 0 79 | 80 | torch.manual_seed(seed) 81 | torch.backends.cudnn.benchmark = False 82 | np.random.seed(seed) 83 | 84 | mnist_0 = MnistRotatedDist('../dataset/', train=True, thetas=[0], d_label=0, transform=True) 85 | mnist_30 = MnistRotatedDist('../dataset/', train=True, thetas=[30.0], d_label=1, transform=True) 86 | mnist_60 = MnistRotatedDist('../dataset/', train=True, thetas=[60.0], d_label=2, transform=True) 87 | mnist = data_utils.ConcatDataset([mnist_0, mnist_30, mnist_60]) 88 | train_loader = data_utils.DataLoader(mnist, 89 | batch_size=100, 90 | shuffle=True) 91 | 92 | y_array = np.zeros(10) 93 | d_array = np.zeros(3) 94 | 95 | for i, (x, y, d) in enumerate(train_loader): 96 | y_array += y.sum(dim=0).cpu().numpy() 97 | d_array += d.sum(dim=0).cpu().numpy() 98 | 99 | if i == 0: 100 | n = min(x.size(0), 36) 101 | comparison = x[:n].view(-1, 1, 28, 28) 102 | save_image(comparison.cpu(), 103 | 'rotated_mnist_dist.png', nrow=6) 104 | 105 | print(y_array, d_array) -------------------------------------------------------------------------------- /rotated_MNIST/mnist_loader_shifted_label_distribution_rotate_da.py: -------------------------------------------------------------------------------- 1 | """Pytorch Dataset object that loads MNIST and SVHN. It returns x,y,s where s=0 when x,y is taken from MNIST.""" 2 | 3 | import os 4 | import numpy as np 5 | import torch 6 | import torch.utils.data as data_utils 7 | from torchvision import datasets, transforms 8 | from paper_experiments.rotated_MNIST.mnist_loader_da import MnistRotatedDistDa 9 | 10 | 11 | if __name__ == "__main__": 12 | from torchvision.utils import save_image 13 | import torchvision 14 | import PIL 15 | 16 | seed = 0 17 | da = 'hflip' 18 | 19 | torch.manual_seed(seed) 20 | torch.backends.cudnn.benchmark = False 21 | np.random.seed(seed) 22 | 23 | transform_dict = {'brightness': torchvision.transforms.ColorJitter(brightness=1.0, contrast=0, saturation=0, hue=0), 24 | 'contrast': torchvision.transforms.ColorJitter(brightness=0, contrast=1.0, saturation=0, hue=0), 25 | 'saturation': torchvision.transforms.ColorJitter(brightness=0, contrast=0, saturation=1.0, hue=0), 26 | 'hue': torchvision.transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0.5), 27 | 'rotation': torchvision.transforms.RandomAffine([0, 359], translate=None, scale=None, shear=None, 28 | resample=PIL.Image.BILINEAR, fillcolor=0), 29 | 'translate': torchvision.transforms.RandomAffine(0, translate=[0.2, 0.2], scale=None, shear=None, 30 | resample=PIL.Image.BILINEAR, fillcolor=0), 31 | 'scale': torchvision.transforms.RandomAffine(0, translate=None, scale=[0.8, 1.2], shear=None, 32 | resample=PIL.Image.BILINEAR, fillcolor=0), 33 | 'shear': torchvision.transforms.RandomAffine(0, translate=None, scale=None, 34 | shear=[-10., 10., -10., 10.], 35 | resample=PIL.Image.BILINEAR, fillcolor=0), 36 | 'vflip': torchvision.transforms.RandomVerticalFlip(p=0.5), 37 | 'hflip': torchvision.transforms.RandomHorizontalFlip(p=0.5), 38 | 'none': None, 39 | } 40 | 41 | rng_state = np.random.get_state() 42 | mnist_0_train = MnistRotatedDistDa('../dataset/', train=True, thetas=[0], d_label=0, transform=transform_dict[da], rng_state=rng_state) 43 | mnist_0_val = MnistRotatedDistDa('../dataset/', train=False, thetas=[0], d_label=0, transform=None, 44 | rng_state=rng_state) 45 | rng_state = np.random.get_state() 46 | mnist_30_train = MnistRotatedDistDa('../dataset/', train=True, thetas=[30.0], d_label=1, transform=transform_dict[da], rng_state=rng_state) 47 | mnist_30_val = MnistRotatedDistDa('../dataset/', train=False, thetas=[30.0], d_label=1, transform=None, 48 | rng_state=rng_state) 49 | rng_state = np.random.get_state() 50 | mnist_60_train = MnistRotatedDistDa('../dataset/', train=True, thetas=[60.0], d_label=2, transform=transform_dict[da], rng_state=rng_state) 51 | mnist_60_val = MnistRotatedDistDa('../dataset/', train=False, thetas=[60.0], d_label=2, 52 | transform=transform_dict[da], rng_state=rng_state) 53 | mnist_train = data_utils.ConcatDataset([mnist_0_train, mnist_30_train, mnist_60_train]) 54 | train_loader = data_utils.DataLoader(mnist_train, 55 | batch_size=100, 56 | shuffle=True) 57 | 58 | mnist_val = data_utils.ConcatDataset([mnist_0_val, mnist_30_val, mnist_60_val]) 59 | val_loader = data_utils.DataLoader(mnist_val, 60 | batch_size=100, 61 | shuffle=True) 62 | 63 | y_array = np.zeros(10) 64 | d_array = np.zeros(3) 65 | 66 | for i, (x, y, d) in enumerate(train_loader): 67 | # y_array += y.sum(dim=0).cpu().numpy() 68 | # d_array += d.sum(dim=0).cpu().numpy() 69 | 70 | if i == 0: 71 | n = min(x.size(0), 36) 72 | comparison = x[:n].view(-1, 1, 28, 28) 73 | save_image(comparison.cpu(), 74 | 'rotated_mnist_' + da + '.png', nrow=6) 75 | 76 | # print(y_array, d_array) -------------------------------------------------------------------------------- /synthetic_data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AMLab-Amsterdam/DataAugmentationInterventions/78ce67174db487e9697b9a842e69818305bb41ef/synthetic_data/__init__.py -------------------------------------------------------------------------------- /synthetic_data/domain_gen_sem.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class DomainGenToyData(object): 5 | def __init__(self, 6 | dim, 7 | transform_c=False, 8 | transform_y_and_d=False, 9 | additional_noise_y_and_d=False, 10 | additional_noise_x1_and_x2=False): 11 | 12 | self.transform_c = transform_c 13 | self.transform_y_and_d = transform_y_and_d 14 | self.additional_noise_y_and_d = additional_noise_y_and_d 15 | self.additional_noise_x1_and_x2 = additional_noise_x1_and_x2 16 | 17 | # dim is the number of dimensions of x 18 | self.dim_half = dim//2 19 | 20 | # Linear transformation c to y and d 21 | if transform_c: 22 | self.Wcd = torch.randn(self.dim_half, self.dim_half) / dim 23 | self.Wcy = torch.randn(self.dim_half, self.dim_half) / dim 24 | else: 25 | self.Wcd = torch.eye(self.dim_half) 26 | self.Wcy = torch.eye(self.dim_half) 27 | 28 | # Linear transformation from d to x1 and y to x2 29 | if transform_y_and_d: 30 | self.Wdx1 = torch.randn(self.dim_half, self.dim_half) / dim 31 | self.Wyx2 = torch.randn(self.dim_half, self.dim_half) / dim 32 | else: 33 | self.Wdx1 = torch.eye(self.dim_half) 34 | self.Wyx2 = torch.eye(self.dim_half) 35 | 36 | def __call__(self, N, train=True): 37 | 38 | # Add noise to y and d 39 | if self.additional_noise_y_and_d: 40 | noise_y = torch.randn(N, self.dim_half)*0.1 41 | noise_d = torch.randn(N, self.dim_half)*0.1 42 | else: 43 | noise_y = 0 44 | noise_d = 0 45 | 46 | # Add noise to x1 and x2 47 | if self.additional_noise_x1_and_x2: 48 | noise_x1 = torch.randn(N, self.dim_half)*0.1 49 | noise_x2 = torch.randn(N, self.dim_half)*0.1 50 | else: 51 | noise_x1 = 0 52 | noise_x2 = 0 53 | 54 | # Sample values for confounder 55 | c = torch.randn(N, self.dim_half) 56 | 57 | # Compute d 58 | if train: 59 | d = c @ self.Wcd + noise_d 60 | 61 | else: 62 | d = torch.randn(N, self.dim_half) 63 | 64 | # Compute y 65 | y = c @ self.Wcy + noise_y 66 | 67 | # Compute x1 and x2 68 | x1 = d @ self.Wdx1 + noise_x1 69 | x2 = y @ self.Wyx2 + noise_x2 70 | 71 | return torch.cat((x1, x2), dim=1).numpy(), y.sum(dim=1).numpy(), d.sum(dim=1).numpy() 72 | 73 | 74 | if __name__ == "__main__": 75 | import matplotlib.pyplot as plt 76 | 77 | seed = 0 78 | torch.manual_seed(seed) 79 | torch.cuda.manual_seed(seed) 80 | 81 | number_of_samples = 1000 82 | dim = 10 83 | 84 | data_loader = DomainGenToyData(dim, 85 | transform_c=False, 86 | transform_y_and_d=False, 87 | additional_noise_y_and_d=True, 88 | additional_noise_x1_and_x2=True) 89 | 90 | x, y, d = data_loader(number_of_samples, train=True) 91 | 92 | plt.figure() 93 | plt.scatter(y, d) 94 | plt.xlabel('$y$') 95 | plt.ylabel('$d$') 96 | plt.savefig('load_domain_gen_toy_train_FFTT.png', bbox_inches='tight') 97 | 98 | 99 | x, y, d = data_loader(number_of_samples, train=False) 100 | 101 | plt.figure() 102 | plt.scatter(y, d) 103 | plt.xlabel('$y$') 104 | plt.ylabel('$d$') 105 | plt.savefig('load_domain_gen_toy_test_FFTT.png', bbox_inches='tight') -------------------------------------------------------------------------------- /synthetic_data/jobs/ICP.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd ~/AdditionalDriveA/DeployedProjects/TwoLatentSpacesVAE/paper_experiments/synthetic_data 3 | echo "Starting" 4 | python main_ICP.py --setup_hidden 0 --setup_hetero 0 --setup_scramble 0 > ICP_FOU.txt 5 | python main_ICP.py --setup_hidden 0 --setup_hetero 1 --setup_scramble 0 > ICP_FEU.txt 6 | python main_ICP.py --setup_hidden 1 --setup_hetero 0 --setup_scramble 0 > ICP_POU.txt 7 | python main_ICP.py --setup_hidden 1 --setup_hetero 1 --setup_scramble 0 > ICP_PEU.txt 8 | 9 | echo "Done" 10 | -------------------------------------------------------------------------------- /synthetic_data/jobs/IRM.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd ~/AdditionalDriveA/DeployedProjects/TwoLatentSpacesVAE/paper_experiments/synthetic_data 3 | echo "Starting" 4 | python main_IRM.py --setup_hidden 0 --setup_hetero 0 --setup_scramble 0 > IRM_FOU.txt 5 | python main_IRM.py --setup_hidden 0 --setup_hetero 1 --setup_scramble 0 > IRM_FEU.txt 6 | python main_IRM.py --setup_hidden 1 --setup_hetero 0 --setup_scramble 0 > IRM_POU.txt 7 | python main_IRM.py --setup_hidden 1 --setup_hetero 1 --setup_scramble 0 > IRM_PEU.txt 8 | 9 | echo "Done" 10 | -------------------------------------------------------------------------------- /synthetic_data/jobs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AMLab-Amsterdam/DataAugmentationInterventions/78ce67174db487e9697b9a842e69818305bb41ef/synthetic_data/jobs/__init__.py -------------------------------------------------------------------------------- /synthetic_data/main_domain_gen.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import argparse 9 | 10 | import numpy 11 | 12 | from paper_experiments.synthetic_data.irm_scripts.models import * 13 | from paper_experiments.synthetic_data.irm_scripts.sem import ChainEquationModel 14 | 15 | 16 | def pretty(vector): 17 | vlist = vector.view(-1).tolist() 18 | return "[" + ", ".join("{:+.3f}".format(vi) for vi in vlist) + "]" 19 | 20 | 21 | def errors(w, w_hat): 22 | w = w.view(-1) 23 | w_hat = w_hat.view(-1) 24 | 25 | i_causal = (w != 0).nonzero().view(-1) 26 | i_noncausal = (w == 0).nonzero().view(-1) 27 | 28 | if len(i_causal): 29 | error_causal = (w[i_causal] - w_hat[i_causal]).pow(2).mean() 30 | error_causal = error_causal.item() 31 | else: 32 | error_causal = 0 33 | 34 | if len(i_noncausal): 35 | error_noncausal = (w[i_noncausal] - w_hat[i_noncausal]).pow(2).mean() 36 | error_noncausal = error_noncausal.item() 37 | else: 38 | error_noncausal = 0 39 | 40 | return error_causal, error_noncausal 41 | 42 | 43 | def run_experiment(args): 44 | if args["seed"] >= 0: 45 | torch.manual_seed(args["seed"]) 46 | numpy.random.seed(args["seed"]) 47 | torch.set_num_threads(1) 48 | 49 | if args["setup_sem"] == "chain": 50 | setup_str = "chain_hidden={}_hetero={}_scramble={}".format( 51 | args["setup_hidden"], 52 | args["setup_hetero"], 53 | args["setup_scramble"]) 54 | elif args["setup_sem"] == "icp": 55 | setup_str = "sem_icp" 56 | else: 57 | raise NotImplementedError 58 | 59 | all_methods = { 60 | "ERM": EmpiricalRiskMinimizer, 61 | } 62 | 63 | if args["methods"] == "all": 64 | methods = all_methods 65 | else: 66 | methods = {m: all_methods[m] for m in args["methods"].split(',')} 67 | 68 | all_sems = [] 69 | all_solutions = [] 70 | all_environments = [] 71 | 72 | for rep_i in range(args["n_reps"]): 73 | if args["setup_sem"] == "chain": 74 | sem = ChainEquationModel(args["dim"], 75 | hidden=args["setup_hidden"], 76 | scramble=args["setup_scramble"], 77 | hetero=args["setup_hetero"]) 78 | environments = [sem(args["n_samples"], 'train'), 79 | sem(args["n_samples"], 'test')] 80 | else: 81 | raise NotImplementedError 82 | 83 | all_sems.append(sem) 84 | all_environments.append(environments) 85 | 86 | for sem, environments in zip(all_sems, all_environments): 87 | solutions = [ 88 | "{} SEM {} {:.5f} {:.5f}".format(setup_str, 89 | pretty(sem.solution()), 0, 0) 90 | ] 91 | 92 | # solutions = [] 93 | 94 | for method_name, method_constructor in methods.items(): 95 | method = method_constructor(environments, args) 96 | msolution = method.solution() 97 | err_causal, err_noncausal = errors(sem.solution(), msolution) 98 | 99 | solutions.append("{} {} {} {:.5f} {:.5f}".format(setup_str, 100 | method_name, 101 | pretty(msolution), 102 | err_causal, 103 | err_noncausal)) 104 | 105 | # solutions.append("{:.5f} {:.5f}".format(err_causal, err_noncausal)) 106 | 107 | 108 | all_solutions += solutions 109 | 110 | return all_solutions 111 | 112 | 113 | if __name__ == '__main__': 114 | parser = argparse.ArgumentParser(description='Invariant regression') 115 | parser.add_argument('--dim', type=int, default=5) 116 | parser.add_argument('--n_samples', type=int, default=1000) 117 | parser.add_argument('--n_reps', type=int, default=10) 118 | parser.add_argument('--skip_reps', type=int, default=0) 119 | parser.add_argument('--seed', type=int, default=0) # Negative is random 120 | parser.add_argument('--print_vectors', type=int, default=1) 121 | parser.add_argument('--n_iterations', type=int, default=100000) 122 | parser.add_argument('--lr', type=float, default=1e-3) 123 | parser.add_argument('--verbose', type=int, default=0) 124 | # parser.add_argument('--methods', type=str, default="ERM,ICP,IRM,ERMNoiseInterventionRight,ERMNoiseInterventionWrong,ERMConstantInterventionRight,ERMConstantInterventionWrong") 125 | parser.add_argument('--methods', type=str, default="ERM") 126 | parser.add_argument('--alpha', type=float, default=0.05) 127 | parser.add_argument('--setup_sem', type=str, default="chain") 128 | parser.add_argument('--setup_hidden', type=int, default=0) # F/P 129 | parser.add_argument('--setup_hetero', type=int, default=0) # O/E 130 | parser.add_argument('--setup_scramble', type=int, default=0) # U/S 131 | parser.add_argument('--num_no_noise', type=int, default=0) 132 | args = dict(vars(parser.parse_args())) 133 | 134 | all_solutions = run_experiment(args) 135 | print("\n".join(all_solutions)) 136 | -------------------------------------------------------------------------------- /synthetic_data/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | feat_x2_0_inter = np.array([0.34256017, 0.31460512, 0.27144957, 0.24856855, 0.22719437]) 5 | feat_x2_0_inter_ste = np.array([0.004312400818, 0.003773893416, 0.002982698083, 0.00233306855, 0.002138502002]) 6 | feat_x2_1_inter = np.array([0.35977444, 0.3302486, 0.2858478, 0.26235148, 0.23809971]) 7 | feat_x2_1_inter_ste = np.array([0.004517972469, 0.004006842971, 0.003241877258, 0.002627826631, 0.002335606366]) 8 | feat_x2_2_inter = np.array([0.387529, 0.35519564, 0.30843478, 0.28350833, 0.2540046]) 9 | feat_x2_2_inter_ste = np.array([0.004769945741, 0.00426604718, 0.00358707428, 0.003047236502, 0.002536125779]) 10 | feat_x2_3_inter = np.array([0.4165158, 0.38317177, 0.33215567, 0.3049404, 0.27241272]) 11 | feat_x2_3_inter_ste = np.array([0.005206080675, 0.004714588821, 0.003986877203, 0.003410176337, 0.002820271552]) 12 | feat_x2_4_inter = np.array([0.44910964, 0.41258878, 0.35587674, 0.3253371, 0.29092044]) 13 | feat_x2_4_inter_ste = np.array([0.005780712962, 0.005148547292, 0.004204738736, 0.003574062288, 0.003055044413]) 14 | ERM = np.array([0.37845063, 0.37845063, 0.37845063, 0.37845063, 0.37845063]) 15 | ERM_ste = np.array([0.004980756044, 0.004980756044, 0.004980756044, 0.004980756044, 0.004980756044]) 16 | 17 | ERM_x2_only = np.array([0.22802237, 0.22802237, 0.22802237, 0.22802237, 0.22802237]) 18 | ERM_ste_x2_only = np.array([0.0021790754795074463, 0.0021790754795074463, 0.0021790754795074463, 0.0021790754795074463, 0.0021790754795074463]) 19 | 20 | x = [1, 2, 3, 4, 5] 21 | 22 | fig, ax = plt.subplots() 23 | plt.plot(x, ERM, label='ERM') 24 | # plt.fill_between(x, ERM - ERM_ste, ERM + ERM_ste, alpha=0.1) 25 | markers, caps, bars = ax.errorbar(x, feat_x2_0_inter, yerr=feat_x2_0_inter_ste, label='augmentation on 0 dims of $h_y$') 26 | # plt.fill_between(x, feat_x2_0_inter - feat_x2_0_inter_ste, feat_x2_0_inter + feat_x2_0_inter_ste, alpha=0.1) 27 | markers, caps, bars = ax.errorbar(x, feat_x2_1_inter, yerr=feat_x2_1_inter_ste, label='augmentation on 1 dim of $h_y$') 28 | # plt.fill_between(x, feat_x2_1_inter - feat_x2_1_inter_ste, feat_x2_1_inter + feat_x2_1_inter_ste, alpha=0.1) 29 | markers, caps, bars = ax.errorbar(x, feat_x2_2_inter, yerr=feat_x2_2_inter_ste, label='augmentation on 2 dims of $h_y$') 30 | # plt.fill_between(x, feat_x2_2_inter - feat_x2_2_inter_ste, feat_x2_2_inter + feat_x2_2_inter_ste, alpha=0.1) 31 | markers, caps, bars = ax.errorbar(x, feat_x2_3_inter, yerr=feat_x2_3_inter_ste, label='augmentation on 3 dims of $h_y$') 32 | # plt.fill_between(x, feat_x2_3_inter - feat_x2_3_inter_ste, feat_x2_3_inter + feat_x2_3_inter_ste, alpha=0.1) 33 | markers, caps, bars = ax.errorbar(x, feat_x2_4_inter, yerr=feat_x2_4_inter_ste, label='augmentation on 4 dims of $h_y$') 34 | plt.plot(x, ERM_x2_only, label='ERM using only $h_y$') 35 | # plt.fill_between(x, feat_x2_4_inter - feat_x2_4_inter_ste, feat_x2_4_inter + feat_x2_4_inter_ste, alpha=0.1) 36 | plt.xticks(x) # Set locations and labels 37 | plt.legend() 38 | plt.ylabel('$MSE$') 39 | plt.xlabel('num of dims of $h_d$ w/ augmentation') 40 | plt.savefig('toy_data_comparison.png', bbox_inches='tight', dpi=300) -------------------------------------------------------------------------------- /synthetic_data/run_erm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | from sklearn import linear_model 4 | from sklearn.metrics import mean_squared_error, r2_score 5 | import torch 6 | import matplotlib.pyplot as plt 7 | 8 | from paper_experiments.synthetic_data.domain_gen_sem import DomainGenToyData 9 | 10 | # Training settings 11 | parser = argparse.ArgumentParser(description='Tpy experiment') 12 | parser.add_argument('--samples', default=1000) 13 | parser.add_argument('--dim', default=10) 14 | parser.add_argument('--transform_c', default=True) 15 | parser.add_argument('--transform_y_and_d', default=True) 16 | parser.add_argument('--additional_noise_y_and_d', default=True) 17 | parser.add_argument('--additional_noise_x1_and_x2', default=True) 18 | parser.add_argument('--no_noise_d_features', default=5) 19 | parser.add_argument('--no_noise_y_features', default=5) 20 | args = parser.parse_args() 21 | 22 | result_list = [] 23 | for seed in range(50): 24 | np.random.seed(seed) 25 | torch.manual_seed(seed) 26 | torch.cuda.manual_seed(seed) 27 | 28 | # Load data 29 | data_loader = DomainGenToyData(args.dim, 30 | transform_c=args.transform_c, 31 | transform_y_and_d=args.transform_y_and_d, 32 | additional_noise_y_and_d=args.additional_noise_y_and_d, 33 | additional_noise_x1_and_x2=args.additional_noise_x1_and_x2) 34 | x_train, y_train, _ = data_loader(1000, train=True) 35 | 36 | # # Add uniform noise to causal features of d 37 | # # Generate uniform noise [-10, 10] 38 | # noise_x1 = torch.FloatTensor(x_train.shape[0], args.dim//2).uniform_(-10, 10).numpy() 39 | # 40 | # # Randomly select channels that get no noise 41 | # feature_index_no_noise_x1 = np.random.choice(args.dim//2, args.no_noise_d_features, replace=False) 42 | # 43 | # # set noise to zero 44 | # noise_x1[:, feature_index_no_noise_x1] = 0.0 45 | # 46 | # # Add to x_1 47 | # x_train[:, :5] += noise_x1 48 | # # x_train[:, :5] = 10 49 | # 50 | # # Add uniform noise to causal features of y 51 | # # Generate uniform noise [-20, 20] 52 | # noise_x2 = torch.FloatTensor(x_train.shape[0], args.dim//2).uniform_(-10, 10).numpy() 53 | # 54 | # # Randomly select channels that get no noise 55 | # feature_index_no_noise_x2 = np.random.choice(args.dim//2, args.no_noise_y_features, replace=False) 56 | # 57 | # # set noise to zero 58 | # noise_x2[:, feature_index_no_noise_x2] = 0.0 59 | # 60 | # # Add to x_2 61 | # x_train[:, 5:] += noise_x2 62 | 63 | x_test, y_test, _ = data_loader(1000, train=False) 64 | 65 | x_train = x_train[:, 5:] 66 | x_test = x_test[:, 5:] 67 | # Create linear regression object 68 | regr = linear_model.LinearRegression() 69 | 70 | # Train the model using the training sets 71 | regr.fit(x_train, y_train) 72 | 73 | # Make predictions using the testing set 74 | y_pred = regr.predict(x_test) 75 | 76 | # The coefficients 77 | print('Coefficients: \n', regr.coef_) 78 | # The mean squared error 79 | print('Mean squared error: %.2f' 80 | % mean_squared_error(y_test, y_pred)) 81 | # The coefficient of determination: 1 is perfect prediction 82 | print('Coefficient of determination: %.2f' 83 | % r2_score(y_test, y_pred)) 84 | 85 | # plt.figure() 86 | # plt.scatter(np.arange(len(y_test)), y_test) 87 | # plt.scatter(np.arange(len(y_test)), y_pred) 88 | # plt.xlabel('$x$') 89 | # plt.ylabel('$y$') 90 | # plt.savefig('compare_predictions.png', bbox_inches='tight') 91 | 92 | result_list.append(mean_squared_error(y_test, y_pred)) 93 | 94 | print(np.mean(result_list)) 95 | print(np.std(result_list)/np.sqrt(50)) -------------------------------------------------------------------------------- /synthetic_data/run_erm_domain.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | from sklearn import linear_model 4 | from sklearn.metrics import mean_squared_error, r2_score 5 | import torch 6 | import matplotlib.pyplot as plt 7 | 8 | from paper_experiments.synthetic_data.domain_gen_sem import DomainGenToyData 9 | 10 | # Training settings 11 | parser = argparse.ArgumentParser(description='Tpy experiment') 12 | parser.add_argument('--samples', default=1000) 13 | parser.add_argument('--dim', default=10) 14 | parser.add_argument('--transform_c', default=True) 15 | parser.add_argument('--transform_y_and_d', default=True) 16 | parser.add_argument('--additional_noise_y_and_d', default=True) 17 | parser.add_argument('--additional_noise_x1_and_x2', default=True) 18 | parser.add_argument('--no_noise_d_features', default=0) 19 | parser.add_argument('--no_noise_y_features', default=5) 20 | args = parser.parse_args() 21 | 22 | result_list = [] 23 | for seed in range(50): 24 | np.random.seed(seed) 25 | torch.manual_seed(seed) 26 | torch.cuda.manual_seed(seed) 27 | 28 | # Load data 29 | data_loader = DomainGenToyData(args.dim, 30 | transform_c=args.transform_c, 31 | transform_y_and_d=args.transform_y_and_d, 32 | additional_noise_y_and_d=args.additional_noise_y_and_d, 33 | additional_noise_x1_and_x2=args.additional_noise_x1_and_x2) 34 | x_train, _, d_train = data_loader(1000, train=True) 35 | 36 | # Add uniform noise to causal features of d 37 | # Generate uniform noise [-10, 10] 38 | noise_x1 = torch.FloatTensor(x_train.shape[0], args.dim//2).uniform_(-10, 10).numpy() 39 | 40 | # Randomly select channels that get no noise 41 | feature_index_no_noise_x1 = np.random.choice(args.dim//2, args.no_noise_d_features, replace=False) 42 | 43 | # set noise to zero 44 | noise_x1[:, feature_index_no_noise_x1] = 0.0 45 | 46 | # Add to x_1 47 | x_train[:, :5] += noise_x1 48 | # x_train[:, :5] = 10 49 | 50 | # Add uniform noise to causal features of y 51 | # Generate uniform noise [-20, 20] 52 | noise_x2 = torch.FloatTensor(x_train.shape[0], args.dim//2).uniform_(-10, 10).numpy() 53 | 54 | # Randomly select channels that get no noise 55 | feature_index_no_noise_x2 = np.random.choice(args.dim//2, args.no_noise_y_features, replace=False) 56 | 57 | # set noise to zero 58 | noise_x2[:, feature_index_no_noise_x2] = 0.0 59 | 60 | # Add to x_2 61 | x_train[:, 5:] += noise_x2 62 | 63 | x_test, _, d_test = data_loader(1000, train=False) 64 | 65 | x_train = x_train[:, 5:] 66 | x_test = x_test[:, 5:] 67 | # Create linear regression object 68 | regr = linear_model.LinearRegression() 69 | 70 | # Train the model using the training sets 71 | regr.fit(x_train, d_train) 72 | 73 | # Make predictions using the testing set 74 | d_pred = regr.predict(x_test) 75 | 76 | # The coefficients 77 | print('Coefficients: \n', regr.coef_) 78 | # The mean squared error 79 | print('Mean squared error: %.2f' 80 | % mean_squared_error(d_test, d_pred)) 81 | # The coefficient of determination: 1 is perfect prediction 82 | print('Coefficient of determination: %.2f' 83 | % r2_score(d_test, d_pred)) 84 | 85 | # plt.figure() 86 | # plt.scatter(np.arange(len(y_test)), y_test) 87 | # plt.scatter(np.arange(len(y_test)), y_pred) 88 | # plt.xlabel('$x$') 89 | # plt.ylabel('$y$') 90 | # plt.savefig('compare_predictions.png', bbox_inches='tight') 91 | 92 | result_list.append(mean_squared_error(d_test, d_pred)) 93 | 94 | print(np.mean(result_list)) 95 | print(np.std(result_list)/np.sqrt(50)) -------------------------------------------------------------------------------- /synthetic_data/run_erm_with_data_augmentation_random_noise.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | from sklearn import linear_model 4 | from sklearn.metrics import mean_squared_error, r2_score 5 | import torch 6 | import matplotlib.pyplot as plt 7 | 8 | from paper_experiments.synthetic_data.domain_gen_sem import DomainGenToyData 9 | 10 | # Training settings 11 | parser = argparse.ArgumentParser(description='Tpy experiment') 12 | parser.add_argument('--samples', default=1000) 13 | parser.add_argument('--dim', default=10) 14 | parser.add_argument('--transform_c', default=True) 15 | parser.add_argument('--transform_y_and_d', default=True) 16 | parser.add_argument('--additional_noise_y_and_d', default=True) 17 | parser.add_argument('--additional_noise_x1_and_x2', default=True) 18 | parser.add_argument('--no_noise_d_features', default=4) 19 | parser.add_argument('--no_noise_y_features', default=5) 20 | args = parser.parse_args() 21 | 22 | result_list = [] 23 | for seed in range(50): 24 | np.random.seed(seed) 25 | torch.manual_seed(seed) 26 | torch.cuda.manual_seed(seed) 27 | 28 | # Load data 29 | data_loader = DomainGenToyData(args.dim, 30 | transform_c=args.transform_c, 31 | transform_y_and_d=args.transform_y_and_d, 32 | additional_noise_y_and_d=args.additional_noise_y_and_d, 33 | additional_noise_x1_and_x2=args.additional_noise_x1_and_x2) 34 | x_train, y_train, _ = data_loader(1000, train=True) 35 | 36 | # Add uniform noise to causal features of d 37 | # Generate uniform noise [-10, 10] 38 | noise_x1 = torch.FloatTensor(x_train.shape[0], args.dim//2).uniform_(-10, 10).numpy() 39 | 40 | # Randomly select channels that get no noise 41 | feature_index_no_noise_x1 = np.random.choice(args.dim//2, args.no_noise_d_features, replace=False) 42 | 43 | # set noise to zero 44 | noise_x1[:, feature_index_no_noise_x1] = 0.0 45 | 46 | # Add to x_1 47 | x_train[:, :5] += noise_x1 48 | # x_train[:, :5] = 10 49 | 50 | # Add uniform noise to causal features of y 51 | # Generate uniform noise [-20, 20] 52 | noise_x2 = torch.FloatTensor(x_train.shape[0], args.dim//2).uniform_(-10, 10).numpy() 53 | 54 | # Randomly select channels that get no noise 55 | feature_index_no_noise_x2 = np.random.choice(args.dim//2, args.no_noise_y_features, replace=False) 56 | 57 | # set noise to zero 58 | noise_x2[:, feature_index_no_noise_x2] = 0.0 59 | 60 | # Add to x_2 61 | x_train[:, 5:] += noise_x2 62 | 63 | x_test, y_test, _ = data_loader(1000, train=False) 64 | 65 | # Create linear regression object 66 | regr = linear_model.LinearRegression() 67 | 68 | # Train the model using the training sets 69 | regr.fit(x_train, y_train) 70 | 71 | # Make predictions using the testing set 72 | y_pred = regr.predict(x_test) 73 | 74 | # The coefficients 75 | print('Coefficients: \n', regr.coef_) 76 | # The mean squared error 77 | print('Mean squared error: %.2f' 78 | % mean_squared_error(y_test, y_pred)) 79 | # The coefficient of determination: 1 is perfect prediction 80 | print('Coefficient of determination: %.2f' 81 | % r2_score(y_test, y_pred)) 82 | 83 | # plt.figure() 84 | # plt.scatter(np.arange(len(y_test)), y_test) 85 | # plt.scatter(np.arange(len(y_test)), y_pred) 86 | # plt.xlabel('$x$') 87 | # plt.ylabel('$y$') 88 | # plt.savefig('compare_predictions.png', bbox_inches='tight') 89 | 90 | result_list.append(mean_squared_error(y_test, y_pred)) 91 | 92 | print(np.mean(result_list)) 93 | print(np.std(result_list)/np.sqrt(50)) -------------------------------------------------------------------------------- /synthetic_data/toy_data_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AMLab-Amsterdam/DataAugmentationInterventions/78ce67174db487e9697b9a842e69818305bb41ef/synthetic_data/toy_data_comparison.png --------------------------------------------------------------------------------