├── models ├── __init__.py ├── net_utils │ ├── __init__.py │ ├── torch_utils.py │ ├── misc.py │ └── torch_ckpts.py ├── ensemble.py ├── MLP.py └── mnets_classifier_interface.py ├── .gitignore ├── utils ├── kernel.py ├── SSGE_squeeze.py ├── SSGE.py ├── generate_results.py ├── config.py └── distributions.py ├── README.md ├── data ├── toy_regression │ ├── generete_regression.py │ └── regression1d_data.py ├── toy_classification │ ├── generate_classification.py │ ├── oned_gaussian.py │ ├── twod_gaussian.py │ ├── donuts.py │ └── moons.py ├── generate_dataset.py ├── fashion_mnist │ └── fashion_data.py ├── mnist │ └── split_mnist.py └── svhn │ └── data_svhn_data.py ├── experiments ├── exp_regr.py ├── exp_2d_class.py ├── exp_cifar.py └── exp_fmnist.py ├── requirements.txt ├── methods ├── method_utils.py ├── f_SVGD.py ├── SVGD.py └── WGD.py └── training ├── training_1d_regre.py ├── training_1dreg.py └── training_mnist.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | -------------------------------------------------------------------------------- /models/net_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/kernel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | class RBF(torch.nn.Module): 5 | def __init__(self, sigma=None): 6 | super(RBF, self).__init__() 7 | 8 | self.sigma = sigma 9 | 10 | def median(self, tensor): 11 | tensor = tensor.flatten().sort()[0] 12 | length = tensor.shape[0] 13 | 14 | if length % 2 == 0: 15 | szh = length // 2 16 | kth = [szh - 1, szh] 17 | else: 18 | kth = [(length - 1) // 2] 19 | return tensor[kth].mean() 20 | 21 | def forward(self, X, Y): 22 | XX = X.matmul(X.t()) 23 | XY = X.matmul(Y.t()) 24 | YY = Y.matmul(Y.t()) 25 | 26 | dnorm2 = -2 * XY + XX.diag().unsqueeze(1) + YY.diag().unsqueeze(0) 27 | 28 | # Apply the median heuristic (PyTorch does not give true median) 29 | if self.sigma is None: 30 | sigma = self.median(dnorm2.detach()) / (2 * torch.tensor(math.log(X.size(0) + 1))) 31 | else: 32 | sigma = self.sigma ** 2 33 | 34 | gamma = 1.0 / (2 * sigma) 35 | K_XY = (-gamma * dnorm2).exp() 36 | 37 | return K_XY -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Repulsive Deep Ensembles are Bayesian 2 | 3 | This repo contains the code of the paper [Repulsive deep ensembles are Bayesian](https://proceedings.neurips.cc/paper/2021/hash/1c63926ebcabda26b5cdb31b5cc91efb-Abstract.html). In the following some usage examples can be found 4 | 5 | ## Sampling from synthetic distributions experiments 6 | The experiment for the synthetic distributions can be found in 'notebooks/WGD_synthetic.ipynb' 7 | 8 | ## 1d regression experiments 9 | 10 | The 1d toy regression problem can be explored. Example run: 11 | 12 | ```console 13 | $ python3 experiments/exp_regr.py --epochs 5000 --lr 1e-2 --n_particles 100 --size_hidden 10 --num_hidden 2 --method SGD --prior_variance 1 --annealing_steps 1000 --batch_size 32 --dataset toy_reg --ann_sch None 14 | 15 | ``` 16 | 17 | ## 2d classification experiments 18 | 19 | The 2d classification problem can be explored. Example run: 20 | 21 | ```console 22 | $ python3 experiments/exp_2d_class.py --epochs 10000 --lr 1e-2 --n_particles 100 --size_hidden 10 --num_hidden 2 --method SVGD --prior_variance 1 --annealing_steps 1000 --batch_size 128 --dataset twod_gaussian --ann_sch None 23 | 24 | ``` 25 | ## Citation 26 | 27 | If you use our code or consider our ideas in your research project, please consider citing our paper. 28 | ``` 29 | @article{d2021repulsive, 30 | title={Repulsive Deep Ensembles are Bayesian}, 31 | author={D'Angelo, Francesco and Fortuin, Vincent}, 32 | journal={arXiv preprint arXiv:2106.11642}, 33 | year={2021} 34 | } 35 | ``` 36 | -------------------------------------------------------------------------------- /data/toy_regression/generete_regression.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from data.toy_regression.regression1d_data import ToyRegression 3 | 4 | def generate_1d_dataset(show_plots=True, task_set=0, data_random_seed=42): 5 | """Generate a set of tasks for 1D regression. 6 | 7 | Args: 8 | show_plots: Visualize the generated datasets. 9 | data_random_seed: Random seed that should be applied to the 10 | synthetic data generation. 11 | task_set: int for the regression task 12 | 13 | Returns: 14 | data_handlers: A data handler 15 | """ 16 | 17 | map_funcs = [lambda x: (x) ** 3., 18 | lambda x: (3.*x), 19 | lambda x: 2. * np.power(x, 2) - 1, 20 | lambda x: np.power(x - 3., 3), 21 | lambda x: x*np.sin(x), 22 | lambda x: x*(1+np.sin(x))] 23 | x_domains = [[-3.5, 3.5],[-2, 2], [-1, 1], [2, 4],[2,6],[3,12]] 24 | test_domains = [[-5.0,5.0],[-3, +3], [-2.5, 2.5], [.5, 4.1],[0,7],[2,13]] 25 | std = [3,0.05,0.05,0.05,0.25,0.6] 26 | num_train = 90 27 | blob = [None,None,None,None,[[1.5,2.5],[4.5,6.0]],[[4.5,5],[7.5,8.5],[10,11]]] 28 | 29 | 30 | data = ToyRegression(train_inter=x_domains[task_set], 31 | num_train=num_train, test_inter=test_domains[task_set], blob = blob[task_set], num_test=100, 32 | val_inter=x_domains[task_set], num_val=100, 33 | map_function=map_funcs[task_set], std=std[task_set], rseed=data_random_seed) 34 | return data 35 | 36 | -------------------------------------------------------------------------------- /data/toy_classification/generate_classification.py: -------------------------------------------------------------------------------- 1 | from data.toy_classification.oned_gaussian import oned_gaussian 2 | from data.toy_classification.twod_gaussian import twod_gaussian 3 | from data.toy_classification.donuts import Donuts 4 | from data.toy_classification.moons import Moons 5 | from data.mnist.mnist_data import MNISTData 6 | from data.mnist.split_mnist import get_split_mnist_handlers 7 | from data.fashion_mnist.fashion_data import FashionMNISTData 8 | from data.cifar.cifar10_data import CIFAR10Data 9 | from data.svhn.data_svhn_data import SVHNData 10 | import os 11 | 12 | def generate_moons(config): 13 | 14 | data = Moons(n_train=1500, n_test=500, noise=0.1, rseed = config.data_random_seed) 15 | 16 | return data 17 | 18 | def generate_oned_gaussian(): 19 | 20 | data = oned_gaussian() 21 | 22 | return data 23 | 24 | def generate_twod_gaussian(config): 25 | 26 | data = twod_gaussian(rseed = 42,mu=[], sigma=[[1.,1.]for i in range(5)],n_train = 40, n_test = 20) 27 | 28 | return data 29 | 30 | def generate_donuts(config): 31 | 32 | data = Donuts(r_1 = (9,10),r_2 = (3,4), c_outer_1 = (0,0), c_outer_2 = (0,0), n_train=100, n_test = 80, rseed = config.data_random_seed) 33 | 34 | return data 35 | 36 | def generate_mnist(): 37 | data = MNISTData(os.getcwd()+'/mnist', True) 38 | 39 | return data 40 | 41 | def generate_split_mnist(): 42 | data = get_split_mnist_handlers(os.getcwd()+'/mnist', True,num_classes_per_task = 5) 43 | 44 | return data 45 | 46 | def generate_f_mnist(): 47 | 48 | data = FashionMNISTData(os.getcwd()+'/fashion_mnist', True) 49 | 50 | return data 51 | 52 | def generate_cifar(): 53 | 54 | data = CIFAR10Data(os.getcwd()+'/cifar', True) 55 | 56 | return data 57 | 58 | def generate_svhn(): 59 | 60 | data = SVHNData(os.getcwd()+'/svhn', True) 61 | 62 | return data -------------------------------------------------------------------------------- /models/ensemble.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | class Ensemble(): 5 | """Implementation of an ensemble of models 6 | 7 | This is a simple class to manage and make predictions using an ensemble with or without particles 8 | Args: 9 | device: Torch device (cpu or gpu). 10 | net: pytorch model to create the ensemble 11 | particles(Tensor): Tensor (n_particles, n_params) containing squeezed parameter value of the specified model, 12 | if None, particles will be sample from a gaussian N(0,1) 13 | n_particles(int): if no particles are provided the ensemble is initialized and the number of members is required 14 | 15 | """ 16 | 17 | def __init__(self, device, net, particles=None, n_particles=1): 18 | self.net = net 19 | if particles is None: 20 | self.particles = (1*torch.randn(n_particles, *torch.Size([self.net.num_params]))).to(device) 21 | #self.particles =torch.FloatTensor(n_particles, self.net.num_params).uniform_(-0.1, 0.1).to(device) 22 | else: 23 | self.particles = particles 24 | 25 | self.weighs_split = [np.prod(w) for w in net.param_shapes] 26 | 27 | def reshape_particles(self, z): 28 | reshaped_weights = [] 29 | z_splitted = torch.split(z, self.weighs_split, 1) 30 | for j in range(z.shape[0]): 31 | l = [] 32 | for i, shape in enumerate(self.net.param_shapes): 33 | l.append(z_splitted[i][j].reshape(shape)) 34 | reshaped_weights.append(l) 35 | return reshaped_weights 36 | 37 | def forward(self, x, W=None): 38 | if W is None: 39 | W = self.particles 40 | models = self.reshape_particles(W) 41 | if self.net.out_act is None: 42 | pred = [self.net.forward(x, w) for w in models] 43 | return [torch.stack(pred)] #.unsqueeze(0) 44 | else: 45 | pred,hidden = zip(*(list(self.net.forward(x,w)) for w in models)) 46 | return torch.stack(pred), torch.stack(hidden) -------------------------------------------------------------------------------- /data/generate_dataset.py: -------------------------------------------------------------------------------- 1 | from data.toy_classification.generate_classification import generate_moons, generate_oned_gaussian, generate_twod_gaussian, generate_mnist, generate_f_mnist, generate_donuts, generate_split_mnist, generate_cifar, generate_svhn 2 | from data.toy_regression.generete_regression import generate_1d_dataset 3 | 4 | 5 | def generate_dataset(config): 6 | """Generate a dataset. 7 | 8 | Args: 9 | config: Command-line arguments. 10 | 11 | Returns: 12 | data_handlers(DATASET): A data handlers. 13 | classification(bool): Whether the dataset is a classification task or not 14 | """ 15 | if config.dataset == 'toy_reg': 16 | classification = False 17 | return generate_1d_dataset(show_plots=True, task_set=4, 18 | data_random_seed=config.data_random_seed), classification 19 | elif config.dataset == 'moons': 20 | classification = True 21 | return generate_moons(config), classification 22 | 23 | elif config.dataset == 'oned_gaussian': 24 | classification = True 25 | return generate_oned_gaussian(), classification 26 | 27 | elif config.dataset == 'twod_gaussian': 28 | classification = True 29 | return generate_twod_gaussian(config), classification 30 | 31 | elif config.dataset == 'mnist': 32 | classification = True 33 | return generate_mnist(), classification 34 | 35 | elif config.dataset == 's_mnist': 36 | classification = True 37 | return generate_split_mnist(), classification 38 | 39 | elif config.dataset == 'f_mnist': 40 | classification = True 41 | return generate_f_mnist(), classification 42 | 43 | elif config.dataset == 'donuts': 44 | classification = True 45 | return generate_donuts(config), classification 46 | 47 | elif config.dataset == 'cifar': 48 | classification = True 49 | return generate_cifar(), classification 50 | 51 | elif config.dataset == 'svhn': 52 | classification = True 53 | return generate_svhn(), classification -------------------------------------------------------------------------------- /experiments/exp_regr.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | import torch 4 | import torch.nn.functional as F 5 | from tensorboardX import SummaryWriter 6 | from utils.config import configuration 7 | from data.generate_dataset import generate_dataset 8 | from models.MLP import Net 9 | from models.ensemble import Ensemble 10 | from training.training_1d_regre import train 11 | 12 | 13 | def run(): 14 | """Run the script. 15 | """ 16 | 17 | config = configuration() 18 | date = datetime.now().strftime('%H-%M-%S') 19 | exp_dir = 'exp_'+datetime.now().strftime('%m-%d-%H-%M') 20 | torch.manual_seed(config.random_seed) 21 | 22 | if config.noise: 23 | alg = 'sto' 24 | else: 25 | alg = 'det' 26 | 27 | f_date = datetime.now().strftime('%Y-%m-%d') 28 | 29 | dout_dir = './out/'+ f_date + '/'+ config.dataset +'_'+ config.exp_dir +'/'+config.method +'/'+ config.ann_sch+str(config.annealing_steps) +'/'+config.where_repulsive + 'part_'+str(config.n_particles)+'/lr_'+str(config.lr)+'/seed_'+str(config.random_seed)+'_run_' + date 30 | config.out_dir = dout_dir 31 | 32 | 33 | writer = SummaryWriter(log_dir=os.path.join(config.out_dir, 'summary')) 34 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 35 | 36 | writer.add_text('Comment', config.comment, 0) 37 | configu = dict(config.__dict__) 38 | del configu['comment'] 39 | 40 | writer.add_text('Hparams', str(configu), 0) 41 | data, classification = generate_dataset(config) 42 | layer_sizes = [data.in_shape[0], data.out_shape[0]] 43 | 44 | for i in range(config.num_hidden): 45 | layer_sizes.insert(-1, config.size_hidden) 46 | 47 | mnet = Net(layer_sizes, classification = classification, act=F.relu,out_act = None).to(device) 48 | 49 | #l = [] 50 | #for _ in range(config.n_particles): 51 | # l.append(torch.cat([p.flatten() for p in Net(layer_sizes, classification = True, act=F.relu,out_act = F.softmax, bias = True, no_weights=False).parameters()][len(mnet.param_shapes):]).detach()) 52 | 53 | #initial_particles = torch.stack(l).to(device) 54 | 55 | ensemble = Ensemble(device = device, net=mnet, n_particles = config.n_particles) 56 | #ensemble = Ensemble(device = device, net=mnet,particles=initial_particles) 57 | 58 | #ensemble = Ensemble(device = device, net=mnet, n_particles = config.n_particles) 59 | 60 | train(data, ensemble, device, config,writer) 61 | 62 | #particles = ensemble.particles.detach().numpy() 63 | 64 | #np.save(date+'.np', particles) 65 | 66 | if __name__ == '__main__': 67 | run() -------------------------------------------------------------------------------- /experiments/exp_2d_class.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.chdir('..') 3 | from utils.config import configuration 4 | import torch 5 | from data.generate_dataset import generate_dataset 6 | from models.MLP import Net 7 | from models.ensemble import Ensemble 8 | from tensorboardX import SummaryWriter 9 | import os 10 | import torch.nn.functional as F 11 | import numpy as np 12 | import warnings 13 | warnings.filterwarnings('ignore') 14 | 15 | from training.training_2d_class import train 16 | 17 | from datetime import datetime 18 | import sys 19 | sys.stdout.flush() 20 | import yaml 21 | 22 | def run(): 23 | """Run the script. 24 | """ 25 | config = configuration() 26 | date = datetime.now().strftime('%H-%M-%S') 27 | exp_dir = 'exp_'+datetime.now().strftime('%m-%d-%H-%M') 28 | torch.manual_seed(config.random_seed) 29 | 30 | if config.logit_soft == 0: 31 | l_s='softmax' 32 | else: 33 | l_s='logit' 34 | 35 | if config.noise: 36 | alg = 'sto' 37 | else: 38 | alg = 'det' 39 | 40 | f_date = datetime.now().strftime('%Y-%m-%d') 41 | 42 | dout_dir = './out/'+ f_date + '/'+ config.dataset +'_'+ config.exp_dir +'/'+config.method +'/'+ config.ann_sch+str(config.annealing_steps) +'/'+ l_s +'/'+alg+'/'+config.where_repulsive + 'part_'+str(config.n_particles)+'/lr_'+str(config.lr)+'/seed_'+str(config.random_seed)+'_run_' + date 43 | config.out_dir = dout_dir 44 | 45 | 46 | writer = SummaryWriter(log_dir=os.path.join(config.out_dir, 'summary')) 47 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 48 | 49 | writer.add_text('Comment', config.comment, 0) 50 | configu = dict(config.__dict__) 51 | del configu['comment'] 52 | 53 | writer.add_text('Hparams', str(configu), 0) 54 | 55 | #writer.add_hparams(configu,{}) 56 | 57 | data, classification = generate_dataset(config) 58 | 59 | #data_ood = data[1] 60 | 61 | #data_ood = generate_mnist() 62 | 63 | layer_sizes = [data.in_shape[0], data.out_shape[0]] 64 | 65 | for i in range(config.num_hidden): 66 | layer_sizes.insert(-1, config.size_hidden) 67 | 68 | mnet = Net(layer_sizes, classification = classification, act=F.relu,out_act = F.softmax, bias = True ).to(device) 69 | 70 | ensemble = Ensemble(device = device, net=mnet, n_particles = config.n_particles) 71 | 72 | metrics = train(data, ensemble, device, config,writer) 73 | 74 | results.append(metrics) 75 | 76 | particles = ensemble.particles.cpu().detach().numpy() 77 | 78 | np.save(dout_dir+'/'+date+'particles', particles) 79 | 80 | np.save(dout_dir+'/'+date+'results', np.array(metrics)) 81 | 82 | dictionary_parameters = vars(config) 83 | 84 | with open(dout_dir+'/'+date+ 'parameters.yml', 'w') as yaml_file: 85 | yaml.dump(dictionary_parameters, stream=yaml_file, default_flow_style=False) 86 | 87 | np.save('./out/'+ f_date + '/'+ config.dataset +'_'+ config.exp_dir +'/'+config.method +'_'+ config.ann_sch +'_'+ l_s +'_'+config.where_repulsive+'results', np.array(results)) 88 | if __name__ == '__main__': 89 | run() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.8.1 2 | alabaster==0.7.12 3 | altair==4.1.0 4 | argcomplete==1.12.0 5 | asn1crypto==1.2.0 6 | astor==0.8.0 7 | astroid==2.3.3 8 | atomicwrites==1.3.0 9 | attrs==19.3.0 10 | autograd==1.3 11 | Babel==2.7.0 12 | backcall==0.1.0 13 | bleach==3.1.0 14 | blinker==1.4 15 | bokeh==2.1.1 16 | celluloid==0.2.0 17 | certifi==2020.12.5 18 | cffi==1.13.0 19 | click==7.1.2 20 | cryptography==2.8 21 | cycler==0.10.0 22 | decorator==4.4.1 23 | defusedxml==0.6.0 24 | distro==1.5.0 25 | docutils==0.15.2 26 | entrypoints==0.3 27 | fastprogress==1.0.0 28 | future==0.18.2 29 | gast==0.2.2 30 | gogo-gadget==0.2.3 31 | google-auth-oauthlib==0.4.1 32 | google-pasta==0.1.8 33 | h5py==2.9.0 34 | idna==2.8 35 | imagecorruptions==1.1.0 36 | imageio==2.6.1 37 | imagesize==1.1.0 38 | importlib-metadata==0.23 39 | invoke==1.4.1 40 | ipykernel==5.1.3 41 | ipython==7.9.0 42 | ipython-genutils==0.2.0 43 | isort==4.3.21 44 | jedi==0.13.3 45 | Jinja2==2.10.3 46 | joblib==0.17.0 47 | jsonschema==3.2.0 48 | jupyter-client==5.3.4 49 | jupyter-console==6.0.0 50 | jupyter-core==4.6.1 51 | jupyter-tensorboard==0.2.0 52 | jupyterlab==2.2.0 53 | Keras-Applications==1.0.8 54 | Keras-Preprocessing==1.1.0 55 | kiwisolver==1.1.0 56 | lazy-object-proxy==1.4.3 57 | littlemcmc==0.2.2 58 | Markdown==3.1.1 59 | MarkupSafe==1.1.1 60 | mccabe==0.6.1 61 | mistune==0.8.4 62 | mkl-fft==1.0.14 63 | mkl-random==1.1.0 64 | mkl-service==2.3.0 65 | more-itertools==7.2.0 66 | natsort==7.0.1 67 | nbconvert==5.6.1 68 | nbformat==4.4.0 69 | networkx==2.5 70 | notebook==6.0.2 71 | npm==0.1.1 72 | numpy==1.17.2 73 | oauthlib==3.1.0 74 | olefile==0.46 75 | opencv-python==4.4.0.44 76 | opt-einsum==3.3.0 77 | optional-django==0.1.0 78 | packaging==19.2 79 | pandas==0.25.3 80 | pandocfilters==1.4.2 81 | parso==0.5.1 82 | patsy==0.5.1 83 | pexpect==4.7.0 84 | pickleshare==0.7.5 85 | Pillow==6.2.0 86 | pipx==0.15.5.1 87 | plotly==4.9.0 88 | pluggy==0.13.0 89 | prometheus-client==0.7.1 90 | prompt-toolkit==2.0.10 91 | protobuf==3.11.1 92 | ptyprocess==0.6.0 93 | py==1.8.0 94 | pyasn1==0.4.8 95 | pyasn1-modules==0.2.8 96 | pycparser==2.19 97 | Pygments==2.4.2 98 | pygpu==0.7.6 99 | PyJWT==1.7.1 100 | pylint==2.4.4 101 | pymc3==3.7 102 | pyOpenSSL==19.1.0 103 | pyparsing==2.4.2 104 | pyro-api==0.1.1 105 | pyro-ppl==1.1.0 106 | pyrsistent==0.15.5 107 | PySocks==1.7.1 108 | pytest==5.0.1 109 | python-dateutil==2.8.0 110 | pytz==2019.3 111 | PyWavelets==1.1.1 112 | PyYAML==5.3.1 113 | pyzmq==18.1.0 114 | qtconsole==4.6.0 115 | requests==2.22.0 116 | retrying==1.3.3 117 | sampyl-mcmc==0.3 118 | scikit-image==0.17.2 119 | scikit-learn==0.22 120 | scipy==1.3.1 121 | selenium==3.141.0 122 | Send2Trash==1.5.0 123 | six==1.12.0 124 | sklearn==0.0 125 | snowballstemmer==2.0.0 126 | Sphinx==2.2.1 127 | sphinxcontrib-applehelp==1.0.1 128 | sphinxcontrib-devhelp==1.0.1 129 | sphinxcontrib-htmlhelp==1.0.2 130 | sphinxcontrib-jsmath==1.0.1 131 | sphinxcontrib-qthelp==1.0.2 132 | sphinxcontrib-serializinghtml==1.1.3 133 | statsmodels==0.10.1 134 | style==1.1.0 135 | tensorboard-plugin-wit==1.7.0 136 | tensorboardX==2.1 137 | tensorflow==2.0.0 138 | tensorflow-estimator==2.0.0 139 | termcolor==1.1.0 140 | terminado==0.8.3 141 | testpath==0.4.4 142 | tifffile==2020.9.22 143 | toolz==0.10.0 144 | torch==1.8.1 145 | torch-two-sample==0.1 146 | torchvision==0.9.1 147 | tornado==6.0.3 148 | tqdm==4.36.1 149 | traitlets==4.3.3 150 | typed-ast==1.4.1 151 | typing-extensions==3.7.4.2 152 | uncertainty-metrics==0.0.81 153 | update==0.0.1 154 | urllib3==1.24.2 155 | userpath==1.4.1 156 | wcwidth==0.1.7 157 | webencodings==0.5.1 158 | Werkzeug==0.16.0 159 | wrapt==1.11.2 160 | zipp==0.6.0 161 | -------------------------------------------------------------------------------- /experiments/exp_cifar.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.chdir('..') 3 | from utils.config import configuration 4 | import torch 5 | from data.generate_dataset import generate_dataset 6 | from models.ensemble import Ensemble 7 | from tensorboardX import SummaryWriter 8 | import os 9 | import numpy as np 10 | from data.toy_classification.generate_classification import generate_svhn 11 | import torch.nn.functional as F 12 | import warnings 13 | warnings.filterwarnings('ignore') 14 | from training.training_mnist_corruption import train 15 | from datetime import datetime 16 | from models.mnets_resnet import ResNet 17 | import warnings 18 | warnings.filterwarnings("ignore", category=DeprecationWarning) 19 | 20 | import sys 21 | sys.stdout.flush() 22 | import yaml 23 | 24 | def run(): 25 | """Run the script. 26 | """ 27 | n_runs = 1 28 | results = [] 29 | exp_dir = 'exp_'+datetime.now().strftime('%m-%d-%H-%M') 30 | config = configuration() 31 | torch.manual_seed(config.random_seed) 32 | for i in range(n_runs): 33 | date = datetime.now().strftime('%H-%M-%S') 34 | if config.logit_soft == 0: 35 | l_s='softmax' 36 | else: 37 | l_s='logit' 38 | 39 | if config.noise: 40 | alg = 'sto' 41 | else: 42 | alg = 'det' 43 | 44 | f_date = datetime.now().strftime('%Y-%m-%d') 45 | 46 | dout_dir = './out/'+ f_date + '/'+ config.dataset +'_'+ config.exp_dir +'/'+config.method +'/'+ config.ann_sch+str(config.annealing_steps) +'/'+ l_s +'/'+alg+'/'+config.where_repulsive + 'part_'+str(config.n_particles)+'/hidden_'+str(config.size_hidden) +'/lr_'+str(config.lr)+'/l2_'+str(config.prior_variance)+'/'+'/seed_'+str(config.random_seed)+'_run_' + date 47 | config.out_dir = dout_dir 48 | 49 | writer = SummaryWriter(log_dir=os.path.join(config.out_dir, 'summary')) 50 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 51 | 52 | writer.add_text('Comment', config.comment, 0) 53 | configu = dict(config.__dict__) 54 | del configu['comment'] 55 | 56 | writer.add_text('Hparams', str(configu), 0) 57 | print(str(configu)) 58 | dictionary_parameters = vars(config) 59 | 60 | with open(dout_dir+'/'+date+ 'parameters.yml', 'w') as yaml_file: 61 | yaml.dump(dictionary_parameters, stream=yaml_file, default_flow_style=False) 62 | #writer.add_hparams(configu,{}) 63 | 64 | data, classification = generate_dataset(config) 65 | 66 | #data_ood = data[1] 67 | 68 | data_ood = generate_svhn() 69 | 70 | mnet = ResNet(out_act = F.softmax, use_batch_norm = False).to(device) 71 | 72 | l = [] 73 | for _ in range(config.n_particles): 74 | l.append(torch.cat([p.flatten() for p in ResNet(out_act=F.softmax, no_weights=False, 75 | use_batch_norm=False).parameters()]).detach()) 76 | 77 | initial_particles = torch.stack(l).to(device) 78 | 79 | #ensemble = Ensemble(device = device, net=mnet, n_particles = config.n_particles) 80 | 81 | ensemble = Ensemble(device=device, net=mnet, particles=initial_particles) 82 | 83 | metrics = train(data, data_ood, ensemble, device, config,writer) 84 | 85 | results.append(metrics) 86 | 87 | particles = ensemble.particles.cpu().detach().numpy() 88 | 89 | np.save(dout_dir+'/'+date+'particles', particles) 90 | 91 | np.save(dout_dir+'/'+date+'results', np.array(metrics)) 92 | 93 | 94 | 95 | #np.save('./out/'+ f_date + '/'+ config.dataset +'_'+ config.exp_dir +'/'+config.method +'_'+ config.ann_sch +'_'+ l_s +'_'+config.where_repulsive+'results', np.array(results)) 96 | 97 | if __name__ == '__main__': 98 | run() -------------------------------------------------------------------------------- /utils/SSGE_squeeze.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd as autograd 3 | 4 | """ 5 | 6 | The original implementation can be found here https://github.com/AntixK/Spectral-Stein-Gradient/blob/master/score_estimator/spectral_stein.py 7 | 8 | """ 9 | 10 | 11 | class SpectralSteinEstimator(): 12 | def __init__(self, eta=None, num_eigs=None, K=None, xm = None, device = None): 13 | self.eta = eta 14 | self.num_eigs = num_eigs 15 | self.K = K 16 | self.xm = xm 17 | if xm is not None: 18 | self.beta, self.eigen_vals, self.eigen_vec = self.compute_beta(xm) 19 | self.device = device 20 | 21 | def nystrom_method(self, x, eval_points, eigen_vecs, eigen_vals): 22 | """ 23 | Implements the Nystrom method for approximating the 24 | eigenfunction (generalized eigenvectors) for the kernel 25 | at x using the M eval_points (x_m). It is given 26 | by - 27 | .. math:: 28 | phi_j(x) = \frac{M}{\lambda_j} \sum_{m=1}^M u_{jm} k(x, x_m) 29 | :param x: (Tensor) Point at which the eigenfunction is evaluated [N x D] 30 | :param eval_points: (Tensor) Sample points from the data of ize M [M x D] 31 | :param eigen_vecs: (Tensor) Eigenvectors of the gram matrix [M x M] 32 | :param eigen_vals: (Tensor) Eigenvalues of the gram matrix [M x 2] 33 | :return: Eigenfunction at x [N x M] 34 | """ 35 | M = torch.tensor(eval_points.size(-2), dtype=torch.float) 36 | 37 | Kxxm = self.K(x, eval_points) 38 | phi_x = torch.sqrt(M) * Kxxm @ eigen_vecs 39 | 40 | phi_x *= 1. / eigen_vals[:, 0] # Take only the real part of the eigenvals 41 | # as the Im is 0 (Symmetric matrix) 42 | return phi_x 43 | 44 | def compute_beta(self,xm): 45 | 46 | M = torch.tensor(xm.size(-2), dtype=torch.float) 47 | 48 | xm = xm.detach().requires_grad_(True) 49 | 50 | Kxx = self.K(xm, xm.detach()) 51 | 52 | dKxx_dx = autograd.grad(Kxx.sum(), xm)[0] 53 | 54 | # Kxx = Kxx + eta * I 55 | if self.eta is not None: 56 | Kxx += self.eta * torch.eye(xm.size(-2)).to(self.device) 57 | 58 | eigen_vals, eigen_vecs = torch.eig(Kxx, eigenvectors=True) 59 | 60 | if self.num_eigs is not None: 61 | eigen_vals = eigen_vals[:self.num_eigs] 62 | eigen_vecs = eigen_vecs[:, :self.num_eigs] 63 | 64 | 65 | 66 | # Compute the Monte Carlo estimate of the gradient of 67 | # the eigenfunction at x 68 | dKxx_dx_avg = -dKxx_dx/xm.shape[0] # [M x D] 69 | 70 | beta = - torch.sqrt(M) * eigen_vecs.t() @ dKxx_dx_avg 71 | beta *= (1. / eigen_vals[:, 0].unsqueeze(-1)) 72 | 73 | return beta, eigen_vals, eigen_vecs 74 | 75 | def compute_score_gradients(self, x, xm = None): 76 | """ 77 | Computes the Spectral Stein Gradient Estimate (SSGE) for the 78 | score function. The SSGE is given by 79 | .. math:: 80 | \nabla_{xi} phi_j(x) = \frac{1}{\mu_j M} \sum_{m=1}^M \nabla_{xi}k(x,x^m) \phi_j(x^m) 81 | \beta_{ij} = -\frac{1}{M} \sum_{m=1}^M \nabla_{xi} phi_j (x^m) 82 | \g_i(x) = \sum_{j=1}^J \beta_{ij} \phi_j(x) 83 | :param x: (Tensor) Point at which the gradient is evaluated [N x D] 84 | :param xm: (Tensor) Samples for the kernel [M x D] 85 | :return: gradient estimate [N x D] 86 | """ 87 | if xm is None: 88 | xm = self.xm 89 | beta = self.beta 90 | eigen_vecs = self.eigen_vecs 91 | eigen_vals = self.eigen_vals 92 | else: 93 | beta,eigen_vals,eigen_vecs = self.compute_beta(xm) 94 | 95 | phi_x = self.nystrom_method(x, xm, eigen_vecs, eigen_vals) # [N x M] 96 | # assert beta.allclose(beta1), f"incorrect computation {beta - beta1}" 97 | g = phi_x @ beta # [N x D] 98 | return g 99 | -------------------------------------------------------------------------------- /methods/method_utils.py: -------------------------------------------------------------------------------- 1 | from math import floor 2 | 3 | import torch 4 | 5 | from utils.SSGE_squeeze import SpectralSteinEstimator 6 | from f_SVGD import f_s_SVGD 7 | from utils.kernel import RBF 8 | from methods.SVGD import SGLD, SGD, SVGD, SVGLD 9 | from methods.WGD import WGD, f_WGD 10 | 11 | 12 | def create_method(config,P,optimizer,K=None, device = None): 13 | """ 14 | Utils for the creation of the SVGD method 15 | """ 16 | ann_sch = create_ann(config) 17 | pred_idx = config.logit_soft #1 logit, 0 softmax 18 | num_train = config.num_train 19 | 20 | if K is None: 21 | K = RBF() 22 | 23 | if config.method == 'SGLD': 24 | method = SGLD(P, K, optimizer, device = device ) 25 | elif config.method == 'SGD': 26 | method = SGD(P,optimizer) 27 | elif config.method == 'SVGD': 28 | K = RBF() 29 | method = SVGD(P,K,optimizer,config, ann_sch, num_train = num_train, noise = config.noise ) 30 | elif config.method == 'SVGLD': 31 | K = RBF() 32 | method = SVGLD(P,K,optimizer,config,ann_sch, beta = 100) 33 | elif config.method == 'f_s_SVGD': 34 | ssge_k = RBF() 35 | ssge = SpectralSteinEstimator(0.01,None,ssge_k, device = device) 36 | method = f_s_SVGD(P, K, optimizer,ssge,config,ann_sch,pred_idx,num_train,noise = config.noise) 37 | elif config.method == 'kde_WGD': 38 | K = RBF() 39 | method = WGD(P,K,optimizer,config, ann_sch,grad_estim=None, num_train = num_train, method = 'kde' ) 40 | elif config.method == 'sge_WGD': 41 | K = RBF() 42 | method = WGD(P,K,optimizer,config, ann_sch,grad_estim=None, num_train = num_train, method = 'sge', device = device ) 43 | elif config.method == 'ssge_WGD': 44 | ssge_k = RBF() 45 | K = RBF() 46 | ssge = SpectralSteinEstimator(0.01,None,ssge_k, device = device) 47 | method = WGD(P, K, optimizer,config,ann_sch,grad_estim = ssge, num_train = num_train, method = 'ssge') 48 | elif config.method == 'kde_f_WGD': 49 | ssge_k = RBF() 50 | K = RBF() 51 | ssge = SpectralSteinEstimator(0.01,None,ssge_k, device = device) 52 | method = f_WGD(P, K, optimizer,config,ann_sch,grad_estim = ssge, pred_idx = pred_idx,num_train = num_train, method = 'kde') 53 | elif config.method == 'sge_f_WGD': 54 | ssge_k = RBF() 55 | K = RBF() 56 | ssge = SpectralSteinEstimator(0.01,None,ssge_k, device = device) 57 | method = f_WGD(P, K, optimizer,config,ann_sch,grad_estim = ssge, pred_idx = pred_idx,num_train = num_train, method = 'sge', device = device) 58 | elif config.method == 'ssge_f_WGD': 59 | ssge_k = RBF() 60 | K = RBF() 61 | ssge = SpectralSteinEstimator(0.01,None,ssge_k, device = device) 62 | method = f_WGD(P, K, optimizer,config,ann_sch,grad_estim = ssge, pred_idx = pred_idx,num_train = num_train, method = 'ssge') 63 | 64 | 65 | return method 66 | 67 | # cosine annealing learning rate schedule 68 | def cosine_annealing(epoch, n_epochs, n_cycles, lrate_max): 69 | epochs_per_cycle = floor(n_epochs/n_cycles) 70 | cos_inner =(epoch % epochs_per_cycle)/ (epochs_per_cycle) 71 | return (cos_inner) 72 | 73 | 74 | def create_ann(config): 75 | if config.ann_sch == 'linear': 76 | ann_sch = torch.cat([torch.linspace(0,config.gamma,config.annealing_steps),config.gamma*torch.ones(config.epochs-config.annealing_steps)]) 77 | elif config.ann_sch == 'hyper': 78 | ann_sch =torch.cat([torch.tanh((torch.linspace(0,config.annealing_steps,config.annealing_steps)*1.3/config.annealing_steps)**10),config.gamma*torch.ones(config.epochs-config.annealing_steps)]) 79 | elif config.ann_sch == 'cyclic': 80 | ann_sch = torch.cat([torch.tensor([cosine_annealing(a,config.annealing_steps,5,1)**10 for a in range(config.annealing_steps)]),config.gamma*torch.ones(config.epochs-config.annealing_steps)]) 81 | elif config.ann_sch == 'None': 82 | ann_sch = config.gamma*torch.ones(config.epochs) 83 | return ann_sch -------------------------------------------------------------------------------- /experiments/exp_fmnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.chdir('..') 3 | from utils.config import configuration 4 | import torch 5 | from data.generate_dataset import generate_dataset 6 | from models.MLP import Net 7 | from models.ensemble import Ensemble 8 | from tensorboardX import SummaryWriter 9 | import os 10 | import torch.nn.functional as F 11 | import numpy as np 12 | from data.toy_classification.generate_classification import generate_mnist 13 | import warnings 14 | warnings.filterwarnings('ignore') 15 | 16 | from training.training_mnist_hyper import train_hyper 17 | 18 | from datetime import datetime 19 | import sys 20 | sys.stdout.flush() 21 | import yaml 22 | 23 | def run(): 24 | """Run the script. 25 | """ 26 | n_runs = 1 27 | results = [] 28 | exp_dir = 'exp_'+datetime.now().strftime('%m-%d-%H-%M') 29 | config = configuration() 30 | torch.manual_seed(config.random_seed) 31 | for i in range(n_runs): 32 | date = datetime.now().strftime('%H-%M-%S') 33 | 34 | if config.logit_soft == 0: 35 | l_s='softmax' 36 | else: 37 | l_s='logit' 38 | 39 | if config.noise: 40 | alg = 'sto' 41 | else: 42 | alg = 'det' 43 | 44 | f_date = datetime.now().strftime('%Y-%m-%d') 45 | 46 | dout_dir = './out/'+ f_date + '/'+ config.dataset +'_'+ config.exp_dir +'/'+config.method +'/'+ config.ann_sch+str(config.annealing_steps) +'/'+ l_s +'/'+alg+'/'+config.where_repulsive + 'part_'+str(config.n_particles)+'/hidden_'+str(config.size_hidden) +'/lr_'+str(config.lr)+'/l2_'+str(config.prior_variance)+'/'+'/seed_'+str(config.random_seed)+'_run_' + date 47 | config.out_dir = dout_dir 48 | 49 | writer = SummaryWriter(log_dir=os.path.join(config.out_dir, 'summary')) 50 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 51 | 52 | writer.add_text('Comment', config.comment, 0) 53 | configu = dict(config.__dict__) 54 | del configu['comment'] 55 | 56 | writer.add_text('Hparams', str(configu), 0) 57 | print(str(configu)) 58 | #writer.add_hparams(configu,{}) 59 | 60 | data, classification = generate_dataset(config) 61 | 62 | #data_ood = data[1] 63 | 64 | data_ood = generate_mnist() 65 | 66 | layer_sizes = [data_ood.in_shape[0]**2, data_ood.out_shape[0]] 67 | 68 | for i in range(config.num_hidden): 69 | layer_sizes.insert(-1, config.size_hidden) 70 | 71 | mnet = Net(layer_sizes, classification = classification, act=F.relu,out_act = F.softmax, bias = True ).to(device) 72 | 73 | l = [] 74 | for _ in range(config.n_particles): 75 | l.append(torch.cat([p.flatten() for p in Net(layer_sizes, classification = True, act=F.relu,out_act = F.softmax, bias = True, no_weights=False).parameters()][len(mnet.param_shapes):]).detach()) 76 | 77 | initial_particles = torch.stack(l).to(device) 78 | 79 | #ensemble = Ensemble(device = device, net=mnet, n_particles = config.n_particles) 80 | ensemble = Ensemble(device = device, net=mnet,particles=initial_particles) 81 | 82 | metrics = train_hyper(data, data_ood, ensemble, device, config,writer) 83 | 84 | results.append(metrics) 85 | 86 | particles = ensemble.particles.cpu().detach().numpy() 87 | 88 | np.save(dout_dir+'/'+date+'particles', particles) 89 | 90 | np.save(dout_dir+'/'+date+'results', np.array(metrics)) 91 | 92 | dictionary_parameters = vars(config) 93 | 94 | with open(dout_dir+'/'+date+ 'parameters.yml', 'w') as yaml_file: 95 | yaml.dump(dictionary_parameters, stream=yaml_file, default_flow_style=False) 96 | 97 | #np.save('./out/'+ f_date + '/'+ config.dataset +'_'+ config.exp_dir +'/'+config.method +'_'+ config.ann_sch +'_'+ l_s +'_'+config.where_repulsive+'results', np.array(results)) 98 | 99 | if __name__ == '__main__': 100 | run() -------------------------------------------------------------------------------- /utils/SSGE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd as autograd 3 | 4 | """ 5 | The original implementation can be found here https://github.com/AntixK/Spectral-Stein-Gradient/blob/master/score_estimator/spectral_stein.py 6 | """ 7 | 8 | class SpectralSteinEstimator(): 9 | def __init__(self, eta=None, num_eigs=None, K=None, xm = None): 10 | self.eta = eta 11 | self.num_eigs = num_eigs 12 | self.K = K 13 | self.xm = xm 14 | if xm is not None: 15 | self.beta, self.eigen_vals, self.eigen_vecs = self.compute_beta(xm) 16 | 17 | def nystrom_method(self, x, eval_points, eigen_vecs, eigen_vals): 18 | """ 19 | Implements the Nystrom method for approximating the 20 | eigenfunction (generalized eigenvectors) for the kernel 21 | at x using the M eval_points (x_m). It is given 22 | by - 23 | .. math:: 24 | phi_j(x) = \frac{M}{\lambda_j} \sum_{m=1}^M u_{jm} k(x, x_m) 25 | :param x: (Tensor) Point at which the eigenfunction is evaluated [N x D] 26 | :param eval_points: (Tensor) Sample points from the data of ize M [M x D] 27 | :param eigen_vecs: (Tensor) Eigenvectors of the gram matrix [M x M] 28 | :param eigen_vals: (Tensor) Eigenvalues of the gram matrix [M x 2] 29 | :return: Eigenfunction at x [N x M] 30 | """ 31 | M = torch.tensor(eval_points.size(-2), dtype=torch.float) 32 | 33 | Kxxm = self.K(x, eval_points) 34 | phi_x = torch.sqrt(M) * Kxxm @ eigen_vecs 35 | 36 | phi_x *= 1. / eigen_vals[:,:,0].unsqueeze(-1) # Take only the real part of the eigenvals 37 | # as the Im is 0 (Symmetric matrix) 38 | return phi_x 39 | 40 | def compute_beta(self,xm): 41 | 42 | M = torch.tensor(xm.shape[0], dtype=torch.float) 43 | 44 | xm = xm.detach().requires_grad_(True) 45 | 46 | Kxx = self.K(xm, xm.detach()) 47 | 48 | dKxx_dx = autograd.grad(Kxx.sum(), xm)[0] #TODO: is this doing the right thing ? 49 | 50 | # Kxx = Kxx + eta * I 51 | if self.eta is not None: 52 | Kxx += self.eta * torch.eye(Kxx.shape[1]) 53 | 54 | eigen_vals, eigen_vecs = zip(*[torch.eig(x,eigenvectors = True) for x in Kxx]) 55 | 56 | eigen_vals = torch.stack(eigen_vals) 57 | eigen_vecs = torch.stack(eigen_vecs) 58 | 59 | if self.num_eigs is not None: 60 | eigen_vals = eigen_vals[:,self.num_eigs] 61 | eigen_vecs = eigen_vecs[:,:, :self.num_eigs] 62 | 63 | 64 | 65 | # Compute the Monte Carlo estimate of the gradient of 66 | # the eigenfunction at x 67 | dKxx_dx_avg = -dKxx_dx/xm.shape[1] # [M x D] 68 | 69 | #beta = - torch.sqrt(M) * eigen_vecs.t() @ dKxx_dx_avg 70 | beta = - torch.sqrt(M) * torch.einsum('jki,kjm->jim',(eigen_vecs,dKxx_dx_avg)) 71 | beta *= 1/eigen_vals[:,:,0].unsqueeze(-1) 72 | 73 | return beta, eigen_vals, eigen_vecs 74 | 75 | def compute_score_gradients(self, x, xm = None): 76 | """ 77 | Computes the Spectral Stein Gradient Estimate (SSGE) for the 78 | score function. The SSGE is given by 79 | .. math:: 80 | \nabla_{xi} phi_j(x) = \frac{1}{\mu_j M} \sum_{m=1}^M \nabla_{xi}k(x,x^m) \phi_j(x^m) 81 | \beta_{ij} = -\frac{1}{M} \sum_{m=1}^M \nabla_{xi} phi_j (x^m) 82 | \g_i(x) = \sum_{j=1}^J \beta_{ij} \phi_j(x) 83 | :param x: (Tensor) Point at which the gradient is evaluated [N x D] 84 | :param xm: (Tensor) Samples for the kernel [M x D] 85 | :return: gradient estimate [N x D] 86 | """ 87 | if xm is None: 88 | xm = self.xm 89 | beta = self.beta 90 | eigen_vecs = self.eigen_vecs 91 | eigen_vals = self.eigen_vals 92 | else: 93 | beta,eigen_vals,eigen_vecs = self.compute_beta(xm) 94 | 95 | phi_x = self.nystrom_method(x, xm, eigen_vecs, eigen_vals) # [N x M] 96 | # assert beta.allclose(beta1), f"incorrect computation {beta - beta1}" 97 | g = phi_x @ beta # [N x D] 98 | return g 99 | 100 | -------------------------------------------------------------------------------- /data/toy_classification/oned_gaussian.py: -------------------------------------------------------------------------------- 1 | from data.dataset import Dataset 2 | from sklearn import datasets 3 | import numpy as np 4 | from matplotlib.colors import ListedColormap 5 | import matplotlib.pyplot as plt 6 | import torch 7 | import numpy as np 8 | import sklearn 9 | 10 | class oned_gaussian(Dataset): 11 | """An instance of this class shall represent a regression task where the 12 | input samples :math:`x` are drawn from a Gaussian with given mean and 13 | variance. 14 | 15 | Due to plotting functionalities, this class only supports 2D inputs and 16 | 1D outputs. 17 | 18 | Attributes: 19 | mean: Mean vector. 20 | cov: Covariance matrix. 21 | """ 22 | 23 | def __init__(self, rseed=1234, use_one_hot=True, noise=0.1, n_points = 300): 24 | """Generate a new dataset. 25 | 26 | The input data x for train and test samples will be drawn iid from the 27 | given Gaussian. Per default, the map function is the probability 28 | density of the given Gaussian: y = f(x) = p(x). 29 | 30 | Args: 31 | mean: The mean of the Gaussian. 32 | cov: The covariance of the Gaussian. 33 | num_train: Number of training samples. 34 | num_test: Number of test samples. 35 | map_function (optional): A function handle that receives input 36 | samples and maps them to output samples. If not specified, the 37 | density function will be used as map function. 38 | rseed (int): If ``None``, the current random state of numpy is used 39 | to generate the data. Otherwise, a new random state with the 40 | given seed is generated. 41 | """ 42 | super().__init__() 43 | 44 | if rseed is None: 45 | rand = np.random 46 | else: 47 | rand = np.random.RandomState(rseed) 48 | 49 | g_1 = torch.distributions.normal.Normal(torch.tensor([0.0]), torch.tensor([1.0])) 50 | g_2 = torch.distributions.normal.Normal(torch.tensor([10.0]), torch.tensor([1.0])) 51 | 52 | x = torch.cat([g_1.sample(torch.Size([n_points])), g_2.sample(torch.Size([n_points]))]) 53 | labels = np.repeat([0,1],n_points) 54 | 55 | train_x, test_x, train_y, test_y = sklearn.model_selection.train_test_split(x, labels, test_size = 0.33, shuffle=True, random_state = 42) 56 | 57 | in_data = np.vstack([train_x, test_x]) 58 | out_data = np.vstack([np.expand_dims(train_y, 1), np.expand_dims(test_y, 1)]) 59 | 60 | # Specify internal data structure. 61 | self._data['classification'] = True 62 | self._data['sequence'] = False 63 | self._data['in_data'] = in_data 64 | self._data['in_shape'] = [1] 65 | self._data['num_classes'] = 2 66 | if use_one_hot: 67 | out_data = self._to_one_hot(out_data) 68 | self._data['out_data'] = out_data 69 | self._data['out_shape'] = [2] 70 | self._data['train_inds'] = np.arange(train_x.shape[0]) 71 | self._data['test_inds'] = np.arange(train_x.shape[0], train_x.shape[0] + test_x.shape[0]) 72 | 73 | def get_identifier(self): 74 | """Returns the name of the dataset.""" 75 | return 'Moons_dataset' 76 | 77 | def get_input_mesh(self, x1_range=None, grid_size=1000): 78 | 79 | x1 = np.linspace(start=x1_range[0], stop=x1_range[1], num=grid_size) 80 | 81 | 82 | return x1 83 | 84 | def _plot_sample(self, fig, inner_grid, num_inner_plots, ind, inputs, 85 | outputs=None, predictions=None): 86 | colors = ListedColormap(['#FF0000', '#0000FF']) 87 | 88 | # Create plot 89 | fig = plt.figure(figsize=(15, 10)) 90 | ax = fig.add_subplot(111) 91 | 92 | x_train_0 = self.get_train_inputs() 93 | y_train_0 = self.get_train_outputs() 94 | x_test_0 = self.get_test_inputs() 95 | y_test_0 = self.get_test_outputs() 96 | 97 | ax.scatter(x_train_0[:, 0], x_train_0[:, 1], alpha=1, marker='o', c=np.argmax(y_train_0, 1), cmap=colors, 98 | edgecolors='k', s=50, label='Train') 99 | ax.scatter(x_test_0[:, 0], x_test_0[:, 1], alpha=0.6, marker='s', c=np.argmax(y_test_0, 1), cmap=colors, 100 | edgecolors='k', s=50, label='test') 101 | plt.title('Data', fontsize=30) 102 | plt.legend(loc=2, fontsize=30) 103 | plt.show() 104 | -------------------------------------------------------------------------------- /models/MLP.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy as np 4 | import math 5 | import torch.nn.functional as F 6 | from utils.utils import dnorm2 7 | 8 | class Net(nn.Module): 9 | """ 10 | Implementation of Fully connected neural network 11 | 12 | Args: 13 | layer_sizes(list): list containing the layer sizes 14 | classification(bool): if the net is used for a classification task 15 | act: activation function in the hidden layers 16 | out_act: activation function in the output layer, if None then linear 17 | bias(Bool): whether or not the net has biases 18 | """ 19 | 20 | def __init__(self, layer_sizes, classification = False, act=F.sigmoid,d_logits = False, out_act=None, bias=True, no_weights = True): 21 | super(Net, self).__init__() 22 | self.layer_sizes = layer_sizes 23 | self.layer_list = [] 24 | self.classification = classification 25 | self.bias = bias 26 | self.d_logits = d_logits 27 | self.ac = act 28 | self.out_act = out_act 29 | for l in range(len(layer_sizes[:-1])): 30 | layer_l = nn.Linear(layer_sizes[l], layer_sizes[l+1], bias=self.bias) 31 | self.add_module('layer_' + str(l), layer_l) 32 | 33 | self.num_params = sum(p.numel() for p in self.parameters()) 34 | 35 | self.param_shapes = [list(i.shape) for i in self.parameters()] 36 | 37 | self._weights = None 38 | if no_weights: 39 | return 40 | 41 | ### Define and initialize network weights. 42 | # Each odd entry of this list will contain a weight Tensor and each 43 | # even entry a bias vector. 44 | self._weights = nn.ParameterList() 45 | 46 | for i, dims in enumerate(self.param_shapes): 47 | self._weights.append(nn.Parameter(torch.Tensor(*dims), 48 | requires_grad=True)) 49 | 50 | 51 | for i in range(0, len(self._weights), 2 if bias else 1): 52 | if bias: 53 | self.init_params(self._weights[i], self._weights[i + 1]) 54 | else: 55 | self.init_params(self._weights[i]) 56 | 57 | 58 | def init_params(self,weights, bias=None): 59 | """Initialize the weights and biases of a linear or (transpose) conv layer. 60 | 61 | Note, the implementation is based on the method "reset_parameters()", 62 | that defines the original PyTorch initialization for a linear or 63 | convolutional layer, resp. The implementations can be found here: 64 | 65 | https://git.io/fhnxV 66 | Args: 67 | weights: The weight tensor to be initialized. 68 | bias (optional): The bias tensor to be initialized. 69 | """ 70 | nn.init.kaiming_uniform_(weights, a=math.sqrt(5)) 71 | if bias is not None: 72 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights) 73 | bound = 1 / math.sqrt(fan_in) 74 | nn.init.uniform_(bias, -bound, bound) 75 | 76 | 77 | def forward(self, x, weights=None, ret_pre = False): 78 | """Can be used to make the forward step and make predictions. 79 | 80 | Args: 81 | x(torch tensor): The input batch to feed the network. 82 | weights(list): A reshaped particle 83 | Returns: 84 | (tuple): Tuple containing: 85 | 86 | - **y**: The output of the network 87 | - **hidden** (optional): if out_act is not None also the linear output before activation is returned 88 | """ 89 | 90 | if weights is None: 91 | weights = self._weights 92 | else: 93 | shapes = self.param_shapes 94 | assert (len(weights) == len(shapes)) 95 | for i, s in enumerate(shapes): 96 | assert (np.all(np.equal(s, list(weights[i].shape)))) 97 | 98 | hidden = x 99 | 100 | if self.bias: 101 | num_layers = len(weights) // 2 102 | step_size = 2 103 | else: 104 | num_layers = len(weights) 105 | step_size = 1 106 | 107 | for l in range(0, len(weights), step_size): 108 | W = weights[l] 109 | if self.bias: 110 | b = weights[l + 1] 111 | else: 112 | b = None 113 | 114 | if l==len(weights)-2 and self.d_logits: 115 | pre_out = hidden 116 | distance_logits = dnorm2(pre_out, W) 117 | 118 | hidden = F.linear(hidden, W, bias=b) 119 | 120 | # Only for hidden layers. 121 | if l / step_size + 1 < num_layers: 122 | if self.ac is not None: 123 | hidden = self.ac(hidden) 124 | 125 | if self.d_logits: 126 | hidden = -distance_logits 127 | if self.out_act is not None: 128 | return self.out_act(hidden), hidden #needed so that i can use second output for training first for predict 129 | else: 130 | return hidden -------------------------------------------------------------------------------- /methods/f_SVGD.py: -------------------------------------------------------------------------------- 1 | import torch.autograd as autograd 2 | import torch 3 | """ 4 | In this file a lot of different SVGD implementations are collected, the basic structure of the class is the same of the 5 | standard SVGD. 6 | """ 7 | import time 8 | class Timer(object): 9 | def __init__(self, name=None, print_t = False): 10 | self.name = name 11 | self.print_t = print_t 12 | 13 | def __enter__(self): 14 | self.tstart = time.time() 15 | 16 | def __exit__(self, type, value, traceback): 17 | if self.print_t: 18 | print('Elapsed '+self.name+': %s' % (time.time() - self.tstart)) 19 | 20 | class f_s_SVGD: 21 | """ 22 | Implementation of functional space SVGD 23 | 24 | Args: 25 | P: instance of a distribution returning the log_prob, see distributions.py for examples 26 | K: kernel instance, see kernel.py for examples 27 | optimizer: instance of an optimizer SGD,Adam 28 | """ 29 | def __init__(self, P, K, optimizer,prior_grad_estim,config,ann_sch,pred_idx = 1,num_train = False, noise = False): 30 | self.P = P 31 | self.K = K 32 | self.optim = optimizer 33 | self.pge = prior_grad_estim 34 | self.gamma = config.gamma 35 | self.ann_schedule=ann_sch 36 | self.pred_idx = pred_idx 37 | self.num_train = num_train 38 | self.noise = noise 39 | 40 | 41 | def phi(self, W,X,T,step,X_add=None): 42 | """ 43 | Computes the update of the f-SVGD rule as: 44 | 45 | 46 | Args: 47 | W: particles 48 | X: input training batch 49 | T: label training batch 50 | 51 | Return: 52 | phi: the update to feed the optimizer 53 | driving force: first term in the update rule 54 | repulsive force: second term in the update rule 55 | """ 56 | if self.num_train: 57 | num_t = self.P.num_train 58 | else: 59 | num_t = 1 60 | 61 | 62 | W = W.detach().requires_grad_(True) 63 | 64 | # Score function 65 | log_prob, pred = self.P.log_prob(W, X, T, return_pred=True,pred_idx=self.pred_idx) 66 | score_func = autograd.grad(log_prob.sum(), pred[self.pred_idx], retain_graph=True)[0] 67 | 68 | if X_add is not None: 69 | pred_add = (self.P.ensemble.forward(X_add,W)[self.pred_idx]).view(W.shape[0],-1) 70 | else: 71 | pred_add = pred[self.pred_idx].view(W.shape[0],-1) #[n_part, classesxB] 72 | 73 | 74 | ############## Repulsive force ############## 75 | 76 | with Timer('Repulsive force:'): 77 | pred_k = pred_add 78 | K_f = self.K(pred_k, pred_k.detach()) 79 | grad_K = -autograd.grad(K_f.sum(), pred_k)[0] 80 | 81 | grad_K = grad_K.view(W.shape[0],-1) #needed only for weird kernels 82 | score_func = score_func.view(W.shape[0],-1) 83 | #pred = pred[0].view(W.shape[0],-1) #[n_part, classesxB] 84 | 85 | ############## Gradient functional prior ############## 86 | 87 | with Timer('Gradient prior:'): 88 | #pred = pred[self.pred_idx].view(W.shape[0],-1) #[n_part, classesxB] 89 | #pred_j = pred[self.pred_idx].view(W.shape[0],-1) #[n_part, classesxB] 90 | 91 | pred = pred[self.pred_idx].view(W.shape[0],-1) #[n_part, classesxB] 92 | 93 | 94 | w_prior = self.P.prior.sample(torch.Size([W.shape[0]])) 95 | 96 | prior_pred = self.P.ensemble.forward(X, w_prior)[self.pred_idx].reshape(W.shape[0],-1) # changed index here 97 | 98 | grad_prior = self.pge.compute_score_gradients(pred, prior_pred) # .mean(0) 99 | 100 | ############## Update rule ############## 101 | 102 | with Timer('Driving force:'): 103 | driv = K_f.matmul(score_func + grad_prior) 104 | 105 | 106 | if self.noise: 107 | lr = self.optim.state_dict()['param_groups'][0]['lr'] 108 | K_W_exp = torch.sqrt(2*K_f.repeat(pred.shape[1],1,1)/(W.size(0)*lr)) #K_XX.repeat(X.shape[1],1,1)*2 109 | langevin_noise = torch.randn_like(K_W_exp)*K_W_exp 110 | f_phi = (self.ann_schedule[step]*driv + num_t*grad_K) / W.size(0) + langevin_noise.sum(2).T 111 | else: 112 | f_phi = (self.ann_schedule[step]*driv + num_t*grad_K) / W.size(0) 113 | #f_phi = score_func + grad_prior 114 | with Timer('function to weight + Jacobian :'): 115 | w_phi = autograd.grad(pred,W,grad_outputs=f_phi,retain_graph=False)[0] 116 | #w_phi = autograd.grad(pred,W,grad_outputs=f_phi,retain_graph=False)[0] 117 | 118 | #with Timer('function to weight :'): 119 | # w_phi = torch.einsum('mbw,mb->mw', [jacob,f_phi]) 120 | return w_phi, self.ann_schedule[step]*driv, num_t*grad_K 121 | 122 | def step(self, W,X,T,step,X_add=None): 123 | """ 124 | Customization of the optimizer step where I am forcing the gradient to be the SVGD update rule 125 | 126 | Args: 127 | W: particles 128 | X: input training batch 129 | T: label training batch 130 | Return: 131 | driving force: first term in the update rule 132 | repulsive force: second term in the update rule 133 | """ 134 | self.optim.zero_grad() 135 | update = self.phi(W,X,T,step,X_add) 136 | W.grad = -update[0] 137 | #torch.nn.utils.clip_grad_norm_(W,0.1,2) 138 | self.optim.step() 139 | return update[1], update[2] -------------------------------------------------------------------------------- /utils/generate_results.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import torch 5 | from natsort import natsorted, index_natsorted 6 | from random import shuffle 7 | import seaborn as sns 8 | import matplotlib 9 | import matplotlib.pyplot as plt 10 | import argparse 11 | 12 | 13 | parser = argparse.ArgumentParser(description='results') 14 | parser.add_argument('--path', type=str, default=None, 15 | help='Folder where to find the results.') 16 | parser.add_argument('--seeds', type=int, default=5, 17 | help='Number of seeds used.') 18 | parser.add_argument('--subf', action='store_true', default=False, 19 | help='If the folder contains subfolders on which to interate') 20 | args = parser.parse_args() 21 | 22 | if os.path.exists(args.path+"/results.csv"): 23 | print('Old results found, removing...') 24 | os.remove(args.path+"/results.csv") 25 | 26 | if os.path.exists(args.path+"/params.csv"): 27 | os.remove(args.path+"/params.csv") 28 | 29 | 30 | cols = ['AUROC','AUR_IN','AUR_OUT', 'train_acc', 'test_acc', 'H_test', 'H_train', 'H_ood', 'std_ood', 'std_test', 'ECE', 'NLL', 'AUROC_std'] 31 | paths_l = [] 32 | name_l = [] 33 | for root, dirs, files in os.walk(args.path+"/", topdown=False): 34 | for name in files: 35 | if 'results.npy' in name and len(name)<31: 36 | print(name) 37 | path = os.path.join(root, name) 38 | paths_l.append(path) 39 | name = path.replace(args.path,'') 40 | name_l.append(name[:]) 41 | 42 | idx = index_natsorted(name_l) 43 | name_l = [name_l[i] for i in idx] 44 | paths_l = [paths_l[j] for j in idx] 45 | mean_res = [np.load(i) for i in paths_l] 46 | std_res = [np.load(i) for i in paths_l] 47 | name_cpopy = [] 48 | for n in name_l: 49 | index = n.find('seed') #stores the index of a substring or char 50 | name_cpopy.append(n[:index]) 51 | my_dict = {i:name_cpopy.count(i) for i in name_cpopy} 52 | name_cpopy = list(set(name_cpopy)) 53 | name_cpopy = natsorted(name_cpopy) 54 | 55 | g = [] 56 | for n in natsorted(name_cpopy): 57 | l=[] 58 | for p in paths_l: 59 | if n in p: 60 | #print(n) 61 | r = np.load(p,allow_pickle=True) 62 | l.append(r) 63 | g.append(l) 64 | 65 | mean_res = [np.mean(resu,0) for resu in g] 66 | res = [np.array(resu) for resu in g] 67 | 68 | std_res = [np.std(resu,0)/len(resu) for resu in g] 69 | results_20p = pd.DataFrame.from_dict(dict(zip(name_cpopy, mean_res)), orient='index') 70 | results_20p.columns = cols 71 | results_20p['ratio_entr'] = results_20p['H_ood']/results_20p['H_test'] 72 | results_20p['ratio_MD'] = results_20p['std_ood']/results_20p['std_test'] 73 | results_20p = results_20p.sort_values(by =['test_acc'] , ascending=False).round(3) 74 | 75 | ############# Computing error of entropy ratios ############## 76 | mean_ratios = [m_r[8]/m_r[9] for m_r in mean_res] 77 | cov_l = [np.cov(np.squeeze(np.stack([np.expand_dims(r[:,9],1),np.expand_dims(r[:,8],1)]),2))[0,1] for r in res] 78 | errors=[] 79 | for i in range(len(mean_ratios)): 80 | errors.append(1/mean_res[i][9]**2*(std_res[i][8]**2 - 2*mean_ratios[i]*cov_l[i] + mean_ratios[i]**2*std_res[i][9]**2 )) 81 | md_ratio_error = pd.DataFrame.from_dict(dict(zip(natsorted(name_cpopy) , np.sqrt(np.abs(errors))/5)), orient='index') 82 | md_ratio_error.columns = ['ratio_MD'] 83 | md_ratio_error = md_ratio_error.reindex(results_20p.index) 84 | 85 | ############# Computing error of std ratios ############## 86 | mean_ratios = [m_r[7]/m_r[5] for m_r in mean_res] 87 | cov_l = [np.cov(np.squeeze(np.stack([np.expand_dims(r[:,5],1),np.expand_dims(r[:,7],1)]),2))[0,1] for r in res] 88 | errors=[] 89 | for i in range(len(mean_ratios)): 90 | errors.append(1/mean_res[i][5]**2*(std_res[i][7]**2 - 2*mean_ratios[i]*cov_l[i] + mean_ratios[i]**2*std_res[i][5]**2 )) 91 | h_ratio_error = pd.DataFrame.from_dict(dict(zip(natsorted(name_cpopy) , np.sqrt(np.abs(errors))/5)), orient='index') 92 | h_ratio_error.columns = ['ratio_entr'] 93 | h_ratio_error = h_ratio_error.reindex(results_20p.index) 94 | 95 | std_20p = pd.DataFrame.from_dict(dict(zip(name_cpopy, std_res)), orient='index') 96 | std_20p.columns = cols 97 | std_20p = std_20p.reindex(results_20p.index) 98 | std_20p = pd.concat([std_20p, h_ratio_error,md_ratio_error], axis=1) 99 | #std_20p = std_20p.rename({'md_ratio_entr':'ratio_MD'},axis='columns') 100 | std_20p_array = std_20p.round(3).astype(str).values 101 | 102 | df3 = results_20p.round(3).applymap(str) + 'pm'+std_20p.round(3).applymap(str) 103 | df3 = df3[['AUROC','AUROC_std', 'test_acc','ratio_entr','ratio_MD','ECE','NLL']] 104 | 105 | df3.to_csv(args.path+'/results.csv') 106 | print('Results saved in:', args.path+'/results.csv') 107 | 108 | name_cpopy_res = name_cpopy 109 | 110 | paths_l = [] 111 | name_l = [] 112 | for root, dirs, files in os.walk(args.path+"/", topdown=False): 113 | for name in files: 114 | if 'parameters' in name and len(name)<31: 115 | #print(name) 116 | path = os.path.join(root, name) 117 | paths_l.append(path) 118 | name = path.replace(args.path,'') 119 | name_l.append(name[:]) 120 | #print(natsorted(name_l)) 121 | 122 | idx = index_natsorted(name_l) 123 | name_l = [name_l[i] for i in idx] 124 | paths_l = [paths_l[j] for j in idx] 125 | name_cpopy = [] 126 | for n in name_l: 127 | index = n.find('seed') #stores the index of a substring or char 128 | name_cpopy.append(n[:index]) 129 | #name_cpopy = list(set(name_cpopy)) 130 | name_cpopy = natsorted(name_cpopy) 131 | 132 | import pandas as pd 133 | import yaml 134 | dfs = [] 135 | for p in paths_l: 136 | with open(p, 'r') as f: 137 | d = pd.io.json.json_normalize(yaml.load(f)) 138 | #d.index(name_cpopy[0]) 139 | dfs.append(d) 140 | 141 | df = pd.concat(dfs) 142 | df = df.reset_index(drop = True) 143 | df_new = pd.DataFrame(data=df.values, index=name_cpopy, columns = list(df.columns)).drop(['comment','out_dir'],axis=1) 144 | print(df_new) 145 | # I take one every 5 rows because only the random seed is changing 146 | df_new.iloc[::args.seeds, :].reindex(name_cpopy_res).to_csv(args.path+'/params.csv') 147 | print('Params saved in:', args.path+'/params.csv') -------------------------------------------------------------------------------- /methods/SVGD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd as autograd 3 | 4 | """ 5 | In this file a lot of different SVGD implementations are collected, the basic structure of the class is the same of the 6 | standard SVGD. 7 | """ 8 | 9 | class SVGD: 10 | """ 11 | Implementation of SVGD 12 | 13 | Args: 14 | P: instance of a distribution returning the log_prob, see distributions.py for examples 15 | K: kernel instance, see kernel.py for examples 16 | optimizer: instance of an optimizer SGD,Adam 17 | """ 18 | def __init__(self, P, K, optimizer,config, ann_sch,num_train = False, noise = False): 19 | self.P = P 20 | self.K = K 21 | self.optim = optimizer 22 | self.gamma = config.gamma 23 | self.ann_schedule = ann_sch 24 | self.num_train = num_train 25 | self.noise = noise 26 | 27 | 28 | def phi(self, W,X,T,step): 29 | """ 30 | Computes the update of the SVGD rule 31 | 32 | Args: 33 | W: particles 34 | X: inputs training batch 35 | T: labels training batch 36 | 37 | Return: 38 | phi: the update to feed the optimizer 39 | driving force: first term in the update rule 40 | repulsive force: second term in the update rule 41 | """ 42 | 43 | if self.num_train: 44 | num_t = self.P.num_train 45 | else: 46 | num_t = 1 47 | 48 | W = W.detach().requires_grad_(True) 49 | 50 | #computing the driving force 51 | log_prob = self.P.log_prob(W,X,T) 52 | score_func = autograd.grad(log_prob.sum(), W)[0] 53 | 54 | #computing the repusive force 55 | K_W = self.K(W, W.detach()) 56 | grad_K = -autograd.grad(K_W.sum(), W)[0] 57 | 58 | if self.noise: 59 | lr = self.optim.state_dict()['param_groups'][0]['lr'] 60 | K_W_exp = torch.sqrt(2*K_W.repeat(W.shape[1],1,1)/(W.size(0)*lr)) #K_XX.repeat(X.shape[1],1,1)*2 61 | langevin_noise = torch.randn_like(K_W_exp)*K_W_exp 62 | phi = (self.ann_schedule[step]*K_W.detach().matmul(score_func) +num_t*grad_K) / W.size(0) + langevin_noise.sum(2).T 63 | else: 64 | phi = (self.ann_schedule[step]*K_W.detach().matmul(score_func) +num_t*grad_K) / W.size(0) 65 | 66 | 67 | return phi, self.ann_schedule[step]*K_W.detach().matmul(score_func), num_t*grad_K 68 | 69 | def step(self, W,X,T,step): 70 | """ 71 | Customization of the optimizer step where I am forcing the gradients to be instead the SVGD update rule 72 | 73 | Args: 74 | W: particles 75 | X: input training batch 76 | T: label training batch 77 | Return: 78 | driving force: first term in the update rule 79 | repulsive force: second term in the update rule 80 | """ 81 | self.optim.zero_grad() 82 | update = self.phi(W,X,T,step) 83 | W.grad = -update[0] 84 | self.optim.step() 85 | return update[1], update[2] 86 | 87 | class SVGLD: 88 | def __init__(self, P, K, optimizer,config, ann_sch, beta = 1.0, alpha = 1.0): 89 | self.P = P 90 | self.K = K 91 | self.optim = optimizer 92 | self.gamma = config.gamma 93 | self.beta = beta 94 | self.alpha = alpha #useful to remove the first additional score 95 | self.ann_schedule=ann_sch 96 | 97 | 98 | def phi(self, W,X,T,step): 99 | W = W.detach().requires_grad_(True) 100 | 101 | log_prob = self.P.log_prob(W,X,T) 102 | score_func = autograd.grad(log_prob.sum(), W)[0] 103 | 104 | K_W = self.K(W, W.detach()) 105 | grad_K = -autograd.grad(K_W.sum(), W)[0] 106 | 107 | driv = self.alpha/self.beta*score_func + K_W.detach().matmul(score_func)/W.size(0) 108 | rep = self.P.num_train*grad_K / W.size(0) 109 | 110 | phi = self.ann_schedule[step]*driv + rep 111 | lr = self.optim.state_dict()['param_groups'][0]['lr'] 112 | 113 | langevin_noise = torch.distributions.Normal(torch.zeros(W.shape[0]),torch.ones(W.shape[0])/torch.sqrt(self.beta*torch.tensor(lr))) 114 | noise = langevin_noise.sample().unsqueeze(1) 115 | phi += -noise 116 | 117 | return phi,self.ann_schedule[step]*driv,rep,-noise 118 | 119 | def step(self, W,X,T,step): 120 | self.optim.zero_grad() 121 | update = self.phi(W,X,T,step) 122 | W.grad = -update[0] 123 | self.optim.step() 124 | return update[1], update[2], update[3] 125 | 126 | class SGD: 127 | def __init__(self, P, optimizer): 128 | self.P = P 129 | self.optim = optimizer 130 | 131 | def phi(self, W,X,T): 132 | W = W.detach().requires_grad_(True) 133 | 134 | log_prob = self.P.log_prob(W,X,T) 135 | score_func = autograd.grad(log_prob.sum(), W)[0] 136 | 137 | phi = score_func 138 | 139 | return phi 140 | 141 | def step(self, W,X,T): 142 | self.optim.zero_grad() 143 | W.grad = -self.phi(W,X,T) 144 | self.optim.step() 145 | 146 | 147 | def step(self, W,X,T,step,X_add = None): 148 | self.optim.zero_grad() 149 | update = self.phi(W,X,T,step, X_add) 150 | W.grad = -update[0] 151 | self.optim.step() 152 | return update[1], update[2] 153 | 154 | class SGLD: 155 | def __init__(self, P, K, optimizer, device): 156 | self.P = P 157 | self.K = K 158 | self.optim = optimizer 159 | self.device = device 160 | 161 | def phi(self, W,X,T): 162 | W = W.detach().requires_grad_(True) 163 | 164 | log_prob = self.P.log_prob(W,X,T) 165 | score_func = autograd.grad(log_prob.sum(), W)[0] 166 | lr = self.optim.state_dict()['param_groups'][0]['lr'] 167 | langevin_noise = torch.distributions.Normal(torch.zeros(W.shape[1]).to(self.device),(torch.ones(W.shape[1])/torch.sqrt(torch.tensor(lr))).to(self.device)) 168 | phi = score_func + langevin_noise.sample(torch.Size([W.size(0)])) 169 | 170 | return phi 171 | 172 | def step(self, W,X,T): 173 | self.optim.zero_grad() 174 | W.grad = -self.phi(W,X,T) 175 | self.optim.step() 176 | -------------------------------------------------------------------------------- /training/training_1d_regre.py: -------------------------------------------------------------------------------- 1 | import seaborn as sns 2 | import torch 3 | 4 | from utils.distributions import Unorm_post 5 | from methods.method_utils import create_method 6 | from utils.utils import plot_predictive_distributions 7 | 8 | sns.set() 9 | 10 | def train(data, ensemble, device, config,writer): 11 | """Train the particles using a specific ensemble. 12 | 13 | Args: 14 | data: A DATASET instance. 15 | mnet: The model of the main network. 16 | device: Torch device (cpu or gpu). 17 | config: The command line arguments. 18 | writer: The tensorboard summary writer.s 19 | """ 20 | 21 | # -------------------------------------------------------------------------------- 22 | # CREATING TESTING 23 | # -------------------------------------------------------------------------------- 24 | x_test_0 = torch.tensor(data.get_test_inputs(), dtype=torch.float).to(device) 25 | y_test_0 = torch.tensor(data.get_test_outputs(), dtype=torch.float).to(device) 26 | 27 | # -------------------------------------------------------------------------------- 28 | # SETTING PROBABILISTIC and UTILS 29 | # -------------------------------------------------------------------------------- 30 | names = {'f_s_SVGD': 'f-SVGD', 31 | 'mixed_f_p_SVGD':'h-SVGD', 32 | 'f_p_SVGD':'fw-SVGD', 33 | 'ssge_WGD': 'ssge-WGD', 34 | 'kde_WGD':'kde-WGD', 35 | 'sge_WGD':'sge-WGD', 36 | 'SGD':'Deep Ensemble', 37 | 'SGLD': 'pSGLD', 38 | 'SVGD':'w-SVGD', 39 | 'ssge_f_WGD':'ssge-fWGD', 40 | 'sge_f_WGD':'sge-fWGD', 41 | 'kde_f_WGD':'kde-fWGD'} 42 | 43 | W = ensemble.particles 44 | samples = [] 45 | optimizer = torch.optim.Adam([W], config.lr, weight_decay=config.weight_decay, 46 | betas=[config.adam_beta1, 0.999]) 47 | #optimizer = torch.optim.SGD([W], config.lr) 48 | prior = torch.distributions.normal.Normal(torch.zeros(ensemble.net.num_params).to(device), 49 | torch.ones(ensemble.net.num_params).to(device) * config.prior_variance) 50 | if config.method == 'f_s_SVGD': 51 | add_prior = False 52 | else: 53 | add_prior = True 54 | 55 | 56 | #w_prior = prior.sample(torch.Size([config.n_particles])) 57 | 58 | 59 | #prior_pred = ensemble.forward(x_test_0[:4],w_prior)[0].reshape(config.n_particles,-1) 60 | 61 | #ssge_k = RBF(sigma = 1) 62 | 63 | #ssge = SpectralSteinEstimator(0.9,None,ssge_k,prior_pred) 64 | 65 | P = Unorm_post(ensemble, prior, config, data.num_train_samples,add_prior) 66 | #log_scale = torch.log2(torch.tensor(data_train.out_shape[0], dtype=torch.float)) 67 | variance_noise = 0.25 68 | noise = torch.distributions.normal.Normal(torch.zeros(data.in_shape[0]).to(device), 69 | torch.ones(data.in_shape[0]).to(device)*variance_noise) 70 | 71 | # -------------------------------------------------------------------------------- 72 | # SVGD ALGORITHM SPECIFICATIONS 73 | # -------------------------------------------------------------------------------- 74 | 75 | method = create_method(config, P, optimizer, device = device) 76 | #K = RBF() 77 | #method = SVGD(P, ssge_k, optimizer,ssge, prior) 78 | 79 | # -------------------------------------------------------------------------------- 80 | # SVGD TRAINING 81 | # -------------------------------------------------------------------------------- 82 | 83 | driving_l = [] 84 | repulsive_l = [] 85 | print('-------------------------'+'Start training'+'-------------------------') 86 | for i in range(config.epochs): 87 | optimizer.zero_grad() 88 | 89 | batch_train = data.next_train_batch(config.batch_size) 90 | batch_test = data.next_test_batch(config.batch_size) 91 | X = data.input_to_torch_tensor(batch_train[0], device, mode='train') 92 | T = data.output_to_torch_tensor(batch_train[1], device, mode='train') 93 | X_t = data.input_to_torch_tensor(batch_test[0], device, mode='train') 94 | T_t = data.output_to_torch_tensor(batch_test[1], device, mode='train') 95 | 96 | 97 | if config.method == 'SGD' or config.method == 'SGLD': 98 | method.step(W, X, T) 99 | elif config.method == 'f_p_SVGD' or config.method == 'mixed_f_p_SVGD' or config.method == 'f_s_SVGD' or config.method == 'f_SGD': 100 | #noise_samples = noise.sample(torch.Size([config.batch_size])) 101 | if config.where_repulsive == 'train': 102 | driving,repulsive = method.step(W, X, T,i, None) 103 | elif config.where_repulsive == 'noise': 104 | blurred_train = X+noise.sample((X.shape[0],)) 105 | driving,repulsive = method.step(W,X,T,i,blurred_train) 106 | elif config.where_repulsive == 'test': 107 | driving,repulsive = method.step(W,X,T,i,X_t) 108 | 109 | #driving,repulsive = method.step(W, X, T,i, None) 110 | elif config.method == 'SVGLD': 111 | driving,repulsive,langevin_noise = method.step(W, X, T,i) 112 | 113 | else: 114 | driving,repulsive = method.step(W, X, T,i) 115 | 116 | if hasattr(method, 'ann_schedule'): 117 | writer.add_scalar('train/annealing', method.ann_schedule[i], i) 118 | 119 | if i % 1000 == 0: 120 | train_loss, train_pred = P.log_prob(W, X, T, return_loss=True) 121 | test_loss, test_pred = P.log_prob(W, x_test_0, y_test_0, return_loss=True) 122 | writer.add_scalar('train/train_mse', train_loss, i) 123 | writer.add_scalar('test/test_loss', test_loss, i) 124 | if 'driving' in locals(): 125 | writer.add_scalar('train/driving_force', torch.mean(driving.abs()), i) 126 | writer.add_scalar('train/repulsive_force', torch.mean(repulsive.abs()), i) 127 | writer.add_scalar('train/forces_ratio', torch.mean(repulsive.abs())/torch.mean(driving.abs()), i) 128 | if config.method == 'SVGLD': 129 | writer.add_scalar('train/langevin_noise', torch.mean(langevin_noise.abs()), i) 130 | #writer.add_scalar('train/bandwith', K.h, i) 131 | 132 | pred_tensor = ensemble.forward(torch.tensor(x_test_0, dtype=torch.float))[0] 133 | plot_predictive_distributions(config,writer,i,data, x_test_0.cpu().squeeze(), pred_tensor.cpu().mean(0).squeeze(), 134 | pred_tensor.cpu().std(0).squeeze(), save_fig=True, 135 | name=names[config.method]) 136 | print('Train iter:',i, ' train mse:', train_loss, 'test mse', test_loss, flush = True) 137 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import argparse 3 | 4 | def configuration(args = None): 5 | parser = argparse.ArgumentParser(description='SVGD ' + 6 | 'Bayesian neural networks.') 7 | 8 | tgroup = parser.add_argument_group('Training options') 9 | tgroup.add_argument('--epochs', type=int, metavar='N', default=5000, 10 | help='Number of training epochs. '+ 11 | 'Default: %(default)s.') 12 | tgroup.add_argument('--batch_size', type=int, metavar='N', default=128, 13 | help='Training batch size. Default: %(default)s.') 14 | tgroup.add_argument('--lr', type=float, default=1e-2, 15 | help='Learning rate of optimizer. Default: ' + 16 | '%(default)s.') 17 | tgroup.add_argument('--momentum', type=float, default=0.0, 18 | help='Momentum of the optimizer. ' + 19 | 'Default: %(default)s.') 20 | tgroup.add_argument('--adam_beta1', type=float, default=0.9, 21 | help='The "beta1" parameter when using torch.optim.' + 22 | 'Adam as optimizer. Default: %(default)s.') 23 | tgroup.add_argument('--weight_decay', type=float, default=0, 24 | help='Weight decay of the optimizer(s). Default: ' + 25 | '%(default)s.') 26 | tgroup.add_argument('--method', type=str, default='SVGD', 27 | help='Method for optimization, options: SVGD,SGD,SGLD,SVGD_debug. Default: ' + 28 | '%(default)s.') 29 | tgroup.add_argument('--noise', action='store_true', default=False, 30 | help='Flag to enable noise injected in the method') 31 | tgroup.add_argument('--dataset', type=str, default='moons', 32 | help='Dataset, options: toy_regression, moons, twod_gaussian, oned_gaussian. Default: ' + 33 | '%(default)s.') 34 | tgroup.add_argument('--optim', type=str, default='Adam', 35 | help='Otimizer, options: Adam,SGD. Default: ' + 36 | '%(default)s.') 37 | tgroup.add_argument('--clip_grad_value', type=float, default=-1, 38 | help='If not "-1", gradients will be clipped using ' + 39 | '"torch.nn.utils.clip_grad_value_". Default: ' + 40 | '%(default)s.') 41 | tgroup.add_argument('--n_particles', type=int, default=5, 42 | help='Number of particles used for the approximation of the gradient flow') 43 | 44 | sgroup = parser.add_argument_group('Network options') 45 | sgroup.add_argument('--num_hidden', type=int, metavar='N', default=1, 46 | help='Number of hidden layer in the (student) ' + 47 | 'network. Default: %(default)s.') 48 | sgroup.add_argument('--size_hidden', type=int, metavar='N', default=10, 49 | help='Number of units in each hidden layer of the ' + 50 | '(student) network. Default: %(default)s.') 51 | sgroup.add_argument('--num_train_samples', type=int, default=20, 52 | help='Number of data training points.') 53 | sgroup.add_argument('--prior_variance', type=float, default=1., 54 | help='Variance of the Gaussian prior. ' + 55 | 'Default: %(default)s.') 56 | sgroup.add_argument('--pred_dist_std', type=float, default=1., 57 | help='The standard deviation of the predictive ' + 58 | 'distribution. Note, this value should be ' + 59 | 'fixed and reasonable for a given dataset.' + 60 | 'Default: %(default)s.') 61 | 62 | mgroup = parser.add_argument_group('Miscellaneous options') 63 | mgroup.add_argument('--use_cuda', action='store_true', 64 | help='Flag to enable GPU usage.') 65 | mgroup.add_argument('--random_seed', type=int, metavar='N', default=42, 66 | help='Random seed. Default: %(default)s.') 67 | mgroup.add_argument('--data_random_seed', type=int, metavar='N', default=42, 68 | help='Data random seed. Default: %(default)s.') 69 | mgroup.add_argument('--dont_show_plot', action='store_false', 70 | help='Dont show the final regression results as plot.' + 71 | 'Note, only applies to 1D regression tasks.') 72 | 73 | dout_dir = './out/'+datetime.now().strftime('%Y-%m-%d')+'/run_' + datetime.now().strftime('%Y-%m-%d_%H-%M-%S') 74 | exp_dir = 'exp_'+datetime.now().strftime('%Y-%m-%d_%H-%M') 75 | 76 | mgroup.add_argument('--out_dir', type=str, default=dout_dir, 77 | help='Where to store the outputs of this simulation.') 78 | mgroup.add_argument('--comment', type=str, default=dout_dir, 79 | help='Comment for the running experiment.') 80 | mgroup.add_argument('--show_plots', action='store_true', 81 | help='Whether plots should be shown.') 82 | 83 | expgroup = parser.add_argument_group('Experiments options') 84 | expgroup.add_argument('--annealing_steps', type=int, default=0, 85 | help='Annealing steps. ' + 86 | 'Default: %(default)s.') 87 | expgroup.add_argument('--keep_samples', type=float, default=0, 88 | help='Keep samples during training ' + 89 | 'Default: %(default)s.') 90 | expgroup.add_argument('--save_particles', type=float, default=0, 91 | help='Save particles in the end or during training ' + 92 | 'Default: %(default)s.') 93 | expgroup.add_argument('--gamma', type=float, default=1., 94 | help='SVGD scaling forces ' + 95 | 'Default: %(default)s.') 96 | expgroup.add_argument('--ann_sch', type=str, default='linear', 97 | help='Annealing schedule options. Default: ' + 98 | '%(default)s.') 99 | expgroup.add_argument('--logit_soft', type=int, default=0, 100 | help='If 1 functional SVGD on logit, if 0 softmax. ' + 101 | 'Default: %(default)s.') 102 | expgroup.add_argument('--where_repulsive', type=str, default='train', 103 | help='If train functional repulsion on train, if test on test, if noise on noisy training points. ' + 104 | 'Default: %(default)s.') 105 | expgroup.add_argument('--num_train', action='store_true', default=False, 106 | help='Flag to enable GPU usage.') 107 | expgroup.add_argument('--exp_dir', type=str, default=exp_dir, 108 | help='directory for all run same experiment') 109 | 110 | 111 | if args is None: 112 | return parser.parse_args() 113 | else: 114 | return parser.parse_args(args = args) 115 | -------------------------------------------------------------------------------- /data/toy_regression/regression1d_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from matplotlib.pyplot import cm 4 | from warnings import warn 5 | 6 | from data.dataset import Dataset 7 | 8 | class ToyRegression(Dataset): 9 | """An instance of this class represents a simple regression task. 10 | 11 | Attributes: (additional to baseclass) 12 | train_inter(tuple): tuple representing the interval in which sample training datapoints. 13 | num_train(int): number of training datapoints. 14 | test_inter(tuple): tuple representing the interval in which sample testing datapoints. 15 | blob(list): used to create blob regression task, list containing 4 extremes of two intervals in which sample 16 | training datapoints 17 | num_test(int) number of tasting datapoints 18 | val_inter(tuple): 19 | num_val(int): 20 | map_function: function that maps inputs to output 21 | std(float): std of gaussian noise to add to the outputs 22 | rseed(int): random state 23 | 24 | """ 25 | def __init__(self, train_inter=[-10, 10], num_train=20, 26 | test_inter=[-10, 10], blob = None, num_test=80, val_inter=None, 27 | num_val=None, map_function=lambda x : x, std=0, rseed=None): 28 | 29 | super().__init__() 30 | 31 | assert(val_inter is None and num_val is None or \ 32 | val_inter is not None and num_val is not None) 33 | 34 | if rseed is None: 35 | rand = np.random 36 | else: 37 | rand = np.random.seed(rseed) 38 | if blob is None: 39 | train_x = rand.uniform(low=train_inter[0], high=train_inter[1], 40 | size=(num_train, 1)) 41 | else: 42 | train_x = np.vstack([np.random.uniform(low=i[0], high=i[1],size=(int(num_train/len(blob)), 1)) for i in blob]) 43 | #train_x = np.asarray([[-0.8, -0.1, 0.02, 0.2, 0.6, 0.8]]).T 44 | #num_train = train_x.shape[0] 45 | test_x = np.linspace(start=test_inter[0], stop=test_inter[1], 46 | num=num_test).reshape((num_test, 1)) 47 | 48 | train_y = map_function(train_x) 49 | test_y = map_function(test_x) 50 | 51 | # Perturb training outputs. 52 | if std > 0: 53 | train_eps = np.random.normal(loc=0.0, scale=std, size=(num_train, 1)) 54 | train_y += train_eps 55 | 56 | # Create validation data if requested. 57 | if num_val is not None: 58 | val_x = np.linspace(start=val_inter[0], stop=val_inter[1], 59 | num=num_val).reshape((num_val, 1)) 60 | val_y = map_function(val_x) 61 | 62 | in_data = np.vstack([train_x, test_x, val_x]) 63 | out_data = np.vstack([train_y, test_y, val_y]) 64 | else: 65 | in_data = np.vstack([train_x, test_x]) 66 | out_data = np.vstack([train_y, test_y]) 67 | 68 | # Specify internal data structure. 69 | self._data['classification'] = False 70 | self._data['sequence'] = False 71 | self._data['in_data'] = in_data 72 | self._data['in_shape'] = [1] 73 | self._data['out_data'] = out_data 74 | self._data['out_shape'] = [1] 75 | self._data['train_inds'] = np.arange(num_train) 76 | self._data['test_inds'] = np.arange(num_train, num_train + num_test) 77 | 78 | if num_val is not None: 79 | n_start = num_train + num_test 80 | self._data['val_inds'] = np.arange(n_start, n_start + num_val) 81 | 82 | self._map = map_function 83 | self._train_inter = train_inter 84 | self._test_inter = test_inter 85 | self._val_inter = val_inter 86 | 87 | @property 88 | def train_x_range(self): 89 | """Getter for read-only attribute train_x_range.""" 90 | return self._train_inter 91 | 92 | @property 93 | def test_x_range(self): 94 | """Getter for read-only attribute test_x_range.""" 95 | return self._test_inter 96 | 97 | @property 98 | def val_x_range(self): 99 | """Getter for read-only attribute val_x_range.""" 100 | return self._val_inter 101 | 102 | def _get_function_vals(self, num_samples=100, x_range=None): 103 | """Get real function values for x values in a range that 104 | covers the test and training data. These values can be used to plot the 105 | ground truth function. 106 | 107 | Args: 108 | num_samples: Number of samples to be produced. 109 | x_range: If a specific range should be used to gather function 110 | values. 111 | 112 | Returns: 113 | x, y: Two numpy arrays containing the corresponding x and y values. 114 | """ 115 | if x_range is None: 116 | min_x = min(self._train_inter[0], self._test_inter[0]) 117 | max_x = max(self._train_inter[1], self._test_inter[1]) 118 | if self.num_val_samples > 0: 119 | min_x = min(min_x, self._val_inter[0]) 120 | max_x = max(max_x, self._val_inter[1]) 121 | else: 122 | min_x = x_range[0] 123 | max_x = x_range[1] 124 | 125 | slack_x = 0.05 * (max_x - min_x) 126 | 127 | sample_x = np.linspace(start=min_x-slack_x, stop=max_x+slack_x, 128 | num=num_samples).reshape((num_samples, 1)) 129 | sample_y = self._map(sample_x) 130 | 131 | return sample_x, sample_y 132 | 133 | def plot_dataset(self, show=True): 134 | """Plot the whole dataset. 135 | 136 | Args: 137 | show: Whether the plot should be shown. 138 | """ 139 | 140 | train_x = self.get_train_inputs().squeeze() 141 | train_y = self.get_train_outputs().squeeze() 142 | 143 | test_x = self.get_test_inputs().squeeze() 144 | test_y = self.get_test_outputs().squeeze() 145 | 146 | if self.num_val_samples > 0: 147 | val_x = self.get_val_inputs().squeeze() 148 | val_y = self.get_val_outputs().squeeze() 149 | 150 | sample_x, sample_y = self._get_function_vals() 151 | 152 | # The default matplotlib setting is usually too high for most plots. 153 | plt.locator_params(axis='y', nbins=2) 154 | plt.locator_params(axis='x', nbins=6) 155 | 156 | plt.plot(sample_x, sample_y, color='k', label='f(x)', 157 | linestyle='dashed', linewidth=.5) 158 | plt.scatter(train_x, train_y, color='r', label='Train') 159 | plt.scatter(test_x, test_y, color='b', label='Test', alpha=0.8) 160 | if self.num_val_samples > 0: 161 | plt.scatter(val_x, val_y, color='g', label='Val', alpha=0.5) 162 | plt.legend() 163 | plt.title('1D-Regression Dataset') 164 | plt.xlabel('$x$') 165 | plt.ylabel('$y$') 166 | 167 | if show: 168 | plt.show() 169 | 170 | 171 | def get_identifier(self): 172 | """Returns the name of the dataset.""" 173 | return '1DRegression' 174 | 175 | def _plot_sample(self, fig, inner_grid, num_inner_plots, ind, inputs, 176 | outputs=None, predictions=None): 177 | 178 | raise NotImplementedError('TODO implement') 179 | 180 | if __name__ == '__main__': 181 | pass 182 | 183 | 184 | -------------------------------------------------------------------------------- /utils/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | from torch.distributions.multivariate_normal import MultivariateNormal 5 | 6 | class Unorm_post(): 7 | """ 8 | Implementation of unnormalized posterior for a neural network model. It assume gaussian likelihood with variance 9 | config.pred_dist_std. The prior can be freely specified, the only requirement is a log_prob method to return the 10 | log probability of the particles. 11 | 12 | Args: 13 | ensemble: ensemble instance from MLP.py 14 | prior: prior instance from torch.distributions or custom, .log_prob() method is required 15 | config: Command-line arguments. 16 | n_train: number of training datapoints to rescale the likelihood 17 | 18 | """ 19 | def __init__(self, ensemble, prior, config, n_train,add_prior = True): 20 | self.prior = prior 21 | self.ensemble = ensemble 22 | self.config = config 23 | self.num_train = n_train 24 | self.add_prior = add_prior 25 | 26 | 27 | def log_prob(self, particles, X, T, return_loss=False, return_pred = False, pred_idx = 1): 28 | pred = self.ensemble.forward(X, particles) 29 | 30 | if self.ensemble.net.classification: 31 | if pred_idx == 1: 32 | #loss = -(T.expand_as(pred[1])*F.log_softmax(pred[1],2)).sum((1,2))/X.shape[0] 33 | #loss = (-(T.expand_as(pred[1])*F.log_softmax(pred[1],2))).max(2)[0].sum(1)/X.shape[0] 34 | loss = torch.stack([F.nll_loss(F.log_softmax(p), T.argmax(1)) for p in pred[1]]) 35 | else: 36 | #loss = -(T.expand_as(pred[0]) * torch.log(pred[0]+1e-15)).sum((1, 2)) / X.shape[0] 37 | #loss = -(torch.log(pred[0]+1e-15)[T.expand_as(pred[0]).type(torch.ByteTensor)].reshape(pred[0].shape[:-1])).sum(1)/ X.shape[0] 38 | loss = (-(T.expand_as(pred[1])*torch.log(pred[0]+1e-15))).max(2)[0].sum(1)/X.shape[0] 39 | 40 | #pred = F.softmax(pred[1],2) #I have to do this to allow derivative and to not have nans 41 | else: 42 | #loss = 0.5*torch.mean(F.mse_loss(prebpd[0], T, reduction='none'), 1) 43 | loss = 0.5*torch.mean((T.expand_as(pred[0])-pred[0])**2,1) 44 | 45 | 46 | ll = -loss*self.num_train / self.config.pred_dist_std ** 2 47 | 48 | if particles is None: 49 | particles = self.ensemble.particles 50 | 51 | if self.add_prior: 52 | log_prob = torch.add(self.prior.log_prob(particles).sum(1), ll) 53 | else: 54 | log_prob = ll 55 | # log_prob = ll 56 | if return_loss: 57 | return torch.mean(loss),pred[0] 58 | elif return_pred: 59 | return log_prob,pred #0 softmax, 1 is logit 60 | else: 61 | return log_prob 62 | 63 | class Unorm_post_hyper(): 64 | """ 65 | Implementation of unnormalized posterior for a neural network model. It assume gaussian likelihood with variance 66 | config.pred_dist_std. The prior can be freely specified, the only requirement is a log_prob method to return the 67 | log probability of the particles. 68 | 69 | Args: 70 | ensemble: ensemble instance from MLP.py 71 | prior: prior instance from torch.distributions or custom, .log_prob() method is required 72 | config: Command-line arguments. 73 | n_train: number of training datapoints to rescale the likelihood 74 | 75 | """ 76 | def __init__(self, ensemble, prior, config, n_train,add_prior = True): 77 | self.priors = prior 78 | self.ensemble = ensemble 79 | self.config = config 80 | self.num_train = n_train 81 | self.add_prior = add_prior 82 | 83 | 84 | def log_prob(self, particles, X, T, return_loss=False, return_pred = False, pred_idx = 1): 85 | pred = self.ensemble.forward(X, particles) 86 | 87 | if self.ensemble.net.classification: 88 | if pred_idx == 1: 89 | #loss = -(T.expand_as(pred[1])*F.log_softmax(pred[1],2)).sum((1,2))/X.shape[0] 90 | #loss = (-(T.expand_as(pred[1])*F.log_softmax(pred[1],2))).max(2)[0].sum(1)/X.shape[0] 91 | loss = torch.stack([F.nll_loss(F.log_softmax(p), T.argmax(1)) for p in pred[1]]) 92 | else: 93 | #loss = -(T.expand_as(pred[0]) * torch.log(pred[0]+1e-15)).sum((1, 2)) / X.shape[0] 94 | #loss = -(torch.log(pred[0]+1e-15)[T.expand_as(pred[0]).type(torch.ByteTensor)].reshape(pred[0].shape[:-1])).sum(1)/ X.shape[0] 95 | loss = (-(T.expand_as(pred[1])*torch.log(pred[0]+1e-15))).max(2)[0].sum(1)/X.shape[0] 96 | 97 | #pred = F.softmax(pred[1],2) #I have to do this to allow derivative and to not have nans 98 | else: 99 | #loss = 0.5*torch.mean(F.mse_loss(prebpd[0], T, reduction='none'), 1) 100 | loss = 0.5*torch.mean((T.expand_as(pred[0])-pred[0])**2,1) 101 | 102 | 103 | ll = -loss*self.num_train / self.config.pred_dist_std ** 2 104 | 105 | if particles is None: 106 | particles = self.ensemble.particles 107 | 108 | log_priors = [] 109 | for ind,p in enumerate(particles): 110 | log_priors.append(self.priors[ind].log_prob(p).sum()) 111 | 112 | log_prob = torch.add(torch.stack(log_priors), ll) 113 | 114 | # log_prob = ll 115 | if return_loss: 116 | return torch.mean(loss),pred[0] 117 | elif return_pred: 118 | return log_prob,pred #0 softmax, 1 is logit 119 | else: 120 | return log_prob 121 | 122 | 123 | class Gaus_mix_multi: 124 | 125 | def __init__(self, mu_1=-3., mu_2=3., sigma_1=1., sigma_2=1.): 126 | self.prior_1 = MultivariateNormal(torch.tensor([7., 0.]), sigma_1 * torch.eye(2)) 127 | self.prior_2 = MultivariateNormal(torch.tensor([-7., 0.]), sigma_2 * torch.eye(2)) 128 | self.prior_3 = MultivariateNormal(torch.tensor([0., 7.]), sigma_1 * torch.eye(2)) 129 | self.prior_4 = MultivariateNormal(torch.tensor([0., -7.]), sigma_1 * torch.eye(2)) 130 | 131 | def log_prob(self, z): 132 | log_prob = torch.log( 133 | 0.25 * torch.exp(self.prior_1.log_prob(z)) + 0.25 * torch.exp(self.prior_2.log_prob(z)) + 0.25 * torch.exp( 134 | self.prior_3.log_prob(z)) + 0.25 * torch.exp(self.prior_4.log_prob(z))) 135 | return log_prob 136 | 137 | def sample(self, n_samples): 138 | s = [] 139 | for i in range(n_samples): 140 | a = np.random.uniform() 141 | if a < 0.25: 142 | s.append(self.prior_1.sample(torch.Size([1])).detach().numpy()[0]) 143 | elif a > 0.25 and a < 0.5: 144 | s.append(self.prior_2.sample(torch.Size([1])).detach().numpy()[0]) 145 | elif a > 0.5 and a < 0.75: 146 | s.append(self.prior_3.sample(torch.Size([1])).detach().numpy()[0]) 147 | else: 148 | s.append(self.prior_4.sample(torch.Size([1])).detach().numpy()[0]) 149 | return np.stack(s) 150 | 151 | 152 | class Gaus_mix_multi_2: 153 | 154 | def __init__(self, mu_1=-5., mu_2=3., sigma_1=1., sigma_2=1.): 155 | self.prior_1 = MultivariateNormal(mu_1 * torch.ones(2), sigma_1 * torch.eye(2)) 156 | self.prior_2 = MultivariateNormal(mu_2 * torch.ones(2), sigma_2 * torch.eye(2)) 157 | 158 | def log_prob(self, z): 159 | log_prob = torch.log(0.5 * torch.exp(self.prior_1.log_prob(z)) + 0.5 * torch.exp(self.prior_2.log_prob(z))) 160 | return log_prob 161 | 162 | def sample(self, n_samples): 163 | s = [] 164 | for i in range(n_samples): 165 | a = np.random.uniform() 166 | if a < 0.5: 167 | s.append(self.prior_1.sample(torch.Size([1])).detach().numpy()[0]) 168 | else: 169 | s.append(self.prior_2.sample(torch.Size([1])).detach().numpy()[0]) 170 | return np.stack(s) -------------------------------------------------------------------------------- /data/toy_classification/twod_gaussian.py: -------------------------------------------------------------------------------- 1 | from data.dataset import Dataset 2 | from sklearn import datasets 3 | import numpy as np 4 | from matplotlib.colors import ListedColormap 5 | import matplotlib.pyplot as plt 6 | import torch 7 | import numpy as np 8 | import sklearn 9 | from torch.distributions.multivariate_normal import MultivariateNormal 10 | import torch.distributions as D 11 | 12 | 13 | class twod_gaussian(Dataset): 14 | def __init__(self, rseed=1234, use_one_hot=True, n_train = 300, n_test = 100, mu=[], sigma=[]): 15 | """Generate a new dataset. 16 | 17 | The input data x for train and test samples will be drawn iid from the 18 | given Gaussian. 19 | 20 | Args: 21 | rseed (int): If ``None``, the current random state of numpy is used 22 | to generate the data. Otherwise, a new random state with the 23 | given seed is generated. 24 | use_one_hot(bool): If True one hot encoding is applied 25 | n_train (int): Number of points per component of the mixture. 26 | n_test (int): Number of points per component of the mixture. 27 | mu (list): The means of the Gaussians, in an empty list is given the means are equally spaced on a ring of radius 1. 28 | sigma (list): List of scalars for the variance of the diagonal of the covariance matrix 29 | 30 | """ 31 | super().__init__() 32 | 33 | if rseed is None: 34 | rand = np.random 35 | else: 36 | rand = np.random.seed(rseed) 37 | torch.manual_seed(rseed) 38 | 39 | 40 | components = [] 41 | 42 | if len(mu) == 0: 43 | mu = self.circle_points(n=len(sigma)) 44 | 45 | classes = len(sigma) 46 | for i in zip(mu,sigma): 47 | components.append(D.Normal( 48 | torch.tensor(i[0],dtype=torch.float), torch.tensor(i[1],dtype=torch.float))) 49 | 50 | train_x = torch.cat([g.sample(torch.Size([n_train])) for g in components]).numpy() 51 | train_y = np.repeat(range(classes),n_train) 52 | 53 | test_x = torch.cat([g.sample(torch.Size([n_test])) for g in components]).numpy() 54 | test_y = np.repeat(range(classes),n_test) 55 | 56 | 57 | in_data = np.vstack([train_x, test_x]) 58 | out_data = np.vstack([np.expand_dims(train_y, 1), np.expand_dims(test_y, 1)]) 59 | 60 | # Specify internal data structure. 61 | self._data['classification'] = True 62 | self._data['sequence'] = False 63 | self._data['in_data'] = in_data 64 | self._data['in_shape'] = [2] 65 | self._data['num_classes'] = classes 66 | if use_one_hot: 67 | out_data = self._to_one_hot(out_data) 68 | self._data['out_data'] = out_data 69 | self._data['out_shape'] = [classes] 70 | self._data['train_inds'] = np.arange(train_x.shape[0]) 71 | self._data['test_inds'] = np.arange(train_x.shape[0], train_x.shape[0] + test_x.shape[0]) 72 | 73 | def get_identifier(self): 74 | """Returns the name of the dataset.""" 75 | return '2D_gauss_dataset' 76 | 77 | def circle_points(self,r=5, n=6): 78 | t = np.linspace(np.pi/2, 5/2*np.pi, n,endpoint=False) 79 | x = r * np.cos(t) 80 | y = r * np.sin(t) 81 | 82 | return np.c_[x, y] 83 | 84 | def get_input_mesh(self, x1_range=None, x2_range=None, grid_size=1000): 85 | """Create a 2D grid of input values. 86 | 87 | The default grid returned by this method will also be the default grid 88 | used by the method :meth:`plot_uncertainty_map`. 89 | 90 | Note: 91 | This method is only implemented for 2D datasets. 92 | 93 | Args: 94 | x1_range (tuple, optional): The min and max value for the first 95 | input dimension. If not specified, the range will be 96 | automatically inferred. 97 | 98 | Automatical inference is based on the underlying data (train 99 | and test). The range will be set, such that all data can be 100 | drawn inside. 101 | x2_range (tuple, optional): Same as ``x1_range`` for the second 102 | input dimension. 103 | grid_size (int or tuple): How many input samples per dimension. 104 | If an integer is passed, then the same number grid size will be 105 | used for both dimension. The grid is build by equally spacing 106 | ``grid_size`` inside the ranges ``x1_range`` and ``x2_range``. 107 | 108 | Returns: 109 | (tuple): Tuple containing: 110 | 111 | - **x1_grid** (numpy.ndarray): A 2D array, containing the grid 112 | values of the first dimension. 113 | - **x2_grid** (numpy.ndarray): A 2D array, containing the grid 114 | values of the second dimension. 115 | - **flattended_grid** (numpy.ndarray): A 2D array, containing all 116 | samples from the first dimension in the first column and all 117 | values corresponding to the second dimension in the second column. 118 | This format correspond to the input format as, for instance, 119 | returned by methods such as 120 | :meth:`data.dataset.Dataset.get_train_inputs`. 121 | """ 122 | if self.in_shape[0] != 2: 123 | raise ValueError('This method only applies to 2D datasets.') 124 | 125 | if isinstance(grid_size, int): 126 | grid_size = (grid_size, grid_size) 127 | else: 128 | assert len(grid_size) == 2 129 | 130 | if x1_range is None or x2_range is None: 131 | min_x1 = self._data['in_data'][:, 0].min() 132 | min_x2 = self._data['in_data'][:, 1].min() 133 | max_x1 = self._data['in_data'][:, 0].max() 134 | max_x2 = self._data['in_data'][:, 1].max() 135 | 136 | slack_1 = (max_x1 - min_x1) * 0.02 137 | slack_2 = (max_x2 - min_x2) * 0.02 138 | 139 | if x1_range is None: 140 | x1_range = (min_x1 - slack_1, max_x1 + slack_1) 141 | else: 142 | assert len(x1_range) == 2 143 | 144 | if x2_range is None: 145 | x2_range = (min_x2 - slack_2, max_x2 + slack_2) 146 | else: 147 | assert len(x2_range) == 2 148 | 149 | x1 = np.linspace(start=x1_range[0], stop=x1_range[1], num=grid_size[0]) 150 | x2 = np.linspace(start=x2_range[0], stop=x2_range[1], num=grid_size[1]) 151 | 152 | X1, X2 = np.meshgrid(x1, x2) 153 | X = np.vstack([X1.ravel(), X2.ravel()]).T 154 | 155 | return X1, X2, X 156 | 157 | def _plot_sample(self,writer=None): 158 | colors = ListedColormap(['#FF0000', '#0000FF']) 159 | 160 | # Create plot 161 | fig = plt.figure(figsize=(15, 10)) 162 | ax = fig.add_subplot(111) 163 | 164 | x_train_0 = self.get_train_inputs() 165 | y_train_0 = self.get_train_outputs() 166 | x_test_0 = self.get_test_inputs() 167 | y_test_0 = self.get_test_outputs() 168 | 169 | # define the colormap 170 | cmap = plt.cm.jet 171 | # extract all colors from the .jet map 172 | cmaplist = [cmap(i) for i in range(cmap.N)] 173 | # create the new map 174 | cmap = cmap.from_list('Custom cmap', cmaplist, cmap.N) 175 | 176 | ax.scatter(x_train_0[:, 0], x_train_0[:, 1], alpha=1, marker='o', c=np.argmax(y_train_0, 1), cmap=cmap, 177 | edgecolors='k', s=50, label='Train') 178 | ax.scatter(x_test_0[:, 0], x_test_0[:, 1], alpha=0.6, marker='s', c=np.argmax(y_test_0, 1), cmap=cmap, 179 | edgecolors='k', s=50, label='test') 180 | plt.title('Data', fontsize=30) 181 | plt.legend(loc=2, fontsize=30) 182 | if writer is not None: 183 | writer.add_figure('Data', plt.gcf(), 0, close=True) 184 | -------------------------------------------------------------------------------- /methods/WGD.py: -------------------------------------------------------------------------------- 1 | import torch.autograd as autograd 2 | import torch 3 | 4 | """ 5 | In this file WGD implementations in weights and function space are collected. 6 | """ 7 | 8 | 9 | class WGD: 10 | """ 11 | Implementation of WGD in weight space with KDE, SGE, SSGE approximation 12 | 13 | Args: 14 | P: instance of a distribution returning the log_prob, see distributions.py for examples 15 | K: kernel instance, see kernel.py for examples 16 | optimizer: instance of an optimizer SGD,Adam 17 | """ 18 | def __init__(self, P, K, optimizer,config,ann_sch,grad_estim=None,num_train = False, method = 'kde', device = None): 19 | self.P = P 20 | self.K = K 21 | self.optim = optimizer 22 | self.pge = grad_estim 23 | self.gamma = config.gamma 24 | self.ann_schedule=ann_sch 25 | self.num_train = num_train 26 | self.method = method 27 | self.device = device 28 | 29 | 30 | def phi(self, W,X,T,step): 31 | """ 32 | Computes the update of the WGD rule 33 | 34 | Args: 35 | W: particles 36 | X: inputs training batch 37 | T: labels training batch 38 | 39 | Return: 40 | phi: the update to feed the optimizer 41 | driving force: first term in the update rule 42 | repulsive force: second term in the update rule 43 | """ 44 | 45 | if self.num_train: 46 | num_t = self.P.num_train 47 | else: 48 | num_t = 1 49 | 50 | W = W.detach().requires_grad_(True) 51 | 52 | #computing the driving force 53 | log_prob = self.P.log_prob(W,X,T) 54 | score_func = autograd.grad(log_prob.sum(), W)[0] 55 | 56 | if self.method == 'kde': 57 | K_W = self.K(W, W.detach()) 58 | grad_K = autograd.grad(K_W.sum(), W)[0] 59 | grad_density = grad_K/ K_W.sum(1,keepdim = True) 60 | 61 | elif self.method == 'ssge': 62 | grad_density = self.pge.compute_score_gradients(W, W) 63 | 64 | elif self.method == 'sge': 65 | eta = 0.01 66 | K_W = self.K(W, W.detach()) 67 | grad_K = autograd.grad(K_W.sum(), W)[0] 68 | K_ = K_W+eta*torch.eye(K_W.shape[0]).to(self.device) 69 | grad_density = torch.linalg.solve(K_,grad_K) 70 | 71 | phi = ( self.ann_schedule[step]*score_func-grad_density) 72 | 73 | 74 | return phi, self.ann_schedule[step]*score_func, grad_density 75 | 76 | def step(self, W,X,T,step): 77 | """ 78 | Customization of the optimizer step where I am forcing the gradients to be instead the SVGD update rule 79 | 80 | Args: 81 | W: particles 82 | X: input training batch 83 | T: label training batch 84 | Return: 85 | driving force: first term in the update rule 86 | repulsive force: second term in the update rule 87 | """ 88 | self.optim.zero_grad() 89 | update = self.phi(W,X,T,step) 90 | W.grad = -update[0] 91 | self.optim.step() 92 | return update[1], update[2] 93 | 94 | 95 | class f_WGD: 96 | """ 97 | Implementation of WGD in weight space with KDE, SGE, SSGE approximation 98 | 99 | Args: 100 | P: instance of a distribution returning the log_prob, see distributions.py for examples 101 | K: kernel instance, see kernel.py for examples 102 | optimizer: instance of an optimizer SGD,Adam 103 | prior_grad_estim: SSGE estimator for functional prior 104 | config: configurator 105 | ann_sch: list of annealing steps 106 | pred_idx: if 1 the function space inference is on logits if 0 on softmax 107 | num_train: if True the repulsive force is multiplied by the number of datapoints 108 | """ 109 | def __init__(self, P, K, optimizer,config,ann_sch,grad_estim=None,pred_idx = 1,num_train = False, method='kde', device = None): 110 | self.P = P 111 | self.K = K 112 | self.optim = optimizer 113 | self.pge = grad_estim 114 | self.gamma = config.gamma 115 | self.ann_schedule=ann_sch 116 | self.pred_idx = pred_idx 117 | self.num_train = num_train 118 | self.method = method 119 | self.device = device 120 | 121 | 122 | def phi(self, W,X,T,step,X_add=None): 123 | """ 124 | Computes the update of the f-SVGD rule as: 125 | 126 | 127 | Args: 128 | W: particles 129 | X: input training batch 130 | T: label training batch 131 | 132 | Return: 133 | phi: the update to feed the optimizer 134 | driving force: first term in the update rule 135 | repulsive force: second term in the update rule 136 | """ 137 | if self.num_train: 138 | num_t = self.P.num_train 139 | else: 140 | num_t = 1 141 | 142 | 143 | W = W.detach().requires_grad_(True) 144 | 145 | ######### Score function ######### 146 | log_prob, pred = self.P.log_prob(W, X, T, return_pred=True,pred_idx=self.pred_idx) 147 | score_func = autograd.grad(log_prob.sum(), pred[self.pred_idx], retain_graph=True)[0] 148 | 149 | # Needed in case we want to compute the repulsion on additional points 150 | if X_add is not None: 151 | pred_add = (self.P.ensemble.forward(X_add,W)[self.pred_idx]).view(W.shape[0],-1) 152 | else: 153 | pred_add = pred[self.pred_idx].view(W.shape[0],-1) #[n_part, classesxB] 154 | 155 | 156 | ######### Repulsive force ######### 157 | pred_k = pred_add 158 | score_func = score_func.view(W.shape[0],-1) 159 | #pred = pred[0].view(W.shape[0],-1) #[n_part, classesxB] 160 | 161 | ######### Gradient functional prior ######### 162 | 163 | pred = pred[self.pred_idx].view(W.shape[0],-1) #[n_particles, classes x Batch] 164 | 165 | 166 | w_prior = self.P.prior.sample(torch.Size([W.shape[0]])) 167 | 168 | prior_pred = self.P.ensemble.forward(X, w_prior)[self.pred_idx].reshape(W.shape[0],-1) # changed index here 169 | 170 | grad_prior = self.pge.compute_score_gradients(pred, prior_pred) # .mean(0)1 171 | 172 | ######### Update rule ######### 173 | driv = score_func + grad_prior 174 | 175 | if self.method == 'kde': 176 | K_f = self.K(pred_k, pred_k.detach()) 177 | grad_K = autograd.grad(K_f.sum(), pred_k)[0] 178 | grad_K = grad_K.view(W.shape[0],-1) 179 | 180 | grad_density = grad_K/ K_f.sum(1,keepdim = True) 181 | 182 | elif self.method == 'ssge': 183 | grad_density = self.pge.compute_score_gradients(pred, pred) 184 | 185 | elif self.method == 'sge': 186 | eta = 0.01 187 | K_f = self.K(pred_k, pred_k.detach()) 188 | grad_K = autograd.grad(K_f.sum(), pred_k)[0] 189 | grad_K = grad_K.view(W.shape[0],-1) 190 | K_ = K_f+eta*torch.eye(K_f.shape[0]).to(self.device) 191 | grad_density = torch.linalg.solve(K_,grad_K) 192 | 193 | f_phi = (self.ann_schedule[step]*driv - grad_density) 194 | #f_phi = driv/W.size(0) 195 | 196 | w_phi = autograd.grad(pred,W,grad_outputs=f_phi,retain_graph=False)[0] 197 | #w_phi = autograd.grad(pred,W,grad_outputs=f_phi,retain_graph=False)[0] 198 | 199 | return w_phi, self.ann_schedule[step]*driv, -grad_density 200 | 201 | def step(self, W,X,T,step,X_add=None): 202 | """ 203 | Customization of the optimizer step where I am forcing the gradient to be the SVGD update rule 204 | 205 | Args: 206 | W: particles 207 | X: input training batch 208 | T: label training batch 209 | Return: 210 | driving force: first term in the update rule 211 | repulsive force: second term in the update rule 212 | """ 213 | self.optim.zero_grad() 214 | update = self.phi(W,X,T,step,X_add) 215 | W.grad = -update[0] 216 | #torch.nn.utils.clip_grad_norm_(W,0.1,2) 217 | self.optim.step() 218 | return update[1], update[2] -------------------------------------------------------------------------------- /models/mnets_classifier_interface.py: -------------------------------------------------------------------------------- 1 | """ 2 | Interface for Classifiers 3 | ------------------------- 4 | 5 | The original implementation can be found `here `__. 6 | 7 | 8 | A general interface for main networks used in classification tasks. This 9 | abstract base class also provides a collection of static helper functions that 10 | are useful in classification problems. 11 | """ 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from warnings import warn 17 | 18 | from mnets_mnet_interface import MainNetInterface 19 | 20 | class Classifier(nn.Module, MainNetInterface): 21 | """A general interface for classification networks. 22 | 23 | Attributes: 24 | num_classes: Number of output neurons. 25 | """ 26 | def __init__(self, num_classes, verbose): 27 | """Initialize the network. 28 | 29 | Args: 30 | num_classes: The number of output neurons. 31 | verbose: Allow printing of general information about the generated 32 | network (such as number of weights). 33 | """ 34 | # FIXME find a way using super to handle multiple inheritance. 35 | #super(Classifier, self).__init__() 36 | nn.Module.__init__(self) 37 | MainNetInterface.__init__(self) 38 | 39 | assert(num_classes > 0) 40 | self._num_classes = num_classes 41 | 42 | self._verbose = verbose 43 | 44 | @property 45 | def num_classes(self): 46 | """Getter for read-only attribute num_classes.""" 47 | return self._num_classes 48 | 49 | @staticmethod 50 | def logit_cross_entropy_loss(h, t, reduction='mean'): 51 | """Compute cross-entropy loss for given predictions and targets. 52 | Note, we assume that the argmax of the target vectors results in the 53 | correct label. 54 | 55 | Args: 56 | h: Unscaled outputs from the main network, i.e., activations of the 57 | last hidden layer (unscaled logits). 58 | t: Targets in form os soft labels or 1-hot encodings. 59 | reduction (str): The reduction method to be passed to 60 | :func:`torch.nn.functional.cross_entropy`. 61 | 62 | Returns: 63 | Cross-entropy loss computed on logits h and labels extracted 64 | from target vector t. 65 | """ 66 | assert(t.shape[1] == h.shape[1]) 67 | targets = t.argmax(dim=1, keepdim=False) 68 | return F.cross_entropy(h, targets, reduction=reduction) 69 | 70 | @staticmethod 71 | def knowledge_distillation_loss(logits, target_logits, target_mapping=None, 72 | device=None, T=2.): 73 | """Compute the knowledge distillation loss as proposed by 74 | 75 | Hinton et al., "Distilling the Knowledge in a Neural Network", 76 | NIPS Deep Learning and Representation Learning Workshop, 2015. 77 | http://arxiv.org/abs/1503.02531 78 | 79 | Args: 80 | logits: Unscaled outputs from the main network, i.e., activations of 81 | the last hidden layer (unscaled logits). 82 | target_logits: Target logits, i.e., activations of the last hidden 83 | layer (unscaled logits) from the target model. 84 | Note, we won't detach "target_logits" from the graph. Make sure, 85 | that you do this before calling this method. 86 | target_mapping: In continual learning, it might be that the output 87 | layer size of a model is growing. Thus, it could be that the 88 | model providing the ``target_logits`` has a smaller output size 89 | than the current model providing the ``logits``. Therefore, one 90 | has to provide a mapping, which is a list of indices for 91 | ``logits`` that state which activations in ``logits`` have a 92 | corresponding target in ``target_logits``. 93 | For instance, if the output layer size just increased by 1 94 | through appending a new output neuron to the current model, the 95 | mapping would simply be: 96 | :code:`target_mapping = list(range(target_logits.shape[1]))`. 97 | device: Current PyTorch device. Only needs to be specified if 98 | "target_mapping" is given. 99 | T: Softmax temperature. 100 | 101 | Returns: 102 | Knowledge Distillation (KD) loss. 103 | """ 104 | assert target_mapping is None or device is not None 105 | targets = F.softmax(target_logits / T, dim=1) 106 | n_classes = logits.shape[1] 107 | n_targets = targets.shape[1] 108 | 109 | if target_mapping is None: 110 | if n_classes != n_targets: 111 | raise ValueError('If sizes of "logits" and "target_logits" ' + 112 | 'differ, "target_mapping" must be specified.') 113 | else: 114 | new_targets = torch.zeros_like(logits).to(device) 115 | new_targets[:, target_mapping] = targets 116 | targets = new_targets 117 | 118 | # Note, I think the multiplication with T^2 here is not necessary. The 119 | # original paper prescribes it, but on a gradient analysis where the 120 | # MSE between tempered softmax targets and predictions is minimized 121 | # (assuming the logits are zero-mean). Here, the gradient should consist 122 | # of two terms where the first is scaled by 1/T and the second can be 123 | # considered as scaled by 1/T^2 if making the same assumption as the 124 | # distillation paper. 125 | # Though, I wouldn't change any of this, since the loss function has 126 | # been used and I don't think it matter for reasonable temperature 127 | # choices. 128 | return -(targets * F.log_softmax(logits / T,dim=1)).sum(dim=1).mean()*\ 129 | T**2 130 | 131 | @staticmethod 132 | def softmax_and_cross_entropy(h, t, reduction_sum=False): 133 | """Compute the cross entropy from logits, allowing smoothed labels 134 | (i.e., this function does not require 1-hot targets). 135 | 136 | Args: 137 | h: Unscaled outputs from the main network, i.e., activations of the 138 | last hidden layer (unscaled logits). 139 | t: Targets in form os soft labels or 1-hot encodings. 140 | 141 | Returns: 142 | Cross-entropy loss computed on logits h and given targets t. 143 | """ 144 | assert(t.shape[1] == h.shape[1]) 145 | 146 | loss = -(t * torch.nn.functional.log_softmax(h, dim=1)).sum(dim=1) 147 | 148 | if reduction_sum: 149 | return loss.sum() 150 | else: 151 | return loss.mean() 152 | 153 | @staticmethod 154 | def accuracy(y, t): 155 | """Computing the accuracy between predictions y and targets t. We 156 | assume that the argmax of t results in labels as described in the 157 | docstring of method "cross_entropy_loss". 158 | 159 | Args: 160 | y: Outputs from the main network. 161 | t: Targets in form of soft labels or 1-hot encodings. 162 | 163 | Returns: 164 | Relative prediction accuracy on the given batch. 165 | """ 166 | assert(t.shape[1] == y.shape[1]) 167 | predictions = y.argmax(dim=1, keepdim=False) 168 | targets = t.argmax(dim=1, keepdim=False) 169 | 170 | return (predictions == targets).float().mean() 171 | 172 | @staticmethod 173 | def num_hyper_weights(dims): 174 | """The number of weights that have to be predicted by a hypernetwork. 175 | 176 | .. deprecated:: 1.0 177 | Please use method 178 | :meth:`mnets.mnet_interface.MainNetInterface.shapes_to_num_weights` 179 | instead. 180 | 181 | Args: 182 | dims: For instance, the attribute :attr:`hyper_shapes`. 183 | 184 | Returns: 185 | (int) 186 | """ 187 | warn('Please use class "mnets.mnet_interface.MainNetInterface.' + 188 | 'shapes_to_num_weights" instead.', DeprecationWarning) 189 | 190 | return np.sum([np.prod(l) for l in dims]) 191 | 192 | if __name__ == '__main__': 193 | pass 194 | 195 | 196 | -------------------------------------------------------------------------------- /data/toy_classification/donuts.py: -------------------------------------------------------------------------------- 1 | from data.dataset import Dataset 2 | from sklearn import datasets 3 | import numpy as np 4 | from matplotlib.colors import ListedColormap 5 | import matplotlib.pyplot as plt 6 | from data.dataset import Dataset 7 | from sklearn import datasets 8 | import numpy as np 9 | from matplotlib.colors import ListedColormap 10 | import matplotlib.pyplot as plt 11 | import math 12 | import random 13 | import torch 14 | import itertools 15 | from sklearn.model_selection import train_test_split 16 | 17 | class Donuts(Dataset): 18 | """An instance of this class shall represent the two moons classification task 19 | """ 20 | 21 | def __init__(self,r_1, r_2, c_outer_1, c_outer_2, c_inner_1=None, c_inner_2=None, rseed=1234, use_one_hot=True, noise=0.1, n_train=100, n_test=80): 22 | 23 | """Generate a new dataset. 24 | 25 | The input data x for train and test samples are drawn from the moons dataset. 26 | Args: 27 | rseed (int): 28 | noise: Noise to inject in the input data 29 | n_train: Number of training samples. 30 | n_test: Number of test samples.. 31 | 32 | """ 33 | super().__init__() 34 | 35 | if rseed is None: 36 | rand = np.random 37 | else: 38 | rand = np.random.RandomState(rseed) 39 | random.seed(rseed) 40 | 41 | donut_1 = self.sample(r_outer=r_1[0], r_inner=r_1[1], x_outer=c_outer_1[0], y_outer=c_outer_1[1], size=n_train) 42 | donut_2 = self.sample(r_outer=r_2[0], r_inner=r_2[1], x_outer=c_outer_2[0], y_outer=c_outer_2[1], size=n_train) 43 | 44 | X_train = torch.tensor(np.vstack([donut_1, donut_2]), dtype=torch.float) 45 | 46 | class_labels_train = np.asarray( 47 | list(itertools.chain.from_iterable(itertools.repeat(x, int(X_train.shape[0] / 2)) for x in range(0, 2)))) 48 | 49 | train_x, test_x, train_y, test_y = train_test_split(X_train, class_labels_train, test_size=0.33, 50 | random_state=42) 51 | in_data = np.vstack([train_x, test_x]) 52 | out_data = np.vstack([np.expand_dims(train_y, 1), np.expand_dims(test_y, 1)]) 53 | 54 | # Specify internal data structure. 55 | # self._r_outer = r_outer 56 | # self._r_inner = r_inner 57 | # self._x_outer = x_outer 58 | # self._y_outer = y_outer 59 | # self._x_inner = x_inner 60 | # self._y_inner = y_inner 61 | self._data['classification'] = True 62 | self._data['sequence'] = False 63 | self._data['in_data'] = in_data 64 | self._data['in_shape'] = [2] 65 | self._data['num_classes'] = 2 66 | if use_one_hot: 67 | out_data = self._to_one_hot(out_data) 68 | self._data['out_data'] = out_data 69 | self._data['out_shape'] = [2] 70 | self._data['train_inds'] = np.arange(train_x.shape[0]) 71 | self._data['test_inds'] = np.arange(train_x.shape[0], train_x.shape[0] + test_x.shape[0]) 72 | 73 | def get_identifier(self): 74 | """Returns the name of the dataset.""" 75 | return 'Donuts_dataset' 76 | 77 | 78 | def sample(self,r_outer, r_inner, x_outer, y_outer, x_inner=None, y_inner=None,size = 1): 79 | """ 80 | Sample uniformly from (x, y) satisfiying: 81 | 82 | x**2 + y**2 <= r_outer**2 83 | 84 | (x-x_inner)**2 + (y-y_inner)**2 > r_inner**2 85 | 86 | Assumes that the inner circle lies inside the outer circle; 87 | i.e., that hypot(x_inner, y_inner) <= r_outer - r_inner. 88 | """ 89 | # Sample from a normal annulus with radii r_inner and r_outer. 90 | if x_inner is None: 91 | x_inner = x_outer 92 | y_inner = y_outer 93 | 94 | l = [] 95 | for _ in range(size): 96 | rad = math.sqrt(random.uniform(r_inner ** 2, r_outer ** 2)) 97 | angle = random.uniform(-math.pi, math.pi) 98 | x, y = rad * math.cos(angle) + x_outer, rad * math.sin(angle) + y_outer 99 | 100 | # If we're inside the forbidden hole, reflect. 101 | if math.hypot(x - x_inner, y - y_inner) < r_inner: 102 | x, y = x_inner - x, y_inner - y 103 | l.append([x, y]) 104 | 105 | return np.asarray(l) 106 | 107 | def get_input_mesh(self, x1_range=None, x2_range=None, grid_size=1000): 108 | """Create a 2D grid of input values to make useful plots like 109 | 110 | Note: 111 | This method is only implemented for 2D datasets. 112 | 113 | Args: 114 | x1_range (tuple, optional): The min and max value for the first 115 | input dimension. If not specified, the range will be 116 | automatically inferred. 117 | 118 | Automatical inference is based on the underlying data (train 119 | and test). The range will be set, such that all data can be 120 | drawn inside. 121 | x2_range (tuple, optional): Same as ``x1_range`` for the second 122 | input dimension. 123 | grid_size (int or tuple): How many input samples per dimension. 124 | If an integer is passed, then the same number grid size will be 125 | used for both dimension. The grid is build by equally spacing 126 | ``grid_size`` inside the ranges ``x1_range`` and ``x2_range``. 127 | 128 | Returns: 129 | (tuple): Tuple containing: 130 | 131 | - **x1_grid** (numpy.ndarray): A 2D array, containing the grid 132 | values of the first dimension. 133 | - **x2_grid** (numpy.ndarray): A 2D array, containing the grid 134 | values of the second dimension. 135 | - **flattended_grid** (numpy.ndarray): A 2D array, containing all 136 | samples from the first dimension in the first column and all 137 | values corresponding to the second dimension in the second column. 138 | This format correspond to the input format as, for instance, 139 | returned by methods such as 140 | :meth:`data.dataset.Dataset.get_train_inputs`. 141 | """ 142 | if self.in_shape[0] != 2: 143 | raise ValueError('This method only applies to 2D datasets.') 144 | 145 | if isinstance(grid_size, int): 146 | grid_size = (grid_size, grid_size) 147 | else: 148 | assert len(grid_size) == 2 149 | 150 | if x1_range is None or x2_range is None: 151 | min_x1 = self._data['in_data'][:, 0].min() 152 | min_x2 = self._data['in_data'][:, 1].min() 153 | max_x1 = self._data['in_data'][:, 0].max() 154 | max_x2 = self._data['in_data'][:, 1].max() 155 | 156 | slack_1 = (max_x1 - min_x1) * 0.02 157 | slack_2 = (max_x2 - min_x2) * 0.02 158 | 159 | if x1_range is None: 160 | x1_range = (min_x1 - slack_1, max_x1 + slack_1) 161 | else: 162 | assert len(x1_range) == 2 163 | 164 | if x2_range is None: 165 | x2_range = (min_x2 - slack_2, max_x2 + slack_2) 166 | else: 167 | assert len(x2_range) == 2 168 | 169 | x1 = np.linspace(start=x1_range[0], stop=x1_range[1], num=grid_size[0]) 170 | x2 = np.linspace(start=x2_range[0], stop=x2_range[1], num=grid_size[1]) 171 | 172 | X1, X2 = np.meshgrid(x1, x2) 173 | X = np.vstack([X1.ravel(), X2.ravel()]).T 174 | 175 | return X1, X2, X 176 | 177 | def _plot_sample(self, fig, inner_grid, num_inner_plots, ind, inputs, 178 | outputs=None, predictions=None): 179 | colors = ListedColormap(['#FF0000', '#0000FF']) 180 | 181 | # Create plot 182 | fig = plt.figure(figsize=(15, 10)) 183 | ax = fig.add_subplot(111) 184 | 185 | x_train_0 = self.get_train_inputs() 186 | y_train_0 = self.get_train_outputs() 187 | x_test_0 = self.get_test_inputs() 188 | y_test_0 = self.get_test_outputs() 189 | 190 | ax.scatter(x_train_0[:, 0], x_train_0[:, 1], alpha=1, marker='o', c=np.argmax(y_train_0, 1), cmap=colors, 191 | edgecolors='k', s=50, label='Train') 192 | ax.scatter(x_test_0[:, 0], x_test_0[:, 1], alpha=0.6, marker='s', c=np.argmax(y_test_0, 1), cmap=colors, 193 | edgecolors='k', s=50, label='test') 194 | plt.title('Data', fontsize=30) 195 | plt.legend(loc=2, fontsize=30) 196 | plt.show() 197 | 198 | 199 | 200 | 201 | -------------------------------------------------------------------------------- /models/net_utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | A collection of helper functions that should capture common functionalities 3 | needed when working with PyTorch. 4 | 5 | The original implementation can be found `here `__. 6 | 7 | """ 8 | import math 9 | import numpy as np 10 | import torch 11 | from torch import nn 12 | import types 13 | 14 | def init_params(weights, bias=None): 15 | """Initialize the weights and biases of a linear or (transpose) conv layer. 16 | 17 | Note, the implementation is based on the method "reset_parameters()", 18 | that defines the original PyTorch initialization for a linear or 19 | convolutional layer, resp. The implementations can be found here: 20 | 21 | https://git.io/fhnxV 22 | 23 | https://git.io/fhnx2 24 | 25 | Args: 26 | weights: The weight tensor to be initialized. 27 | bias (optional): The bias tensor to be initialized. 28 | """ 29 | nn.init.kaiming_uniform_(weights, a=math.sqrt(5)) 30 | if bias is not None: 31 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights) 32 | bound = 1 / math.sqrt(fan_in) 33 | nn.init.uniform_(bias, -bound, bound) 34 | 35 | def get_optimizer(params, lr, momentum=0, weight_decay=0, use_adam=False, 36 | adam_beta1=0.9, use_rmsprop=False, use_adadelta=False, 37 | use_adagrad=False, pgroup_ids=None): 38 | """Create an optimizer instance for the given set of parameters. Default 39 | optimizer is :class:`torch.optim.SGD`. 40 | 41 | Args: 42 | params (list): The parameters passed to the optimizer. 43 | lr: Learning rate. 44 | momentum (optional): Momentum (only applicable to 45 | :class:`torch.optim.SGD` and :class:`torch.optim.RMSprop`. 46 | weight_decay (optional): L2 penalty. 47 | use_adam: Use :class:`torch.optim.Adam` optimizer. 48 | adam_beta1: First parameter in the `betas` tuple that is passed to the 49 | optimizer :class:`torch.optim.Adam`: 50 | :code:`betas=(adam_beta1, 0.999)`. 51 | use_rmsprop: Use :class:`torch.optim.RMSprop` optimizer. 52 | use_adadelta: Use :class:`torch.optim.Adadelta` optimizer. 53 | use_adagrad: Use :class:`torch.optim.Adagrad` optimizer. 54 | pgroup_ids (list, optional): If passed, a list of integers of the same 55 | length as params is expected. In this case, each integer states to 56 | which parameter group the corresponding parameter in ``params`` 57 | shall belong. Parameter groups may have different optimizer 58 | settings. Therefore, options like ``lr``, ``momentum``, 59 | ``weight_decay``, ``adam_beta1`` may be lists in this case that have 60 | a length corresponding to the number of parameter groups. 61 | 62 | Returns: 63 | Optimizer instance. 64 | """ 65 | use_sgd = not use_adam and not use_rmsprop and not use_adadelta \ 66 | and not use_adagrad 67 | 68 | if isinstance(params, types.GeneratorType): 69 | params = list(params) 70 | 71 | # Transform list of parameter tensors `params` into list of dictionaries. 72 | if pgroup_ids is None: 73 | pgroup_ids = [0] * len(params) 74 | npgroups = 1 75 | else: 76 | assert len(pgroup_ids) == len(params) 77 | npgroups = max(pgroup_ids) + 1 78 | 79 | plist = params 80 | params = [] 81 | 82 | # Initialize parameter group dictionaries. 83 | for i in range(npgroups): 84 | pdict = {} 85 | pdict['params'] = [] 86 | pdict['lr'] = lr[i] if isinstance(lr, (list, tuple)) else lr 87 | pdict['weight_decay'] = weight_decay[i] \ 88 | if isinstance(weight_decay, (list, tuple)) else weight_decay 89 | if use_adam: 90 | ab1 = adam_beta1[i] if isinstance(adam_beta1, (list, tuple)) \ 91 | else adam_beta1 92 | pdict['betas'] = [ab1, 0.999] 93 | if use_sgd or use_rmsprop: 94 | pdict['momentum'] = momentum[i] \ 95 | if isinstance(momentum, (list, tuple)) else momentum 96 | params.append(pdict) 97 | 98 | # Fill parameter groups. 99 | for pgid, p in zip(pgroup_ids, plist): 100 | params[pgid]['params'].append(p) 101 | 102 | 103 | if use_adam: 104 | optimizer = torch.optim.Adam(params) 105 | elif use_rmsprop: 106 | optimizer = torch.optim.RMSprop(params) 107 | elif use_adadelta: 108 | optimizer = torch.optim.Adadelta(params) 109 | elif use_adagrad: 110 | optimizer = torch.optim.Adagrad(params) 111 | else: 112 | assert use_sgd 113 | optimizer = torch.optim.SGD(params) 114 | 115 | return optimizer 116 | 117 | def lambda_lr_schedule(epoch): 118 | """Multiplicative Factor for Learning Rate Schedule. 119 | 120 | Computes a multiplicative factor for the initial learning rate based 121 | on the current epoch. This method can be used as argument 122 | ``lr_lambda`` of class :class:`torch.optim.lr_scheduler.LambdaLR`. 123 | 124 | The schedule is inspired by the Resnet CIFAR-10 schedule suggested 125 | here https://keras.io/examples/cifar10_resnet/. 126 | 127 | Args: 128 | epoch (int): The number of epochs 129 | 130 | Returns: 131 | lr_scale (float32): learning rate scale 132 | """ 133 | lr_scale = 1. 134 | if epoch > 180: 135 | lr_scale = 0.5e-3 136 | elif epoch > 160: 137 | lr_scale = 1e-3 138 | elif epoch > 120: 139 | lr_scale = 1e-2 140 | elif epoch > 80: 141 | lr_scale = 1e-1 142 | return lr_scale 143 | 144 | class CutoutTransform(object): 145 | """Randomly mask out one or more patches from an image. 146 | 147 | The cutout transformation as preprocessing step has been proposed by 148 | 149 | DeVries et al., `Improved Regularization of Convolutional Neural \ 150 | Networks with Cutout `__, 2017. 151 | 152 | The original implementation can be found `here `__. 154 | 155 | Args: 156 | n_holes (int): Number of patches to cut out of each image. 157 | length (int): The length (in pixels) of each square patch. 158 | """ 159 | # This code of this class has been copied from (accessed 04/08/2020): 160 | # https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py 161 | # 162 | # NOTE Our copyright and license does not apply for this function. 163 | # We use this code WITHOUT ANY WARRANTIES. 164 | # 165 | # The code is licensed according to 166 | # Educational Community License, Version 2.0 (ECL-2.0) 167 | # https://github.com/uoguelph-mlrg/Cutout/blob/master/LICENSE.md 168 | # 169 | # Copyright 2017 Terrance DeVries, Raeid Saqur 170 | # Licensed under the Educational Community License, Version 2.0 171 | # (the "License"); you may not use this file except in compliance with the 172 | # License. You may obtain a copy of the License at 173 | # 174 | # http://www.osedu.org/licenses /ECL-2.0 175 | # 176 | # Unless required by applicable law or agreed to in writing, software 177 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 178 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 179 | # License for the specific language governing permissions and limitations 180 | # under the License. 181 | def __init__(self, n_holes, length): 182 | self.n_holes = n_holes 183 | self.length = length 184 | 185 | def __call__(self, img): 186 | """Perform cutout to given image. 187 | 188 | Args: 189 | img (Tensor): Tensor image of size ``(C, H, W)``. 190 | 191 | Returns: 192 | (torch.Tensor): Image with ``n_holes`` of dimension 193 | ``length x length`` cut out of it. 194 | """ 195 | h = img.size(1) 196 | w = img.size(2) 197 | 198 | mask = np.ones((h, w), np.float32) 199 | 200 | for n in range(self.n_holes): 201 | y = np.random.randint(h) 202 | x = np.random.randint(w) 203 | 204 | y1 = np.clip(y - self.length // 2, 0, h) 205 | y2 = np.clip(y + self.length // 2, 0, h) 206 | x1 = np.clip(x - self.length // 2, 0, w) 207 | x2 = np.clip(x + self.length // 2, 0, w) 208 | 209 | mask[y1: y2, x1: x2] = 0. 210 | 211 | mask = torch.from_numpy(mask) 212 | mask = mask.expand_as(img) 213 | img = img * mask 214 | 215 | return img 216 | 217 | if __name__ == '__main__': 218 | pass 219 | 220 | 221 | -------------------------------------------------------------------------------- /models/net_utils/misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | The original implementation can be found `here `__. 4 | 5 | """ 6 | 7 | 8 | import inspect 9 | import matplotlib 10 | import matplotlib.pyplot as plt 11 | import math 12 | from torch import nn 13 | import torch 14 | from warnings import warn 15 | 16 | def init_params(weights, bias=None): 17 | """Initialize the weights and biases of a linear or (transpose) conv layer. 18 | 19 | Note, the implementation is based on the method "reset_parameters()", 20 | that defines the original PyTorch initialization for a linear or 21 | convolutional layer, resp. The implementations can be found here: 22 | 23 | https://git.io/fhnxV 24 | 25 | https://git.io/fhnx2 26 | 27 | .. deprecated:: 1.0 28 | Please use function :func:`utils.torch_utils.init_params` instead. 29 | 30 | Args: 31 | weights: The weight tensor to be initialized. 32 | bias (optional): The bias tensor to be initialized. 33 | """ 34 | warn('Function is deprecated. Use "utils.torch_utils.init_params" instead.', 35 | DeprecationWarning) 36 | 37 | nn.init.kaiming_uniform_(weights, a=math.sqrt(5)) 38 | if bias is not None: 39 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights) 40 | bound = 1 / math.sqrt(fan_in) 41 | nn.init.uniform_(bias, -bound, bound) 42 | 43 | def str_to_ints(str_arg): 44 | """Helper function to convert a string which is a list of comma separated 45 | integers into an actual list of integers. 46 | 47 | Args: 48 | str_arg: String containing list of comma-separated ints. For convenience 49 | reasons, we allow the user to also pass single integers that a put 50 | into a list of length 1 by this function. 51 | 52 | Returns: 53 | (list): List of integers. 54 | """ 55 | if isinstance(str_arg, int): 56 | return [str_arg] 57 | 58 | if len(str_arg) > 0: 59 | return [int(s) for s in str_arg.split(',')] 60 | else: 61 | return [] 62 | 63 | def str_to_floats(str_arg): 64 | """Helper function to convert a string which is a list of comma separated 65 | floats into an actual list of floats. 66 | 67 | Args: 68 | str_arg: String containing list of comma-separated floats. For 69 | convenience reasons, we allow the user to also pass single float 70 | that a put into a list of length 1 by this function. 71 | 72 | Returns: 73 | (list): List of floats. 74 | """ 75 | if isinstance(str_arg, float): 76 | return [str_arg] 77 | 78 | if len(str_arg) > 0: 79 | return [float(s) for s in str_arg.split(',')] 80 | else: 81 | return [] 82 | 83 | def list_to_str(list_arg, delim=' '): 84 | """Convert a list of numbers into a string. 85 | 86 | Args: 87 | list_arg: List of numbers. 88 | delim (optional): Delimiter between numbers. 89 | 90 | Returns: 91 | (str): List converted to string. 92 | """ 93 | ret = '' 94 | for i, e in enumerate(list_arg): 95 | if i > 0: 96 | ret += delim 97 | ret += str(e) 98 | return ret 99 | 100 | def str_to_act(act_str): 101 | """Convert the name of an activation function into the actual PyTorch 102 | activation function. 103 | 104 | Args: 105 | act_str: Name of activation function (as defined by command-line 106 | arguments). 107 | 108 | Returns: 109 | Torch activation function instance or ``None``, if ``linear`` is given. 110 | """ 111 | if act_str == 'linear': 112 | act = None 113 | elif act_str == 'sigmoid': 114 | act = torch.nn.Sigmoid() 115 | elif act_str == 'relu': 116 | act = torch.nn.ReLU() 117 | elif act_str == 'elu': 118 | act = torch.nn.ELU() 119 | elif act_str == 'tanh': 120 | act = torch.nn.Tanh() 121 | else: 122 | raise Exception('Activation function %s unknown.' % act_str) 123 | return act 124 | 125 | def configure_matplotlib_params(fig_size = [6.4, 4.8], two_axes=True, 126 | font_size=8, usetex=False): 127 | """Helper function to configure default matplotlib parameters. 128 | 129 | Args: 130 | fig_size: Figure size (width, height) in inches. 131 | usetex (bool): Whether ``text.usetex`` should be set (leads to an 132 | error on systems that don't have latex installed). 133 | """ 134 | params = { 135 | 'axes.labelsize': font_size, 136 | 'font.size': font_size, 137 | 'font.sans-serif': ['Arial'], 138 | 'text.usetex': usetex, 139 | 'text.latex.preamble': [r'\usepackage[scaled]{helvet}', 140 | r'\usepackage{sfmath}'], 141 | 'font.family': 'sans-serif', 142 | 'legend.fontsize': font_size, 143 | 'xtick.labelsize': font_size, 144 | 'ytick.labelsize': font_size, 145 | 'axes.titlesize': font_size, 146 | 'axes.spines.right' : not two_axes, 147 | 'axes.spines.top' : not two_axes, 148 | 'figure.figsize': fig_size, 149 | 'legend.handlelength': 0.5 150 | } 151 | 152 | matplotlib.rcParams.update(params) 153 | 154 | def get_colorbrewer2_colors(family = 'Set2'): 155 | """Helper function that returns a list of color combinations 156 | extracted from colorbrewer2.org. 157 | 158 | Args: 159 | (list): the color family from colorbrewer2.org to use. 160 | """ 161 | # https://colorbrewer2.org/#type=qualitative&scheme=Set1&n=7 162 | if family == 'Set2': 163 | return [ 164 | '#e41a1c', 165 | '#377eb8', 166 | '#4daf4a', 167 | '#984ea3', 168 | '#ff7f00', 169 | '#ffff33', 170 | '#a65628', 171 | '#b3de69' 172 | ] 173 | # https://colorbrewer2.org/#type=qualitative&scheme=Set3&n=8 174 | elif family == 'Set3': 175 | return [ 176 | '#8dd3c7', 177 | '#ffffb3', 178 | '#bebada', 179 | '#fb8072', 180 | '#80b1d3', 181 | '#fdb462', 182 | '#b3de69', 183 | '#fccde5' 184 | ] 185 | # https://colorbrewer2.org/#type=qualitative&scheme=Dark2&n=8 186 | elif family == 'Dark2': 187 | return [ 188 | '#1b9e77', 189 | '#d95f02', 190 | '#7570b3', 191 | '#e7298a', 192 | '#66a61e', 193 | '#e6ab02', 194 | '#a6761d' 195 | ] 196 | # https://colorbrewer2.org/#type=qualitative&scheme=Pastel1&n=8 197 | elif family == 'Pastel': 198 | return [ 199 | '#fbb4ae', 200 | '#b3cde3', 201 | '#ccebc5', 202 | '#decbe4', 203 | '#fed9a6', 204 | '#ffffcc', 205 | '#e5d8bd' 206 | ] 207 | else: 208 | raise ValueError() 209 | 210 | def repair_canvas_and_show_fig(fig, close=True): 211 | """If writing a figure to tensorboard via "add_figure" it might change the 212 | canvas, such that our backend doesn't allow to show the figure anymore. 213 | This method will generate a new canvas and replace the old one of the 214 | given figure. 215 | 216 | Args: 217 | fig: The figure to be shown. 218 | close: Whether the figure should be closed after it has been shown. 219 | """ 220 | tmp_fig = plt.figure() 221 | tmp_manager = tmp_fig.canvas.manager 222 | tmp_manager.canvas.figure = fig 223 | fig.set_canvas(tmp_manager.canvas) 224 | plt.close(tmp_fig.number) 225 | plt.figure(fig.number) 226 | plt.show() 227 | if close: 228 | plt.close(fig.number) 229 | 230 | def get_default_args(func): 231 | """Get the default values of all keyword arguments for a given function. 232 | 233 | Args: 234 | func: A function handle. 235 | 236 | Returns: 237 | (dict): Dictionary with keyword argument names as keys and their 238 | default value as values. 239 | """ 240 | # The code from this function has been copied from (accessed: 02/28/2020): 241 | # https://stackoverflow.com/questions/12627118/get-a-function-arguments-default-value 242 | # 243 | # NOTE Our copyright and license does not apply for this function. 244 | # We use this code WITHOUT ANY WARRANTIES. 245 | # 246 | # Instead, the code in this method is licensed under CC BY-SA 3.0: 247 | # https://creativecommons.org/licenses/by-sa/3.0/ 248 | # 249 | # The code stems from an answer by user "mgilson": 250 | # https://stackoverflow.com/users/748858/mgilson 251 | signature = inspect.signature(func) 252 | 253 | return { 254 | k: v.default for k, v in signature.parameters.items() \ 255 | if v.default is not inspect.Parameter.empty 256 | } 257 | 258 | if __name__ == '__main__': 259 | pass 260 | -------------------------------------------------------------------------------- /data/fashion_mnist/fashion_data.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | from torchvision.datasets import FashionMNIST 4 | 5 | from data.cifar.cifar10_data import CIFAR10Data 6 | from data.dataset import Dataset 7 | from data.mnist.mnist_data import MNISTData 8 | 9 | class FashionMNISTData(Dataset): 10 | """An instance of the class shall represent the Fashion-MNIST dataset. 11 | 12 | The original implementation can be found `here `__. 13 | 14 | Note: 15 | By default, input samples are provided in a range of ``[0, 1]``. 16 | 17 | Args: 18 | data_path (str): Where should the dataset be read from? If not existing, 19 | the dataset will be downloaded into this folder. 20 | use_one_hot (bool): Whether the class labels should be 21 | represented in a one-hot encoding. 22 | validation_size (int): The number of validation samples. Validation 23 | samples will be taking from the training set (the first :math:`n` 24 | samples). 25 | use_torch_augmentation (bool): Apply data augmentation to inputs when 26 | calling method :meth:`data.dataset.Dataset.input_to_torch_tensor`. 27 | 28 | The augmentation will be identical to the one provided by class 29 | :class:`data.mnist_data.MNISTData`, **except** that during training 30 | also random horizontal flips are applied. 31 | 32 | Note: 33 | If activated, the statistics of test samples are changed as 34 | a normalization is applied. 35 | """ 36 | def __init__(self, data_path, use_one_hot=False, validation_size=0, 37 | use_torch_augmentation=False): 38 | super().__init__() 39 | 40 | fmnist_train = FashionMNIST(data_path, train=True, download=True) 41 | fmnist_test = FashionMNIST(data_path, train=False, download=True) 42 | assert np.all(np.equal(fmnist_train.data.shape, [60000, 28, 28])) 43 | assert np.all(np.equal(fmnist_test.data.shape, [10000, 28, 28])) 44 | 45 | train_inputs = fmnist_train.data.numpy().reshape(60000, -1) 46 | test_inputs = fmnist_test.data.numpy().reshape(10000, -1) 47 | train_labels = fmnist_train.targets.numpy().reshape(60000, 1) 48 | test_labels = fmnist_test.targets.numpy().reshape(10000, 1) 49 | 50 | images = np.concatenate([train_inputs, test_inputs], axis=0) 51 | labels = np.concatenate([train_labels, test_labels], axis=0) 52 | 53 | # Scale images into a range between 0 and 1. Such that it is identical 54 | # to the default MNIST scale in `data.dataset.mnist_data`. 55 | images = images / 255 56 | 57 | val_inds = None 58 | train_inds = np.arange(train_labels.size) 59 | test_inds = np.arange(train_labels.size, 60 | train_labels.size + test_labels.size) 61 | 62 | if validation_size > 0: 63 | if validation_size >= train_inds.size: 64 | raise ValueError('Validation set must contain less than %d ' \ 65 | % (train_inds.size) + 'samples!') 66 | 67 | val_inds = np.arange(validation_size) 68 | train_inds = np.arange(validation_size, train_inds.size) 69 | 70 | # Bring everything into the internal structure of the Dataset class. 71 | self._data['classification'] = True 72 | self._data['sequence'] = False 73 | self._data['num_classes'] = 10 74 | self._data['is_one_hot'] = use_one_hot 75 | self._data['in_data'] = images 76 | self._data['in_shape'] = [28, 28, 1] 77 | self._data['out_shape'] = [10 if use_one_hot else 1] 78 | self._data['val_inds'] = val_inds 79 | self._data['train_inds'] = train_inds 80 | self._data['test_inds'] = test_inds 81 | 82 | if use_one_hot: 83 | labels = self._to_one_hot(labels) 84 | 85 | self._data['out_data'] = labels 86 | 87 | # Information specific to this dataset. 88 | assert np.all([fmnist_train.classes[i] == c for i, c in \ 89 | enumerate(fmnist_test.classes)]) 90 | self._data['fmnist'] = dict() 91 | self._data['fmnist']['classes'] = fmnist_train.classes 92 | 93 | # Initialize PyTorch data augmentation. 94 | self._augment_inputs = False 95 | if use_torch_augmentation: 96 | self._augment_inputs = True 97 | self._train_transform, self._test_transform = \ 98 | MNISTData.torch_input_transforms(use_random_hflips=True) 99 | 100 | def get_identifier(self): 101 | """Returns the name of the dataset.""" 102 | return 'Fashion-MNIST' 103 | 104 | def input_to_torch_tensor(self, x, device, mode='inference', 105 | force_no_preprocessing=False, sample_ids=None): 106 | """This method can be used to map the internal numpy arrays to PyTorch 107 | tensors. 108 | 109 | Note, this method has been overwritten from the base class. 110 | 111 | If enabled via constructor option ``use_torch_augmentation``, input 112 | images are preprocessed. 113 | Preprocessing involves normalization and (for training mode) random 114 | perturbations. 115 | 116 | Args: 117 | (....): See docstring of method 118 | :meth:`data.dataset.Dataset.input_to_torch_tensor`. 119 | 120 | Returns: 121 | (torch.Tensor): The given input ``x`` as PyTorch tensor. 122 | """ 123 | # FIXME Method is identical to the one used by the MNIST dataset. 124 | if self._augment_inputs and not force_no_preprocessing: 125 | if mode == 'inference': 126 | transform = self._test_transform 127 | elif mode == 'train': 128 | transform = self._train_transform 129 | else: 130 | raise ValueError('"%s" not a valid value for argument "mode".' 131 | % mode) 132 | 133 | return CIFAR10Data.torch_augment_images(x, device, transform, 134 | img_shape=self.in_shape) 135 | 136 | else: 137 | return Dataset.input_to_torch_tensor(self, x, device, 138 | mode=mode, force_no_preprocessing=force_no_preprocessing, 139 | sample_ids=sample_ids) 140 | 141 | def _plot_sample(self, fig, inner_grid, num_inner_plots, ind, inputs, 142 | outputs=None, predictions=None): 143 | """Implementation of abstract method 144 | :meth:`data.dataset.Dataset._plot_sample`. 145 | """ 146 | ax = plt.Subplot(fig, inner_grid[0]) 147 | 148 | if outputs is None: 149 | ax.set_title("Fashion-MNIST Sample") 150 | else: 151 | assert(np.size(outputs) == 1) 152 | label = np.asscalar(outputs) 153 | label_name = self._data['fmnist']['classes'][label] 154 | 155 | if predictions is None: 156 | ax.set_title('Sample with label:\n%s (%d)' % \ 157 | (label_name, label)) 158 | else: 159 | if np.size(predictions) == self.num_classes: 160 | pred_label = np.argmax(predictions) 161 | else: 162 | pred_label = np.asscalar(predictions) 163 | pred_label_name = self._data['fmnist']['classes'][pred_label] 164 | 165 | ax.set_title('Label: %s (%d)\n' % (label_name, label) + \ 166 | 'Prediction: %s (%d)' % (pred_label_name, 167 | pred_label)) 168 | 169 | #plt.subplots_adjust(wspace=0.5, hspace=0.4) 170 | 171 | ax.set_axis_off() 172 | ax.imshow(np.squeeze(np.reshape(inputs, self.in_shape))) 173 | fig.add_subplot(ax) 174 | 175 | if num_inner_plots == 2: 176 | ax = plt.Subplot(fig, inner_grid[1]) 177 | ax.set_title('Predictions') 178 | bars = ax.bar(range(self.num_classes), np.squeeze(predictions)) 179 | ax.set_xticks(range(self.num_classes)) 180 | if outputs is not None: 181 | bars[int(label)].set_color('r') 182 | fig.add_subplot(ax) 183 | 184 | def _plot_config(self, inputs, outputs=None, predictions=None): 185 | """Re-Implementation of method 186 | :meth:`data.dataset.Dataset._plot_config`. 187 | 188 | This method has been overriden to ensure, that there are 2 subplots, 189 | in case the predictions are given. 190 | """ 191 | # FIXME code copied from MNISTData. 192 | plot_configs = super()._plot_config(inputs, outputs=outputs, 193 | predictions=predictions) 194 | 195 | if predictions is not None and \ 196 | np.shape(predictions)[1] == self.num_classes: 197 | plot_configs['outer_hspace'] = 0.6 198 | plot_configs['inner_hspace'] = 0.4 199 | plot_configs['num_inner_rows'] = 2 200 | #plot_configs['num_inner_cols'] = 1 201 | plot_configs['num_inner_plots'] = 2 202 | 203 | return plot_configs 204 | 205 | if __name__ == '__main__': 206 | pass 207 | 208 | 209 | -------------------------------------------------------------------------------- /training/training_1dreg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from utils.distributions import Unorm_post 4 | from utils.utils import plot_predictive_distributions 5 | 6 | 7 | def train(data, ensemble, device, config,writer): 8 | """Train the particles using a specific ensemble. 9 | 10 | Args: 11 | data: A DATASET instance. 12 | mnet: The model of the main network. 13 | device: Torch device (cpu or gpu). 14 | config: The command line arguments. 15 | writer: The tensorboard summary writer. 16 | """ 17 | 18 | # particles.train()grid_mesh[1].shape 19 | 20 | #W = ensemble.particles.clone() 21 | 22 | W = ensemble.particles 23 | samples = [] 24 | 25 | optimizer = torch.optim.Adam([W], config.lr, weight_decay=config.weight_decay, 26 | betas=[config.adam_beta1, 0.999]) 27 | # optimizer = torch.optim.SGD([W], config.lr) 28 | 29 | K_p = covariance_K() 30 | K = RBF_2() 31 | 32 | prior = torch.distributions.normal.Normal(torch.zeros(ensemble.net.num_params).to(device), 33 | torch.ones(ensemble.net.num_params).to(device) * config.prior_variance) 34 | 35 | P = Unorm_post(ensemble, prior, config, data.num_train_samples) 36 | 37 | if config.method == 'SGLD': 38 | method = SGLD(P, K, optimizer) 39 | elif config.method == 'SGD': 40 | method = SGD(P,K,optimizer) 41 | elif config.method == 'SVGD_annealing': 42 | method = SVGD_annealing(P, K, optimizer,config) 43 | elif config.method == 'SVGD_debug': 44 | method = SVGD_debug(P,K,optimizer) 45 | elif config.method == 'SVGD': 46 | method = SVGD(P,K,optimizer) 47 | elif config.method == 'f_SVGD': 48 | method = functional_SVGD(P,K,optimizer) 49 | elif config.method == 'r_SGD': 50 | method = repulsive_SGD(P,K,optimizer) 51 | elif config.method == 'f_SGD': 52 | method = functional_SGD(P,K,optimizer) 53 | elif config.method == 'f_p_SVGD': 54 | method = f_p_SVGD(P,K,optimizer) 55 | elif config.method == 'mixed_f_p_SVGD': 56 | method = mixed_f_p_SVGD(P,K,K_p,optimizer) 57 | elif config.method == 'log_p_SVGD': 58 | method = log_p_SVGD(P,K,optimizer) 59 | elif config.method == 'log_p_SGD': 60 | method = log_p_SGD(P,K,optimizer) 61 | elif config.method == 'fisher_SVGD': 62 | method = fisher_SVGD(P,K,optimizer) 63 | 64 | driving_l = [] 65 | repulsive_l = [] 66 | 67 | for i in range(config.epochs): 68 | 69 | optimizer.zero_grad() 70 | 71 | batch = data.next_train_batch(config.batch_size) 72 | X = data.input_to_torch_tensor(batch[0], device, mode='train') 73 | T = data.output_to_torch_tensor(batch[1], device, mode='train') 74 | x_test_0 = torch.tensor(data.get_test_inputs(), dtype=torch.float) 75 | y_test_0 = torch.tensor(data.get_test_outputs(), dtype=torch.float) 76 | 77 | # if config.clip_grad_value != -1: 78 | # torch.nn.utils.clip_grad_value_(optimizer.param_groups[0]['params'], 79 | # config.clip_grad_value) 80 | # elif config.clip_grad_norm != -1: 81 | # torch.nn.utils.clip_grad_norm_(optimizer.param_groups[0]['params'], 82 | # config.clip_grad_norm) 83 | if config.method == 'SVGD_annealing': 84 | driving,repulsive = method.step(W, X, T,i) 85 | elif config.method == 'SGD': 86 | method.step(W, X, T) 87 | driving = 0 88 | repulsive = 0 89 | elif config.method == 'f_p_SVGD' or config.method == 'mixed_f_p_SVGD': 90 | #grid_mesh_ood[torch.randint(0, grid_mesh_ood.shape[0], (200,))] 91 | #method.step(W,X,T,grid_mesh_ood[torch.randint(0,grid_mesh_ood.shape[0],(200,))]) 92 | #driving,repulsive = method.step(W,X,T,ood_donuts) 93 | driving,repulsive = method.step(W,X,T) 94 | #method.step(W, X, T, None) 95 | 96 | else: 97 | driving,repulsive = method.step(W, X, T) 98 | 99 | driving_l.append(torch.mean(driving.abs())) 100 | repulsive_l.append(torch.mean(repulsive.abs())) 101 | 102 | if i % 10 == 0: 103 | train_loss, train_pred = P.log_prob(W, X, T, return_loss=True) 104 | test_loss, test_pred = P.log_prob(W, x_test_0, y_test_0, return_loss=True) 105 | writer.add_scalar('train/train_mse', train_loss, i) 106 | writer.add_scalar('test/test_loss', test_loss, i) 107 | writer.add_scalar('train/driving_force', torch.mean(driving.abs()), i) 108 | writer.add_scalar('train/repulsive_force', torch.mean(repulsive.abs()), i) 109 | # writer.add_scalar('train/bandwith', K.h, i) 110 | 111 | # writer.add_scalar('train/task_%d/loss_nll' % task_id, loss_nll, i) 112 | # writer.add_scalar('train/task_%d/log_det_j' % task_id, torch.mean(log),i) 113 | # writer.add_scalar('train/task_%d/loss' % task_id, loss, i) 114 | # print('Train iter:', i, train_loss) 115 | # print('Test iter:', i, test_loss) 116 | 117 | if ensemble.net.classification: 118 | Y = torch.mean(train_pred, 0) 119 | Y_t = torch.mean(test_pred, 0) 120 | train_accuracy = (torch.argmax(Y, 1) == torch.argmax(T, 1)).sum().item() / Y.shape[0] * 100 121 | test_accuracy = (torch.argmax(Y_t, 1) == torch.argmax(y_test_0, 1)).sum().item() / Y_t.shape[0] * 100 122 | writer.add_scalar('train/accuracy', train_accuracy, i) 123 | writer.add_scalar('test/accuracy', test_accuracy, i) 124 | print('Train iter:',i, ' train acc:', train_accuracy, 'test_acc', test_accuracy) 125 | # else: 126 | # print('Train iter:', i, train_loss) 127 | # print('Test iter:', i, test_loss) 128 | 129 | 130 | 131 | # Plot distribution of mean and log-variance values. 132 | # mean_outputs = torch.cat([d.clone().view(-1) for d in flow._w_0_mu]) 133 | # logvar_outputs = torch.cat([d.clone().view(-1) for d in flow._w_0_logvar]) 134 | # writer.add_histogram('train/task_%d/input_flow_mean' % task_id, 135 | # mean_outputs, i) 136 | # writer.add_histogram('train/task_%d/input_flow_logvar' % task_id, 137 | # logvar_outputs, i) 138 | if i % 50 == 0: 139 | if ensemble.net.classification and data.in_shape[0]==2: 140 | pred_tensor = ensemble.forward(torch.tensor(grid_mesh[2], dtype=torch.float))[0] 141 | #pred_tensor = ensemble.forward(torch.tensor(np.expand_dims(grid_mesh[2].sum(1), 1), dtype=torch.float))[0] 142 | average_prob = pred_tensor.mean(0) 143 | decision_b = torch.argmax(average_prob,1) 144 | entropies = -torch.sum(torch.log(average_prob + 1e-20) * average_prob, 1) 145 | #contour_plot(grid_mesh,entropies,config,i,writer = writer,data = None, title = 'Entropy '+ config.method) 146 | #contour_plot(grid_mesh,average_prob[:,0],config,i,writer = writer, data = None, title = 'Softmax'+ config.method) 147 | #contour_plot(grid_mesh,decision_b,config,i,writer = writer, data = data, title = 'Decision Boundary'+ config.method) 148 | 149 | #ood diversity 150 | 151 | pred_ood = torch.argmax((ensemble.forward(torch.tensor(grid_mesh_ood, dtype=torch.float))[0]),2) 152 | 153 | diversity = torch.mean(torch.min(pred_ood.sum(0), torch.tensor(pred_ood.shape[0])-pred_ood.sum(0) )/(pred_ood.shape[0]/2)) 154 | writer.add_scalar('test/diversity', diversity, i) 155 | 156 | 157 | elif config.dataset == 'toy_reg': 158 | pred_tensor = ensemble.forward(torch.tensor(x_test_0, dtype=torch.float)) 159 | 160 | plot_predictive_distributions(config,writer,i,data, x_test_0.squeeze(), pred_tensor.mean(0).squeeze(), 161 | pred_tensor.std(0).squeeze(), save_fig=False, publication_style=False, 162 | name=config.method) 163 | # correlation matrix. 164 | 165 | #dist= pairwise_distances(W.detach().numpy()) 166 | #plt.rcParams['figure.figsize'] = [10, 10] 167 | #plt.matshow(dist) 168 | #plt.colorbar() 169 | #writer.add_figure('distance_particles', plt.gcf(), 170 | # i, close=not config.show_plots) 171 | #plt.close() 172 | #embedded_particles = TSNE(n_components=2).fit_transform(W.detach().numpy()) 173 | #plt.figure(figsize=(10, 10)) 174 | #plt.ylim(-500, 500) 175 | #plt.xlim(-500, 500) 176 | #plt.scatter(embedded_particles[:, 0], embedded_particles[:, 1], s = 100, alpha=0.8) 177 | #writer.add_figure('2D_embedding', plt.gcf(), 178 | # i, close=not config.show_plots) 179 | 180 | if config.keep_samples != 0 and i % config.keep_samples == 0: 181 | samples.append(W.detach().clone()) 182 | samples_models = torch.cat(samples) 183 | # pred_tensor_samples = ensemble.forward(x_test_0, samples_models) 184 | # plot_predictive_distributions(config,writer,i,data, [x_test_0.squeeze()], [pred_tensor_samples.mean(0).squeeze()], 185 | # [pred_tensor_samples.std(0).squeeze()], save_fig=False, publication_style=False, 186 | # name=config.method+'smp') 187 | 188 | return driving_l, repulsive_l -------------------------------------------------------------------------------- /data/mnist/split_mnist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from warnings import warn 3 | 4 | from data.mnist.mnist_data import MNISTData 5 | 6 | def _transform_split_outputs(data, outputs): 7 | """Actual implementation of method ``transform_outputs`` for split dataset 8 | handlers. 9 | 10 | Args: 11 | data: Data handler. 12 | outputs (numpy.ndarray): See docstring of method 13 | :meth:`data.special.split_mnist.SplitMNIST.transform_outputs` 14 | 15 | Returns: 16 | (numpy.ndarray) 17 | """ 18 | if not data._full_out_dim: 19 | # TODO implement reverse direction as well. 20 | raise NotImplementedError('This method is currently only ' + 21 | 'implemented if constructor argument "full_out_dim" was set.') 22 | 23 | labels = data._labels 24 | if data.is_one_hot: 25 | assert (outputs.shape[1] == data._data['num_classes']) 26 | mask = np.zeros(data._data['num_classes'], dtype=np.bool) 27 | mask[labels] = True 28 | 29 | return outputs[:, mask] 30 | else: 31 | assert (outputs.shape[1] == 1) 32 | ret = outputs.copy() 33 | for i, l in enumerate(labels): 34 | ret[ret == l] = i 35 | return ret 36 | 37 | 38 | def get_split_mnist_handlers(data_path, use_one_hot=True, validation_size=0, 39 | use_torch_augmentation=False, 40 | num_classes_per_task=2, num_tasks=None): 41 | """This function instantiates 5 objects of the class :class:`SplitMNIST` 42 | which will contain a disjoint set of labels. 43 | 44 | The SplitMNIST task consists of 5 tasks corresponding to the images with 45 | labels [0,1], [2,3], [4,5], [6,7], [8,9]. 46 | 47 | Args: 48 | data_path: Where should the MNIST dataset be read from? If not existing, 49 | the dataset will be downloaded into this folder. 50 | use_one_hot: Whether the class labels should be represented in a one-hot 51 | encoding. 52 | validation_size: The size of the validation set of each individual 53 | data handler. 54 | use_torch_augmentation (bool): See docstring of class 55 | :class:`data.mnist_data.MNISTData`. 56 | num_classes_per_task (int): Number of classes to put into one data 57 | handler. If ``2``, then every data handler will include 2 digits. 58 | num_tasks (int, optional): The number of data handlers that should be 59 | returned by this function. 60 | 61 | Returns: 62 | (list): A list of data handlers, each corresponding to a 63 | :class:`SplitMNIST` object. 64 | """ 65 | assert num_tasks is None or num_tasks > 0 66 | if num_tasks is None: 67 | num_tasks = 10 // num_classes_per_task 68 | 69 | if not (num_tasks >= 1 and (num_tasks * num_classes_per_task) <= 10): 70 | raise ValueError('Cannot create SplitMNIST datasets for %d tasks ' \ 71 | % (num_tasks) + 'with %d classes per task.' \ 72 | % (num_classes_per_task)) 73 | 74 | print('Creating %d data handlers for SplitMNIST tasks ...' % num_tasks) 75 | 76 | handlers = [] 77 | steps = num_classes_per_task 78 | for i in range(0, 10, steps): 79 | handlers.append(SplitMNIST(data_path, use_one_hot=use_one_hot, 80 | use_torch_augmentation=use_torch_augmentation, 81 | validation_size=validation_size, labels=range(i, i+steps))) 82 | 83 | if len(handlers) == num_tasks: 84 | break 85 | 86 | print('Creating data handlers for SplitMNIST tasks ... Done') 87 | 88 | return handlers 89 | 90 | class SplitMNIST(MNISTData): 91 | """An instance of the class shall represent a SplitMNIST task. 92 | 93 | Args: 94 | data_path (str): Where should the dataset be read from? If not existing, 95 | the dataset will be downloaded into this folder. 96 | use_one_hot (bool): Whether the class labels should be represented in a 97 | one-hot encoding. 98 | validation_size (int): The number of validation samples. Validation 99 | samples will be taking from the training set (the first :math:`n` 100 | samples). 101 | use_torch_augmentation (bool): See docstring of class 102 | :class:`data.mnist_data.MNISTData`. 103 | labels (list): The labels that should be part of this task. 104 | full_out_dim (bool): Choose the original MNIST instead of the new 105 | task output dimension. This option will affect the attributes 106 | :attr:`data.dataset.Dataset.num_classes` and 107 | :attr:`data.dataset.Dataset.out_shape`. 108 | """ 109 | def __init__(self, data_path, use_one_hot=False, validation_size=1000, 110 | use_torch_augmentation=False, labels=[0, 1], 111 | full_out_dim=False): 112 | # Note, we build the validation set below! 113 | super().__init__(data_path, use_one_hot=use_one_hot, 114 | use_torch_augmentation=use_torch_augmentation, validation_size=0) 115 | 116 | self._full_out_dim = full_out_dim 117 | 118 | if isinstance(labels, range): 119 | labels = list(labels) 120 | assert np.all(np.array(labels) >= 0) and \ 121 | np.all(np.array(labels) < self.num_classes) and \ 122 | len(labels) == len(np.unique(labels)) 123 | K = len(labels) 124 | 125 | self._labels = labels 126 | 127 | train_ins = self.get_train_inputs() 128 | test_ins = self.get_test_inputs() 129 | 130 | train_outs = self.get_train_outputs() 131 | test_outs = self.get_test_outputs() 132 | 133 | # Get labels. 134 | if self.is_one_hot: 135 | train_labels = self._to_one_hot(train_outs, reverse=True) 136 | test_labels = self._to_one_hot(test_outs, reverse=True) 137 | else: 138 | train_labels = train_outs 139 | test_labels = test_outs 140 | 141 | train_labels = train_labels.squeeze() 142 | test_labels = test_labels.squeeze() 143 | 144 | train_mask = train_labels == labels[0] 145 | test_mask = test_labels == labels[0] 146 | for k in range(1, K): 147 | train_mask = np.logical_or(train_mask, train_labels == labels[k]) 148 | test_mask = np.logical_or(test_mask, test_labels == labels[k]) 149 | 150 | train_ins = train_ins[train_mask, :] 151 | test_ins = test_ins[test_mask, :] 152 | 153 | train_outs = train_outs[train_mask, :] 154 | test_outs = test_outs[test_mask, :] 155 | 156 | if validation_size > 0: 157 | if validation_size >= train_outs.shape[0]: 158 | raise ValueError('Validation set size must be smaller than ' + 159 | '%d.' % train_outs.shape[0]) 160 | val_inds = np.arange(validation_size) 161 | train_inds = np.arange(validation_size, train_outs.shape[0]) 162 | 163 | else: 164 | train_inds = np.arange(train_outs.shape[0]) 165 | 166 | test_inds = np.arange(train_outs.shape[0], 167 | train_outs.shape[0] + test_outs.shape[0]) 168 | 169 | outputs = np.concatenate([train_outs, test_outs], axis=0) 170 | 171 | if not full_out_dim: 172 | # Transform outputs, e.g., if 1-hot [0,0,0,1,0,0,0,0,0,0] -> [0,1] 173 | 174 | # Note, the method assumes `full_out_dim` when later called by a 175 | # user. We just misuse the function to call it inside the 176 | # constructor. 177 | self._full_out_dim = True 178 | outputs = self.transform_outputs(outputs) 179 | self._full_out_dim = full_out_dim 180 | 181 | # Note, we may also have to adapt the output shape appropriately. 182 | if self.is_one_hot: 183 | self._data['out_shape'] = [len(labels)] 184 | 185 | images = np.concatenate([train_ins, test_ins], axis=0) 186 | 187 | ### Overwrite internal data structure. Only keep desired labels. 188 | 189 | # Note, we continue to pretend to be a 10 class problem, such that 190 | # the user has easy access to the correct labels and has the original 191 | # 1-hot encodings. 192 | if not full_out_dim: 193 | self._data['num_classes'] = len(labels) 194 | else: 195 | self._data['num_classes'] = 10 196 | self._data['in_data'] = images 197 | self._data['out_data'] = outputs 198 | self._data['train_inds'] = train_inds 199 | self._data['test_inds'] = test_inds 200 | if validation_size > 0: 201 | self._data['val_inds'] = val_inds 202 | 203 | n_val = 0 204 | if validation_size > 0: 205 | n_val = val_inds.size 206 | 207 | print('Created SplitMNIST task with labels %s and %d train, %d test ' 208 | % (str(labels), train_inds.size, test_inds.size) + 209 | 'and %d val samples.' % (n_val)) 210 | 211 | def transform_outputs(self, outputs): 212 | """Transform the outputs from the 10D MNIST dataset into proper labels 213 | based on the constructor argument ``labels``. 214 | 215 | I.e., the output will have ``len(labels)`` classes. 216 | 217 | Example: 218 | Split with labels [2,3] 219 | 220 | 1-hot encodings: [0,0,0,1,0,0,0,0,0,0] -> [0,1] 221 | 222 | labels: 3 -> 1 223 | 224 | Args: 225 | outputs: 2D numpy array of outputs. 226 | 227 | Returns: 228 | 2D numpy array of transformed outputs. 229 | """ 230 | return _transform_split_outputs(self, outputs) 231 | 232 | def get_identifier(self): 233 | """Returns the name of the dataset.""" 234 | return 'SplitMNIST' 235 | 236 | if __name__ == '__main__': 237 | pass 238 | -------------------------------------------------------------------------------- /data/toy_classification/moons.py: -------------------------------------------------------------------------------- 1 | from data.dataset import Dataset 2 | from sklearn import datasets 3 | import numpy as np 4 | from matplotlib.colors import ListedColormap 5 | import matplotlib.pyplot as plt 6 | from data.dataset import Dataset 7 | from sklearn import datasets 8 | import numpy as np 9 | from matplotlib.colors import ListedColormap 10 | import matplotlib.pyplot as plt 11 | 12 | class Moons(Dataset): 13 | """An instance of this class shall represent the two moons classification task 14 | """ 15 | 16 | def __init__(self, rseed=1234, use_one_hot=True, noise=0.1, n_train=1500, n_test=200): 17 | """Generate a new dataset. 18 | 19 | The input data x for train and test samples are drawn from the moons dataset. 20 | Args: 21 | rseed (int): 22 | noise: Noise to inject in the input data 23 | n_train: Number of training samples. 24 | n_test: Number of test samples.. 25 | 26 | """ 27 | super().__init__() 28 | 29 | if rseed is None: 30 | rand = np.random 31 | else: 32 | rand = np.random.RandomState(rseed) 33 | 34 | train_x, train_y = datasets.make_moons(n_samples=n_train, shuffle=True, noise=noise, random_state=rseed) 35 | test_x, test_y = datasets.make_moons(n_samples=n_test, shuffle=True, noise=noise, random_state=rseed) 36 | 37 | in_data = np.vstack([train_x, test_x]) 38 | out_data = np.vstack([np.expand_dims(train_y, 1), np.expand_dims(test_y, 1)]) 39 | 40 | # Specify internal data structure. 41 | self._data['classification'] = True 42 | self._data['sequence'] = False 43 | self._data['in_data'] = in_data 44 | self._data['in_shape'] = [2] 45 | self._data['num_classes'] = 2 46 | if use_one_hot: 47 | out_data = self._to_one_hot(out_data) 48 | self._data['out_data'] = out_data 49 | self._data['out_shape'] = [2] 50 | self._data['train_inds'] = np.arange(train_x.shape[0]) 51 | self._data['test_inds'] = np.arange(train_x.shape[0], train_x.shape[0] + test_x.shape[0]) 52 | 53 | def get_identifier(self): 54 | """Returns the name of the dataset.""" 55 | return 'Moons_dataset' 56 | 57 | def get_input_mesh(self, x1_range=None, x2_range=None, grid_size=1000): 58 | """Create a 2D grid of input values to make useful plots like 59 | 60 | Note: 61 | This method is only implemented for 2D datasets. 62 | 63 | Args: 64 | x1_range (tuple, optional): The min and max value for the first 65 | input dimension. If not specified, the range will be 66 | automatically inferred. 67 | 68 | Automatical inference is based on the underlying data (train 69 | and test). The range will be set, such that all data can be 70 | drawn inside. 71 | x2_range (tuple, optional): Same as ``x1_range`` for the second 72 | input dimension. 73 | grid_size (int or tuple): How many input samples per dimension. 74 | If an integer is passed, then the same number grid size will be 75 | used for both dimension. The grid is build by equally spacing 76 | ``grid_size`` inside the ranges ``x1_range`` and ``x2_range``. 77 | 78 | Returns: 79 | (tuple): Tuple containing: 80 | 81 | - **x1_grid** (numpy.ndarray): A 2D array, containing the grid 82 | values of the first dimension. 83 | - **x2_grid** (numpy.ndarray): A 2D array, containing the grid 84 | values of the second dimension. 85 | - **flattended_grid** (numpy.ndarray): A 2D array, containing all 86 | samples from the first dimension in the first column and all 87 | values corresponding to the second dimension in the second column. 88 | This format correspond to the input format as, for instance, 89 | returned by methods such as 90 | :meth:`data.dataset.Dataset.get_train_inputs`. 91 | """ 92 | if self.in_shape[0] != 2: 93 | raise ValueError('This method only applies to 2D datasets.') 94 | 95 | if isinstance(grid_size, int): 96 | grid_size = (grid_size, grid_size) 97 | else: 98 | assert len(grid_size) == 2 99 | 100 | if x1_range is None or x2_range is None: 101 | min_x1 = self._data['in_data'][:, 0].min() 102 | min_x2 = self._data['in_data'][:, 1].min() 103 | max_x1 = self._data['in_data'][:, 0].max() 104 | max_x2 = self._data['in_data'][:, 1].max() 105 | 106 | slack_1 = (max_x1 - min_x1) * 0.02 107 | slack_2 = (max_x2 - min_x2) * 0.02 108 | 109 | if x1_range is None: 110 | x1_range = (min_x1 - slack_1, max_x1 + slack_1) 111 | else: 112 | assert len(x1_range) == 2 113 | 114 | if x2_range is None: 115 | x2_range = (min_x2 - slack_2, max_x2 + slack_2) 116 | else: 117 | assert len(x2_range) == 2 118 | 119 | x1 = np.linspace(start=x1_range[0], stop=x1_range[1], num=grid_size[0]) 120 | x2 = np.linspace(start=x2_range[0], stop=x2_range[1], num=grid_size[1]) 121 | 122 | X1, X2 = np.meshgrid(x1, x2) 123 | X = np.vstack([X1.ravel(), X2.ravel()]).T 124 | 125 | return X1, X2, X 126 | 127 | def _plot_sample(self, fig, inner_grid, num_inner_plots, ind, inputs, 128 | outputs=None, predictions=None): 129 | colors = ListedColormap(['#FF0000', '#0000FF']) 130 | 131 | # Create plot 132 | fig = plt.figure(figsize=(15, 10)) 133 | ax = fig.add_subplot(111) 134 | 135 | x_train_0 = self.get_train_inputs() 136 | y_train_0 = self.get_train_outputs() 137 | x_test_0 = self.get_test_inputs() 138 | y_test_0 = self.get_test_outputs() 139 | 140 | ax.scatter(x_train_0[:, 0], x_train_0[:, 1], alpha=1, marker='o', c=np.argmax(y_train_0, 1), cmap=colors, 141 | edgecolors='k', s=50, label='Train') 142 | ax.scatter(x_test_0[:, 0], x_test_0[:, 1], alpha=0.6, marker='s', c=np.argmax(y_test_0, 1), cmap=colors, 143 | edgecolors='k', s=50, label='test') 144 | plt.title('Data', fontsize=30) 145 | plt.legend(loc=2, fontsize=30) 146 | plt.show() 147 | 148 | 149 | 150 | 151 | class Moons_alternative(Dataset): 152 | 153 | 154 | def __init__(self, rseed=1234, use_one_hot=True, noise=0.1, n_train=1500, n_test=200): 155 | 156 | super().__init__() 157 | 158 | if rseed is None: 159 | rand = np.random 160 | else: 161 | rand = np.random.RandomState(rseed) 162 | 163 | train_x, train_y = datasets.make_moons(n_samples=n_train, shuffle=True, noise=noise, random_state=rseed) 164 | test_x, test_y = datasets.make_moons(n_samples=n_test, shuffle=True, noise=noise, random_state=rseed) 165 | 166 | in_data = np.vstack([np.expand_dims(train_x.sum(1), 1), np.expand_dims(test_x.sum(1), 1)]) 167 | out_data = np.vstack([np.expand_dims(train_y, 1), np.expand_dims(test_y, 1)]) 168 | 169 | # Specify internal data structure. 170 | self._data['classification'] = True 171 | self._data['sequence'] = False 172 | self._data['in_data'] = in_data 173 | self._data['in_shape'] = [1] 174 | self._data['num_classes'] = 2 175 | if use_one_hot: 176 | out_data = self._to_one_hot(out_data) 177 | self._data['out_data'] = out_data 178 | self._data['out_shape'] = [2] 179 | self._data['train_inds'] = np.arange(train_x.shape[0]) 180 | self._data['test_inds'] = np.arange(train_x.shape[0], train_x.shape[0] + test_x.shape[0]) 181 | 182 | def get_identifier(self): 183 | """Returns the name of the dataset.""" 184 | return 'Moons_dataset' 185 | 186 | def get_input_mesh(self, x1_range=None, x2_range=None, grid_size=1000): 187 | 188 | 189 | if isinstance(grid_size, int): 190 | grid_size = (grid_size, grid_size) 191 | else: 192 | assert len(grid_size) == 2 193 | 194 | if x1_range is None or x2_range is None: 195 | min_x1 = self._data['in_data'][:, 0].min() 196 | min_x2 = self._data['in_data'][:, 1].min() 197 | max_x1 = self._data['in_data'][:, 0].max() 198 | max_x2 = self._data['in_data'][:, 1].max() 199 | 200 | slack_1 = (max_x1 - min_x1) * 0.02 201 | slack_2 = (max_x2 - min_x2) * 0.02 202 | 203 | if x1_range is None: 204 | x1_range = (min_x1 - slack_1, max_x1 + slack_1) 205 | else: 206 | assert len(x1_range) == 2 207 | 208 | if x2_range is None: 209 | x2_range = (min_x2 - slack_2, max_x2 + slack_2) 210 | else: 211 | assert len(x2_range) == 2 212 | 213 | x1 = np.linspace(start=x1_range[0], stop=x1_range[1], num=grid_size[0]) 214 | x2 = np.linspace(start=x2_range[0], stop=x2_range[1], num=grid_size[1]) 215 | 216 | X1, X2 = np.meshgrid(x1, x2) 217 | X = np.vstack([X1.ravel(), X2.ravel()]).T 218 | 219 | return X1, X2, X 220 | 221 | def _plot_sample(self, fig, inner_grid, num_inner_plots, ind, inputs, 222 | outputs=None, predictions=None): 223 | colors = ListedColormap(['#FF0000', '#0000FF']) 224 | 225 | # Create plot 226 | fig = plt.figure(figsize=(15, 10)) 227 | ax = fig.add_subplot(111) 228 | 229 | x_train_0 = self.get_train_inputs() 230 | y_train_0 = self.get_train_outputs() 231 | x_test_0 = self.get_test_inputs() 232 | y_test_0 = self.get_test_outputs() 233 | 234 | ax.scatter(x_train_0[:, 0], x_train_0[:, 1], alpha=1, marker='o', c=np.argmax(y_train_0, 1), cmap=colors, 235 | edgecolors='k', s=50, label='Train') 236 | ax.scatter(x_test_0[:, 0], x_test_0[:, 1], alpha=0.6, marker='s', c=np.argmax(y_test_0, 1), cmap=colors, 237 | edgecolors='k', s=50, label='test') 238 | plt.title('Data', fontsize=30) 239 | plt.legend(loc=2, fontsize=30) 240 | plt.show() 241 | -------------------------------------------------------------------------------- /training/training_mnist.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import torch 3 | from utils.distributions import Unorm_post 4 | from utils.kernel import RBF 5 | matplotlib.use('Agg') 6 | import matplotlib.pyplot as plt 7 | from sklearn.metrics import pairwise_distances 8 | from sklearn.manifold import TSNE 9 | from datetime import datetime 10 | import numpy as np 11 | from skimage.filters import gaussian 12 | from utils.utils import ood_metrics_entropy 13 | 14 | 15 | def train(data_train,data_test, ensemble, device, config,writer): 16 | """Train the particles using a specific ensemble. 17 | 18 | Args: 19 | data: A DATASET instance. 20 | mnet: The model of the main network. 21 | device: Torch device (cpu or gpu). 22 | config: The command line arguments. 23 | writer: The tensorboard summary writer. 24 | """ 25 | 26 | # particles.train() 27 | 28 | #W = ensemble.particles.clone() 29 | 30 | W = ensemble.particles 31 | samples = [] 32 | 33 | optimizer = torch.optim.Adam([W], config.lr, weight_decay=config.weight_decay, 34 | betas=[config.adam_beta1, 0.999]) 35 | # optimizer = torch.optim.SGD([W], config.lr) 36 | 37 | #K_p = covariance_K() 38 | K = RBF() 39 | #K = real_PP_K() 40 | # prior = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(ensemble.net.num_params), 41 | # config.prior_variance * torch.eye( 42 | # ensemble.net.num_params)) 43 | 44 | 45 | prior = torch.distributions.normal.Normal(torch.zeros(ensemble.net.num_params).to(device), 46 | torch.ones(ensemble.net.num_params).to(device) * config.prior_variance) 47 | 48 | noise = torch.distributions.normal.Normal(torch.zeros(data_train.in_shape[0]**2).to(device), 49 | torch.ones(data_train.in_shape[0]**2).to(device)) 50 | 51 | P = Unorm_post(ensemble, prior, config, data_train.num_train_samples) 52 | 53 | log_scale = torch.log2(torch.tensor(data_train.out_shape[0],dtype=torch.float)) 54 | 55 | x_test_0 = torch.tensor(data_train.get_test_inputs(), dtype=torch.float) 56 | y_test_0 = torch.tensor(data_train.get_test_outputs(), dtype=torch.float) 57 | 58 | blurred = gaussian(x_test_0[5],sigma=inte,multichannel=False) 59 | 60 | x_test_ood = torch.tensor(data_test.get_test_inputs(), dtype=torch.float) 61 | y_test_ood = torch.tensor(data_test.get_test_outputs(), dtype=torch.float) 62 | 63 | if config.method == 'SGLD': 64 | method = SGLD(P, K, optimizer) 65 | elif config.method == 'SGD': 66 | method = SGD(P,optimizer) 67 | elif config.method == 'SVGD_annealing': 68 | method = SVGD_annealing(P, K, optimizer,config) 69 | elif config.method == 'SVGD_debug': 70 | method = SVGD_debug(P,K,optimizer) 71 | elif config.method == 'SVGD': 72 | K = RBF() 73 | method = SVGD(P,K,optimizer) 74 | elif config.method == 'f_SVGD': 75 | method = functional_SVGD(P,K,optimizer) 76 | elif config.method == 'r_SGD': 77 | method = repulsive_SGD(P,K,optimizer) 78 | elif config.method == 'f_SGD': 79 | method = functional_SGD(P,K,optimizer) 80 | elif config.method == 'f_p_SVGD': 81 | K = real_PP_K() 82 | method = f_p_SVGD(P,K,optimizer) 83 | elif config.method == 'mixed_f_p_SVGD': 84 | #K_p = covariance_K() 85 | K = RBF() 86 | K_p = real_PP_K() 87 | method = mixed_f_p_SVGD(P,K,K_p,optimizer) 88 | elif config.method == 'log_p_SVGD': 89 | method = log_p_SVGD(P,K,optimizer) 90 | elif config.method == 'log_p_SGD': 91 | method = log_p_SGD(P,K,optimizer) 92 | elif config.method == 'fisher_SVGD': 93 | method = fisher_SVGD(P,K,optimizer) 94 | elif config.method == 'fisher_x_SVGD': 95 | K = real_PP_K() 96 | method = fisher_x_SVGD(P,K,optimizer) 97 | 98 | 99 | for i in range(config.epochs): 100 | 101 | optimizer.zero_grad() 102 | 103 | batch_train = data_train.next_train_batch(config.batch_size) 104 | batch_test = data_train.next_test_batch(config.batch_size) 105 | batch_ood = data_test.next_train_batch(config.batch_size) 106 | X = data_train.input_to_torch_tensor(batch_train[0], device, mode='train') 107 | T = data_train.output_to_torch_tensor(batch_train[1], device, mode='train') 108 | X_t = data_train.input_to_torch_tensor(batch_test[0], device, mode='train') 109 | T_t = data_train.output_to_torch_tensor(batch_test[1], device, mode='train') 110 | 111 | #X_ood = data_train.input_to_torch_tensor(batch_ood[0], device, mode='train') 112 | #T_ood = data_train.output_to_torch_tensor(batch_ood[1], device, mode='train') 113 | 114 | #Adding noise to test as oood 115 | #x_test_ood = x_test_0 + noise.sample(torch.Size([x_test_0.shape[0]])) 116 | 117 | 118 | 119 | # if config.clip_grad_value != -1: 120 | # torch.nn.utils.clip_grad_value_(optimizer.param_groups[0]['params'], 121 | # config.clip_grad_value) 122 | # elif config.clip_grad_norm != -1: 123 | # torch.nn.utils.clip_grad_norm_(optimizer.param_groups[0]['params'], 124 | # config.clip_grad_norm) 125 | if config.method == 'SVGD_annealing': 126 | driving,repulsive = method.step(W, X, T,i) 127 | elif config.method == 'SVGD_debug': 128 | driving,repulsive = method.step(W, X, T) 129 | elif config.method == 'SGD': 130 | method.step(W, X, T) 131 | elif config.method == 'f_p_SVGD' or config.method == 'mixed_f_p_SVGD': 132 | noise_samples = noise.sample(torch.Size([config.batch_size])) 133 | driving,repulsive = method.step(W,X,T,noise_samples) 134 | #method.step(W, X, T, None) 135 | 136 | else: 137 | driving,repulsive = method.step(W, X, T) 138 | 139 | if i % 10 == 0: 140 | train_loss, train_pred = P.log_prob(W, X, T, return_loss=True) 141 | test_loss, test_pred = P.log_prob(W, x_test_0, y_test_0, return_loss=True) 142 | writer.add_scalar('train/train_loss', train_loss, i) 143 | writer.add_scalar('test/test_loss', test_loss, i) 144 | if config.method != 'SGD': 145 | writer.add_scalar('train/driving_force', torch.mean(driving.abs()), i) 146 | writer.add_scalar('train/repulsive_force', torch.mean(repulsive.abs()), i) 147 | # writer.add_scalar('train/bandwith', K.h, i) 148 | 149 | if ensemble.net.classification: 150 | Y = torch.mean(train_pred, 0) 151 | Y_t = torch.mean(test_pred, 0) 152 | entropies_test = -torch.sum(torch.log2(Y_t + 1e-20)/log_scale * Y_t, 1) 153 | 154 | train_accuracy = (torch.argmax(Y, 1) == torch.argmax(T, 1)).sum().item() / Y.shape[0] * 100 155 | test_accuracy = (torch.argmax(Y_t, 1) == torch.argmax(y_test_0, 1)).sum().item() / Y_t.shape[0] * 100 156 | writer.add_scalar('train/accuracy', train_accuracy, i) 157 | writer.add_scalar('test/accuracy', test_accuracy, i) 158 | writer.add_scalar('test/entropy', entropies_test.mean(), i) 159 | 160 | print('Train iter:',i, ' train acc:', train_accuracy, 'test_acc', test_accuracy) 161 | 162 | if i % 50 == 0: 163 | if ensemble.net.classification: 164 | 165 | #ood diversity 166 | softmax_ood = ensemble.forward(x_test_ood)[0] 167 | pred_ood = torch.argmax(softmax_ood,2) 168 | 169 | average_prob = softmax_ood.mean(0) 170 | KL_uniform = -torch.sum(average_prob * torch.log2( 171 | (torch.ones(average_prob.shape[1]) / average_prob.shape[1] )/ average_prob + 1e-20)/log_scale, 1) 172 | KL_uniform[KL_uniform != KL_uniform] = 0 173 | writer.add_scalar('ood_metrics/AV_KL_uniform', KL_uniform.mean(), i) 174 | entropies_ood = -torch.sum(torch.log2(average_prob + 1e-20)/log_scale * average_prob, 1) 175 | writer.add_scalar('ood_metrics/Av_entropy', entropies_ood.mean(), i) 176 | 177 | 178 | # diversity = torch.mean(torch.min(pred_ood.sum(0), torch.tensor(pred_ood.shape[0])-pred_ood.sum(0) )/(pred_ood.shape[0]/2)) 179 | # writer.add_scalar('ood_metrics/diversity', diversity, i) 180 | average_prob = softmax_ood.mean(0) 181 | #entropies = -torch.sum(torch.log(average_prob + 1e-20) * average_prob, 1) 182 | #average_pred_ood = torch.argmax(average_prob,1) 183 | 184 | rocauc = ood_metrics_entropy(entropies_test,entropies_ood,writer,config,i) 185 | writer.add_scalar('ood_metrics/AUROC', rocauc[0], i) 186 | writer.add_scalar('ood_metrics/AUPR_IN', rocauc[1], i) 187 | writer.add_scalar('ood_metrics/AUPR_OUT', rocauc[2], i) 188 | 189 | #writer.add_hparams(dict(config.__dict__), {}) 190 | 191 | 192 | # distance matrix between particles . 193 | 194 | dist= pairwise_distances(W.detach().numpy()) 195 | plt.rcParams['figure.figsize'] = [10, 10] 196 | plt.matshow(dist +np.identity(config.n_particles)*np.max(dist)) 197 | plt.colorbar() 198 | writer.add_figure('distance_particles', plt.gcf(), 199 | i, close=not config.show_plots) 200 | plt.close() 201 | embedded_particles = TSNE(n_components=2).fit_transform(W.detach().numpy()) 202 | plt.figure(figsize=(10, 10)) 203 | plt.ylim(-500, 500) 204 | plt.xlim(-500, 500) 205 | plt.scatter(embedded_particles[:, 0], embedded_particles[:, 1], s = 100, alpha=0.8) 206 | writer.add_figure('2D_embedding', plt.gcf(), 207 | i, close=not config.show_plots) 208 | 209 | if config.keep_samples != 0 and i % config.keep_samples == 0: 210 | samples.append(W.detach().clone()) 211 | samples_models = torch.cat(samples) 212 | # pred_tensor_samples = ensemble.forward(x_test_0, samples_models) 213 | # plot_predictive_distributions(config,writer,i,data, [x_test_0.squeeze()], [pred_tensor_samples.mean(0).squeeze()], 214 | # [pred_tensor_samples.std(0).squeeze()], save_fig=False, publication_style=False, 215 | # name=config.method+'smp') 216 | 217 | if config.save_particles !=0 and i% config.save_particles == 0: 218 | particles = ensemble.particles.detach().numpy() 219 | np.save(datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '.np', particles) 220 | 221 | return samples_models -------------------------------------------------------------------------------- /models/net_utils/torch_ckpts.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | The original implementation can be found `here `__. 4 | 5 | """ 6 | 7 | import os 8 | import torch 9 | import time 10 | import json 11 | 12 | # Key that will be added to the state dictionary for maintenance reasons. 13 | _INTERNAL_KEY = '_ckpt_internal' 14 | 15 | def save_checkpoint(ckpt_dict, file_path, performance_score, train_iter=None, 16 | max_ckpts_to_keep=5, keep_cktp_every=2, timestamp=None): 17 | """Save checkpoint to file. 18 | 19 | Example: 20 | .. code-block:: python 21 | 22 | save_checkpoint({ 23 | 'state_dict': net.state_dict(), 24 | 'train_iter': curr_iteration 25 | }, 'ckpts/my_net', current_test_accuracy) 26 | 27 | Args: 28 | ckpt_dict: A dict with mostly arbitrary content. Though, most important, 29 | it needs to include the state dict and should also include 30 | the current training iteration. 31 | file_path: Where to store the checkpoint. Note, the filepath should 32 | not change. Instead, ``train_iter`` should be provided, 33 | such that this method can handle the filenames by itself. 34 | 35 | Note: 36 | The function currently assumes that within the same directory, 37 | no checkpoint filenname is the prefix of another 38 | checkpoint filename (e.g., if several networks are checkpointed 39 | into the same directory). 40 | performance_score: A score that expresses the performance of the 41 | current network state, e.g., accuracy for a 42 | classification task. This score is used to 43 | maintain the list of kept checkpoints during 44 | training. 45 | train_iter (optional): If given, it will be added to the filename. 46 | Otherwise, existing checkpoints are simply overwritten. 47 | max_ckpts_to_keep: The maximum number of checkpoints to 48 | keep. This will use the performance score to determine the n-1 49 | checkpoints not to be deleted (where n is the number of 50 | checkpoints to keep). The current checkpoint will always be saved. 51 | keep_cktp_every: If this option is not :code:`None`, 52 | then every n hours one checkpoint will be permanently saved, i.e., 53 | this checkpoint will not be maintained by 'max_ckpts_to_keep' 54 | anymore. The checkpoint to be kept will be the best one from the 55 | time window that spans the last n hours. 56 | timestamp (optional): The timestamp of this checkpoint. If not given, 57 | a current timestamp will be used. This option is useful when one 58 | aims to synchronize checkpoint savings from multiple networks. 59 | """ 60 | if timestamp is None: 61 | ts = time.time() # timestamp 62 | else: 63 | ts = timestamp 64 | 65 | assert('state_dict' in ckpt_dict.keys()) 66 | # We need to store internal (checkpoint maintenance related) information in 67 | # each checkpoint. 68 | internal_key = _INTERNAL_KEY 69 | assert(internal_key not in ckpt_dict.keys()) 70 | ckpt_dict[internal_key] = dict() 71 | ckpt_dict[internal_key]['permanent'] = False 72 | ckpt_dict[internal_key]['score'] = performance_score 73 | ckpt_dict[internal_key]['ts']= ts 74 | 75 | # FIXME We currently don't care about file extensions. 76 | dname, fname = os.path.split(file_path) 77 | # Where do we store meta data, needed for maintenance. 78 | meta_fn = ('.' if not fname.startswith('.') else '') + fname + '_meta' 79 | meta_fn = os.path.join(dname, meta_fn) 80 | 81 | if not os.path.exists(dname): 82 | os.makedirs(dname) 83 | 84 | # Needed for option 'keep_cktp_every'. When was the first ckpt stored? 85 | if not os.path.exists(meta_fn): 86 | with open(meta_fn, 'w') as f: 87 | json.dump({'init_ts': ts}, f) 88 | init_ts = ts 89 | else: 90 | with open(meta_fn) as f: 91 | meta_dict = json.load(f) 92 | init_ts = meta_dict['init_ts'] 93 | 94 | hrs_passed = (ts - init_ts) / (60 * 60) 95 | 96 | ### Iterate all existing checkpoints to determine which we remove. 97 | ckpt_fns = [os.path.join(dname, f) for f in os.listdir(dname) if 98 | os.path.isfile(os.path.join(dname, f)) and 99 | f.startswith(fname)] 100 | 101 | kept_ckpts = [] 102 | permanent_ckpts = [] 103 | 104 | for fn in ckpt_fns: 105 | # FIXME loading all checkpoints is expensive. 106 | ckpt = torch.load(fn) 107 | 108 | if not internal_key in ckpt: 109 | continue 110 | 111 | if ckpt[internal_key]['permanent']: 112 | permanent_ckpts.append((fn, ckpt[internal_key]['ts'])) 113 | else: 114 | kept_ckpts.append((fn, ckpt[internal_key]['ts'], 115 | ckpt[internal_key]['score'])) 116 | 117 | ## Decide, whether a new permanent checkpoint should be saved. 118 | if keep_cktp_every is not None and hrs_passed >= keep_cktp_every: 119 | perm_ckpt_needed = True 120 | 121 | num_wins = hrs_passed // keep_cktp_every 122 | win_start = (num_wins-1) * keep_cktp_every 123 | 124 | # Check whether a permanent checkpoint for the current window already 125 | # exists. 126 | if len(permanent_ckpts) > 0: 127 | permanent_ckpts.sort(key=lambda tup: tup[1], reverse=True) 128 | 129 | ts_last_perm = permanent_ckpts[0][1] 130 | if ((ts_last_perm - init_ts) / (60 * 60)) >= win_start: 131 | perm_ckpt_needed = False 132 | 133 | if perm_ckpt_needed: 134 | # Choose the checkpoint with the best score in the current window 135 | # as next permanent checkpoint. 136 | kept_ckpts.sort(key=lambda tup: tup[1], reverse=True) 137 | max_score = -1 138 | max_ind = -1 139 | 140 | for i, tup in enumerate(kept_ckpts): 141 | if ((tup[1] - init_ts) / (60 * 60)) < win_start: 142 | break 143 | 144 | if max_ind == -1 or max_score < tup[2]: 145 | max_ind = i 146 | max_score = tup[2] 147 | 148 | if max_ind != -1 and max_score > performance_score: 149 | # Transform an existing checkpoint into a permanent one. 150 | ckpt_tup = kept_ckpts[max_ind] 151 | # Important, we need to remove this item from the kept_ckpts, 152 | # as this list is used in the next step to determine which 153 | # checkpoints are removed. 154 | del kept_ckpts[max_ind] 155 | print('Checkpoint %s will be kept permanently.' % ckpt_tup[0]) 156 | 157 | # FIXME: We might need the device here as in the load method. 158 | ckpt = torch.load(ckpt_tup[0]) 159 | ckpt[internal_key]['permanent'] = True 160 | torch.save(ckpt, ckpt_tup[0]) 161 | 162 | else: 163 | print('New checkpoint will be kept permanently.') 164 | ckpt_dict[internal_key]['permanent'] = True 165 | 166 | ## Decide, whether a checkpoint has to be deleted. 167 | if len(kept_ckpts) >= max_ckpts_to_keep: 168 | kept_ckpts.sort(key=lambda tup: tup[2]) 169 | 170 | for i in range(len(kept_ckpts) - (max_ckpts_to_keep-1)): 171 | fn = kept_ckpts[i][0] 172 | print('Deleting old checkpoint: %s.' % fn) 173 | os.remove(fn) 174 | 175 | ### Save new checkpoint. 176 | if train_iter is not None: 177 | file_path += '_%d' % train_iter 178 | 179 | torch.save(ckpt_dict, file_path) 180 | print('Checkpoint saved to %s' % file_path) 181 | 182 | def load_checkpoint(ckpt_path, net, device=None, ret_performance_score=False): 183 | """Load a checkpoint from file. 184 | 185 | Args: 186 | ckpt_path: Path to checkpoint. 187 | net: The network, that should load the state dict saved in this 188 | checkpoint. 189 | device (optional): The device currently used by the model. Can help to 190 | speed up loading the checkpoint. 191 | ret_performance_score: If True, the score associated with this 192 | checkpoint will be returned as well. See argument 193 | "performance_score" of method "save_ckecpoint". 194 | 195 | Returns: 196 | The loaded checkpoint. Note, the state_dict is already applied to the 197 | network. However, there might be other important dict elements. 198 | """ 199 | # See here for details on how to load the checkpoint: 200 | # https://blog.floydhub.com/checkpointing-tutorial-for-tensorflow-keras-and-pytorch/ 201 | if device is not None and device.type == 'cuda': 202 | ckpt = torch.load(ckpt_path) 203 | else: 204 | # Load GPU model on CPU 205 | ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage) 206 | 207 | net.load_state_dict(ckpt['state_dict']) 208 | 209 | if ret_performance_score: 210 | score = ckpt[_INTERNAL_KEY]['score'] 211 | 212 | # That key was added for maintenance reasons in the method save_checkpoint. 213 | if _INTERNAL_KEY in ckpt: 214 | del ckpt[_INTERNAL_KEY] 215 | 216 | if ret_performance_score: 217 | return ckpt, score 218 | 219 | return ckpt 220 | 221 | def make_ckpt_list(file_path): 222 | """Creates a file that lists all checkpoints together with there scores, 223 | such that one can easily find the checkpoint associated with the maximum 224 | score. 225 | 226 | Args: 227 | file_path: See method :func:`save_checkpoints`. 228 | """ 229 | internal_key = _INTERNAL_KEY 230 | 231 | dname, fname = os.path.split(file_path) 232 | 233 | assert(os.path.exists(dname)) 234 | 235 | ckpt_fns = [(f, os.path.join(dname, f)) for f in os.listdir(dname) if 236 | os.path.isfile(os.path.join(dname, f)) and 237 | f.startswith(fname)] 238 | 239 | ckpts = [] 240 | 241 | for fn, fpath in ckpt_fns: 242 | ckpt = torch.load(fpath) 243 | 244 | if not internal_key in ckpt: 245 | continue 246 | 247 | score = ckpt[internal_key]['score'] 248 | 249 | ckpts.append((fn, score)) 250 | 251 | ckpts.sort(key=lambda tup: tup[1], reverse=True) 252 | 253 | with open(os.path.join(dname, 'score_list_' + fname + '.txt'), 'w') as f: 254 | for tup in ckpts: 255 | f.write('%s, %f\n' % (tup[0], tup[1])) 256 | 257 | def get_best_ckpt_path(file_path): 258 | """Returns the path to the checkpoint with the highest score. 259 | 260 | Args: 261 | file_path: See method :func:`save_checkpoints`. 262 | """ 263 | dname, fname = os.path.split(file_path) 264 | assert(os.path.exists(dname)) 265 | 266 | # See method make_ckpt_list. 267 | ckpt_list_fn = os.path.join(dname, 'score_list_' + fname + '.txt') 268 | if os.path.exists(ckpt_list_fn): 269 | with open(ckpt_list_fn, 'r') as f: 270 | # Get first word from file. Note, the filename ends with a comma. 271 | best_ckpt_fname = f.readline().split(None, 1)[0][:-1] 272 | 273 | return os.path.join(dname, best_ckpt_fname) 274 | 275 | # Go through each checkpoint and evaluate the score achieved. 276 | ckpt_fns = [(f, os.path.join(dname, f)) for f in os.listdir(dname) if 277 | os.path.isfile(os.path.join(dname, f)) and 278 | f.startswith(fname)] 279 | 280 | best_ckpt_path = None 281 | best_score = -1 282 | 283 | for fn, fpath in ckpt_fns: 284 | ckpt = torch.load(fpath) 285 | 286 | if not _INTERNAL_KEY in ckpt: 287 | continue 288 | 289 | score = ckpt[_INTERNAL_KEY]['score'] 290 | if score > best_score: 291 | best_score = score 292 | best_ckpt_path = fpath 293 | 294 | return best_ckpt_path 295 | 296 | if __name__ == '__main__': 297 | pass 298 | 299 | 300 | -------------------------------------------------------------------------------- /data/svhn/data_svhn_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Street View House Numbers (SVHN) Dataset 3 | ---------------------------------------- 4 | 5 | The original implementation can be found `here `__. 6 | 7 | 8 | The module :mod:`data.svhn_data` contains a handler for the 9 | `SVHN `__ dataset. 10 | 11 | The dataset was introduced in: 12 | 13 | Netzer et al., `Reading Digits in Natural Images with Unsupervised Feature \ 14 | Learning `__, 15 | 2011. 16 | 17 | This module contains a simple wrapper from the corresponding 18 | `torchvision `__ 19 | class :class:`torchvision.datasets.SVHN` to our dataset interface 20 | :class:`data.dataset.Dataset`. 21 | """ 22 | 23 | import matplotlib.pyplot as plt 24 | import numpy as np 25 | import os 26 | from torchvision.datasets import SVHN 27 | 28 | from data.cifar.cifar10_data import CIFAR10Data 29 | from data.dataset import Dataset 30 | 31 | class SVHNData(Dataset): 32 | """An instance of the class shall represent the SVHN dataset. 33 | 34 | Note: 35 | By default, input samples are provided in a range of ``[0, 1]``. 36 | 37 | Args: 38 | data_path (str): Where should the dataset be read from? If not existing, 39 | the dataset will be downloaded into this folder. 40 | use_one_hot (bool): Whether the class labels should be 41 | represented in a one-hot encoding. 42 | validation_size (int): The number of validation samples. Validation 43 | samples will be taking from the training set (the first :math:`n` 44 | samples). 45 | use_torch_augmentation (bool): Note, this option currently only applies 46 | to input batches that are transformed using the class member 47 | :meth:`input_to_torch_tensor` (hence, **only available for 48 | PyTorch**, so far). 49 | 50 | The augmentation will be identical to the one provided by class 51 | :class:`data.cifar10_data.CIFAR10Data`, **except** that during 52 | training no random horizontal flips are applied. 53 | 54 | Note: 55 | If activated, the statistics of test samples are changed as 56 | a normalization is applied (identical to the of class 57 | :class:`data.cifar10_data.CIFAR10Data`). 58 | use_cutout (bool): Whether option ``apply_cutout`` should be set of 59 | method :meth:`torch_input_transforms`. We use cutouts of size 60 | ``20 x 20`` as recommended 61 | `here `__. 62 | 63 | Note: 64 | Only applies if ``use_data_augmentation`` is set. 65 | include_train_extra (bool): The training dataset can be extended by 66 | "531,131 additional, somewhat less difficult samples" (see 67 | `here `__). 68 | 69 | Note, as long as the validation set size is smaller than the 70 | original training set size, all validation samples would be taken 71 | from the original training set (and thus not contain those "less 72 | difficult" samples). 73 | """ 74 | # In which subfolder of the datapath should the data be stored. 75 | _SUBFOLDER = 'SVHN' 76 | 77 | def __init__(self, data_path, use_one_hot=False, validation_size=0, 78 | use_torch_augmentation=False, use_cutout=False, 79 | include_train_extra=False): 80 | super().__init__() 81 | 82 | # Actual data path 83 | data_path = os.path.join(data_path, SVHNData._SUBFOLDER) 84 | if not os.path.exists(data_path): 85 | os.makedirs(data_path) 86 | 87 | svhn_train = SVHN(data_path, split='train', download=True) 88 | svhn_test = SVHN(data_path, split='test', download=True) 89 | svhn_extra = None 90 | if include_train_extra: 91 | svhn_extra = SVHN(data_path, split='extra', download=True) 92 | 93 | assert np.all(np.equal(svhn_train.data.shape, [73257, 3, 32, 32])) 94 | assert np.all(np.equal(svhn_test.data.shape, [26032, 3, 32, 32])) 95 | assert not include_train_extra or \ 96 | np.all(np.equal(svhn_extra.data.shape, [531131, 3, 32, 32])) 97 | 98 | train_inputs = svhn_train.data 99 | test_inputs = svhn_test.data 100 | train_labels = svhn_train.labels 101 | test_labels = svhn_test.labels 102 | if include_train_extra: 103 | train_inputs = np.concatenate([train_inputs, svhn_extra.data], 104 | axis=0) 105 | train_labels = np.concatenate([train_labels, svhn_extra.labels], 106 | axis=0) 107 | 108 | images = np.concatenate([train_inputs, test_inputs], axis=0) 109 | labels = np.concatenate([train_labels, test_labels], axis=0) 110 | 111 | # Note, images are currently encoded in a way, that their shape 112 | # corresponds to (3, 32, 32). For consistency reasons, we would like to 113 | # change that to (32, 32, 3). 114 | images = np.rollaxis(images, 1, 4) 115 | # Scale images into a range between 0 and 1. 116 | images = images / 255. 117 | 118 | images = images.reshape(-1, 32 * 32 * 3) 119 | labels = labels.reshape(-1, 1) 120 | 121 | val_inds = None 122 | train_inds = np.arange(train_labels.size) 123 | test_inds = np.arange(train_labels.size, 124 | train_labels.size + test_labels.size) 125 | 126 | if validation_size > 0: 127 | if validation_size >= train_inds.size: 128 | raise ValueError('Validation set must contain less than %d ' \ 129 | % (train_inds.size) + 'samples!') 130 | 131 | val_inds = np.arange(validation_size) 132 | train_inds = np.arange(validation_size, train_inds.size) 133 | 134 | # Bring everything into the internal structure of the Dataset class. 135 | self._data['classification'] = True 136 | self._data['sequence'] = False 137 | self._data['num_classes'] = 10 138 | self._data['is_one_hot'] = use_one_hot 139 | self._data['in_data'] = images 140 | self._data['in_shape'] = [32, 32, 3] 141 | self._data['out_shape'] = [10 if use_one_hot else 1] 142 | self._data['val_inds'] = val_inds 143 | self._data['train_inds'] = train_inds 144 | self._data['test_inds'] = test_inds 145 | 146 | if use_one_hot: 147 | labels = self._to_one_hot(labels) 148 | 149 | self._data['out_data'] = labels 150 | 151 | # Dataset specific attributes. 152 | self._data['svhn'] = dict() 153 | # 0 - original train, 1 - extra train, 2 - test 154 | # Note, independent of whether samples are now in the validation set. 155 | self._data['svhn']['type'] = np.zeros(self._data['in_data'].shape[0]) 156 | if include_train_extra: 157 | self._data['svhn']['type'][svhn_train.labels.size:] = 1 158 | self._data['svhn']['type'][test_inds] = 2 159 | 160 | # Initialize PyTorch data augmentation. 161 | self._augment_inputs = False 162 | if use_torch_augmentation: 163 | self._augment_inputs = True 164 | # Note, horizontal flips change the meaning of digits! 165 | self._train_transform, self._test_transform = \ 166 | CIFAR10Data.torch_input_transforms(apply_rand_hflips=False, 167 | apply_cutout=use_cutout, cutout_length=20) 168 | 169 | print('Created %s.' % (str(self))) 170 | 171 | def get_identifier(self): 172 | """Returns the name of the dataset.""" 173 | return 'SVHN' 174 | 175 | def input_to_torch_tensor(self, x, device, mode='inference', 176 | force_no_preprocessing=False, sample_ids=None): 177 | """This method can be used to map the internal numpy arrays to PyTorch 178 | tensors. 179 | 180 | Note, this method has been overwritten from the base class. 181 | 182 | The input images are preprocessed if data augmentation is enabled. 183 | Preprocessing involves normalization and (for training mode) random 184 | perturbations. 185 | 186 | Args: 187 | (....): See docstring of method 188 | :meth:`data.dataset.Dataset.input_to_torch_tensor`. 189 | 190 | Returns: 191 | (torch.Tensor): The given input ``x`` as PyTorch tensor. 192 | """ 193 | # FIXME Method copied from `CIFAR100Data`. 194 | if self._augment_inputs and not force_no_preprocessing: 195 | if mode == 'inference': 196 | transform = self._test_transform 197 | elif mode == 'train': 198 | transform = self._train_transform 199 | else: 200 | raise ValueError('"%s" not a valid value for argument "mode".' 201 | % mode) 202 | 203 | return CIFAR10Data.torch_augment_images(x, device, transform) 204 | 205 | else: 206 | return Dataset.input_to_torch_tensor(self, x, device, 207 | mode=mode, force_no_preprocessing=force_no_preprocessing, 208 | sample_ids=sample_ids) 209 | 210 | def _plot_sample(self, fig, inner_grid, num_inner_plots, ind, inputs, 211 | outputs=None, predictions=None, batch_ids=None): 212 | """Implementation of abstract method 213 | :meth:`data.dataset.Dataset._plot_sample`. 214 | 215 | Args: 216 | batch_ids (numpy.ndarray, optional): If provided, then samples 217 | stemming from the "extra" training set will be marked in the 218 | caption. 219 | """ 220 | ax = plt.Subplot(fig, inner_grid[0]) 221 | 222 | lbl = 'SVHN sample' 223 | if batch_ids is not None: 224 | stype = self._data['svhn']['type'][batch_ids[ind]] 225 | if stype == 1: 226 | lbl = 'SVHN (extra) sample' 227 | 228 | if outputs is None: 229 | ax.set_title(lbl) 230 | else: 231 | assert(np.size(outputs) == 1) 232 | label = np.asscalar(outputs) 233 | 234 | if predictions is None: 235 | ax.set_title('%s\nLabel: %d' % (lbl, label)) 236 | else: 237 | if np.size(predictions) == self.num_classes: 238 | pred_label = np.argmax(predictions) 239 | else: 240 | pred_label = np.asscalar(predictions) 241 | 242 | ax.set_title('%s\nLabel: %d, Prediction: %d' % \ 243 | (lbl, label, pred_label)) 244 | 245 | ax.set_axis_off() 246 | ax.imshow(np.squeeze(np.reshape(inputs, self.in_shape))) 247 | fig.add_subplot(ax) 248 | 249 | if num_inner_plots == 2: 250 | ax = plt.Subplot(fig, inner_grid[1]) 251 | ax.set_title('Predictions') 252 | bars = ax.bar(range(self.num_classes), np.squeeze(predictions)) 253 | ax.set_xticks(range(self.num_classes)) 254 | if outputs is not None: 255 | bars[int(label)].set_color('r') 256 | fig.add_subplot(ax) 257 | 258 | def _plot_config(self, inputs, outputs=None, predictions=None): 259 | """Re-Implementation of method 260 | :meth:`data.dataset.Dataset._plot_config`. 261 | 262 | This method has been overriden to ensure, that there are 2 subplots, 263 | in case the predictions are given. 264 | """ 265 | plot_configs = Dataset._plot_config(self, inputs, outputs=outputs, 266 | predictions=predictions) 267 | 268 | if predictions is not None and \ 269 | np.shape(predictions)[1] == self.num_classes: 270 | plot_configs['outer_hspace'] = 0.6 271 | plot_configs['inner_hspace'] = 0.4 272 | plot_configs['num_inner_rows'] = 2 273 | #plot_configs['num_inner_cols'] = 1 274 | plot_configs['num_inner_plots'] = 2 275 | 276 | return plot_configs 277 | 278 | def __str__(self): 279 | return 'SVHN Dataset with %d training, %d validation and %d test ' % \ 280 | (self.num_train_samples, self.num_val_samples, 281 | self.num_test_samples) + 'samples' 282 | 283 | if __name__ == '__main__': 284 | pass 285 | 286 | 287 | --------------------------------------------------------------------------------