├── .gitignore ├── README.md ├── extraction └── extract_continual.py ├── figures ├── dist_taskA_taskB.png ├── results_continual_conv.png ├── results_dissimilar.png └── results_similar.png ├── main_continual.py └── utils ├── BBBConvmodel.py ├── BBBdistributions.py ├── BBBlayers.py └── __init__.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | data 3 | utils/__pycache__/ 4 | results 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Continual learning with a Bayesian CNN 2 | 3 | ## Instructions 4 | `main_continual.py` is the main file. Run it each time with the subsequent hyperparameters configurations: 5 | 6 | 1. `pretrained = False` will give you `weights_1.pkl` (everything else as it is by default) 7 | 2. `pretrained = True`, and `task = 2` will give you `weights_2.pkl` to evaluate forward transfer 8 | 3. `is_training = False`, `pretrained = True`, `task = 2` to evaluate backward transfer, i.e. ability to overcome catastrophic forgetting 9 | 4. `pretrained = True`, `task = 3` will give you `weights_3.pkl` 10 | 5. `is_training = False`, `pretrained = True`, `task = 3` 11 | 6. (and so on for as many tasks as you want) 12 | 13 | ### how we transfer parameters 14 | 15 | ![posterior of task A = prior of task B](figures/dist_taskA_taskB.png) 16 | 17 | ### Results 18 | 19 | #### Results on permutations on MNIST 20 | ![Results on permutations on MNIST](figures/results_continual_conv.png) 21 | 22 | #### Results on *similar* CIFAR-100 classes ("leopards", "tigers", "lions") 23 | ![Results on *similar* CIFAR-100 classes ("leopards", "tigers", "lions")](figures/results_similar.png) 24 | 25 | #### Results on *dissimilar* CIFAR-100 classes ("leopards", "palms", "bicycles") 26 | ![Results on *dissimilar* CIFAR-100 classes ("leopards", "palms", "bicycles")](figures/results_dissimilar.png) 27 | 28 | 29 | [Bayesian CNN repository](https://github.com/felix-laumann/Bayesian_CNN) 30 | -------------------------------------------------------------------------------- /extraction/extract_continual.py: -------------------------------------------------------------------------------- 1 | import os 2 | import matplotlib 3 | matplotlib.use("Agg") 4 | import matplotlib.pyplot as plt 5 | plt.style.use("seaborn-whitegrid") 6 | import re 7 | import numpy as np 8 | plt.rc('font', family='serif', size=32) 9 | plt.rcParams.update({'xtick.labelsize': 32, 'ytick.labelsize': 32, 'axes.labelsize': 32}) 10 | 11 | os.chdir("/home/felix/Dropbox/publications/Bayesian_CNN_continual/results/") 12 | 13 | with open("diagnostics_1.txt", 'r') as file: 14 | acc = re.findall(r"'acc':\s+tensor\((.*?)\)", file.read()) 15 | print(acc) 16 | 17 | train_1 = acc[1::2] 18 | valid_1 = acc[0::2] 19 | train_1 = np.array(train_1).astype(np.float32) 20 | valid_1 = np.array(valid_1).astype(np.float32) 21 | 22 | with open("diagnostics_2.txt", 'r') as file: 23 | acc = re.findall(r"'acc':\s+tensor\((.*?)\)", file.read()) 24 | print(acc) 25 | 26 | train_2 = acc[1::2] 27 | valid_2 = acc[0::2] 28 | train_2 = np.array(train_2).astype(np.float32) 29 | valid_2 = np.array(valid_2).astype(np.float32) 30 | 31 | with open("diagnostics_2_eval.txt", 'r') as file: 32 | valid_2_eval_A = re.findall(r"'acc':\s+tensor\((.*?)\)", file.read()) 33 | print(valid_2_eval_A) 34 | 35 | valid_2_eval_A = np.array(valid_2_eval_A).astype(np.float32) 36 | 37 | with open("diagnostics_3.txt", 'r') as file: 38 | acc = re.findall(r"'acc':\s+tensor\((.*?)\)", file.read()) 39 | print(acc) 40 | 41 | train_3 = acc[1::2] 42 | valid_3 = acc[0::2] 43 | train_3 = np.array(train_3).astype(np.float32) 44 | valid_3 = np.array(valid_3).astype(np.float32) 45 | 46 | with open("diagnostics_3_eval.txt", 'r') as file: 47 | valid_3_eval_B = re.findall(r"'acc':\s+tensor\((.*?)\)", file.read()) 48 | print(valid_3_eval_B) 49 | 50 | valid_3_eval_B = np.array(valid_3_eval_B).astype(np.float32) 51 | """ 52 | with open("diagnostics_3_eval_A.txt", 'r') as file: 53 | valid_3_eval_A = re.findall(r"'acc':\s+tensor\((.*?)\)", file.read()) 54 | print(valid_3_eval_A) 55 | 56 | valid_3_eval_A = np.array(valid_3_eval_A).astype(np.float32) 57 | 58 | 59 | with open("diagnostics_4.txt", 'r') as file: 60 | acc = re.findall(r"'acc':\s+tensor\((.*?)\)", file.read()) 61 | print(acc) 62 | 63 | train_4 = acc[1::2] 64 | valid_4 = acc[0::2] 65 | train_4 = np.array(train_4).astype(np.float32) 66 | valid_4 = np.array(valid_4).astype(np.float32) 67 | 68 | with open("diagnostics_4_eval.txt", 'r') as file: 69 | valid_4_eval = re.findall(r"'acc':\s+tensor\((.*?)\)", file.read()) 70 | print(valid_4_eval) 71 | 72 | valid_4_eval = np.array(valid_4_eval).astype(np.float32) 73 | """ 74 | f = plt.figure(figsize=(20, 16)) 75 | 76 | plt.plot(valid_1, label=r"Validation task A , prior: $U(a, b)$", color='maroon') 77 | plt.plot(valid_2, label=r"Validation task B, prior: $q(w | \theta_{A})$", color='darkblue') 78 | plt.plot(valid_2_eval_A, label=r"Validation task A after training task B", color='#89c765') 79 | plt.plot(valid_3, label=r"Validation task C, prior: $q(w | \theta_{B})$", color='peru') 80 | plt.plot(valid_3_eval_B, label=r"Validation task B after training task C", color='m') 81 | #plt.plot(valid_3_eval_A, label=r"Validation task A after training task C", color='gray') 82 | #plt.plot(valid_4, "--", label=r"Validation, prior: $q(w | \theta_C)$", color='gray') 83 | #plt.plot(valid_4_eval, "--", label=r"Validation task C after training task D", color='black') 84 | 85 | 86 | plt.xlabel("Epochs") 87 | plt.ylabel("Accuracy") 88 | x_ticks = range(len(valid_1)) 89 | plt.xticks(x_ticks[9::10], map(lambda x: x+1, x_ticks[9::10])) 90 | 91 | plt.legend(loc='center right', fontsize=28) 92 | 93 | # , bbox_to_anchor=(0.25, 0.58) 94 | 95 | plt.savefig("results_continual.png", linewidth=10.0) 96 | -------------------------------------------------------------------------------- /figures/dist_taskA_taskB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/felix-laumann/Bayesian_CNN_ContinualLearning/1c0357a736ff1cd8d3b449ddbfb8aabffea34862/figures/dist_taskA_taskB.png -------------------------------------------------------------------------------- /figures/results_continual_conv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/felix-laumann/Bayesian_CNN_ContinualLearning/1c0357a736ff1cd8d3b449ddbfb8aabffea34862/figures/results_continual_conv.png -------------------------------------------------------------------------------- /figures/results_dissimilar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/felix-laumann/Bayesian_CNN_ContinualLearning/1c0357a736ff1cd8d3b449ddbfb8aabffea34862/figures/results_dissimilar.png -------------------------------------------------------------------------------- /figures/results_similar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/felix-laumann/Bayesian_CNN_ContinualLearning/1c0357a736ff1cd8d3b449ddbfb8aabffea34862/figures/results_similar.png -------------------------------------------------------------------------------- /main_continual.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pickle 3 | import torch 4 | import torch.cuda 5 | import torchvision.transforms as transforms 6 | import torch.utils.data as data 7 | import torchvision.datasets as dsets 8 | import os 9 | from utils.BBBConvmodel import BBBAlexNet, BBBLeNet, BBB3Conv3FC 10 | from utils.BBBlayers import GaussianVariationalInference 11 | 12 | cuda = torch.cuda.is_available() 13 | 14 | ''' 15 | HYPERPARAMETERS 16 | ''' 17 | is_training = True # set to "False" for evaluation of network ability to remember previous tasks 18 | pretrained = False # change pretrained to "True" for continual learning 19 | 20 | if pretrained is False: 21 | task = 1 22 | noise = 0 23 | elif pretrained is True: 24 | task = 2 # change to 3, 4, 5, etc. for more tasks 25 | noise = 0.025 # add extent of Gaussian noise 26 | 27 | num_samples = 10 # because of Casper's trick 28 | batch_size = 32 29 | beta_type = "Blundell" 30 | net = BBB3Conv3FC # LeNet, BBB3Conv3FC, or AlexNet 31 | dataset = 'CIFAR-100-classes' # MNIST, CIFAR-10, or CIFAR-100 32 | num_epochs = 100 33 | p_logvar_init = 0 34 | q_logvar_init = -10 35 | lr = 0.00001 36 | weight_decay = 0.0005 37 | 38 | 39 | # dimensions of input and output 40 | if dataset is 'MNIST': # train with MNIST 41 | outputs = 10 42 | inputs = 1 43 | elif dataset is 'CIFAR-10': # train with CIFAR-10 44 | outputs = 10 45 | inputs = 3 46 | elif dataset is 'CIFAR-100': # train with CIFAR-100 47 | outputs = 100 48 | inputs = 3 49 | elif dataset is 'CIFAR-100-classes': # train with 3 CIFAR-100-classes classes 50 | outputs = 3 51 | inputs = 3 52 | 53 | 54 | if net is BBBLeNet or BBB3Conv3FC: 55 | resize = 32 56 | elif net is BBBAlexNet: 57 | resize = 227 58 | 59 | ''' 60 | LOADING DATASET 61 | ''' 62 | 63 | if dataset is 'MNIST': 64 | transform = transforms.Compose([transforms.Resize((resize, resize)), transforms.ToTensor(), 65 | transforms.Lambda(lambda x: x + noise * torch.randn(x.size())), 66 | transforms.Normalize((0.1307,), (0.3081,))]) 67 | train_dataset = dsets.MNIST(root="data", download=True, transform=transform) 68 | val_dataset = dsets.MNIST(root="data", download=True, train=False, transform=transform) 69 | 70 | elif dataset is 'CIFAR-100': 71 | transform = transforms.Compose([transforms.Resize((resize, resize)), transforms.ToTensor(), 72 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) 73 | train_dataset = dsets.CIFAR100(root="data", download=True, transform=transform) 74 | val_dataset = dsets.CIFAR100(root="data", download=True, train=False, transform=transform) 75 | 76 | elif dataset is 'CIFAR-10': 77 | transform = transforms.Compose([transforms.Resize((resize, resize)), transforms.ToTensor(), 78 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) 79 | train_dataset = dsets.CIFAR10(root="data", download=True, transform=transform) 80 | val_dataset = dsets.CIFAR10(root="data", download=True, train=False, transform=transform) 81 | 82 | elif dataset is 'CIFAR-100-classes': 83 | transform = transforms.Compose([transforms.Resize((resize, resize)), transforms.ToTensor(), 84 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) 85 | train_dataset = dsets.ImageFolder(root="./similar/leopards/train/", transform=transform) 86 | val_dataset = dsets.ImageFolder(root="./similar/leopards/val/", transform=transform) 87 | 88 | ''' 89 | MAKING DATASET ITERABLE 90 | ''' 91 | 92 | print('length of training dataset:', len(train_dataset)) 93 | n_iterations = num_epochs * (len(train_dataset) / batch_size) 94 | n_iterations = int(n_iterations) 95 | print('Number of iterations: ', n_iterations) 96 | 97 | loader_train = data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) 98 | 99 | loader_val = data.DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False) 100 | 101 | 102 | # enable loading of weights to transfer learning 103 | def cnnmodel(pretrained, task): 104 | model = net(outputs=outputs, inputs=inputs) 105 | 106 | if pretrained: 107 | if is_training: 108 | # load pretrained posterior distribution of one task as prior of next task 109 | with open("weights_{}.pkl".format(task-1), "rb") as previous: 110 | d = pickle.load(previous) 111 | model.load_prior(d) 112 | else: 113 | # evaluate accuracy of previous task 114 | with open("weights_{}.pkl".format(task), "rb") as previous: 115 | d = pickle.load(previous) 116 | model.load_prior(d) 117 | 118 | return model 119 | 120 | 121 | ''' 122 | INSTANTIATE MODEL 123 | ''' 124 | 125 | model = cnnmodel(pretrained=pretrained, task=task) 126 | 127 | if cuda: 128 | model.cuda() 129 | 130 | ''' 131 | INSTANTIATE VARIATIONAL INFERENCE AND OPTIMISER 132 | ''' 133 | vi = GaussianVariationalInference(torch.nn.CrossEntropyLoss()) 134 | optimiser = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=weight_decay) 135 | 136 | ''' 137 | check parameter matrix shapes 138 | ''' 139 | 140 | # how many parameter matrices do we have? 141 | print('Number of parameter matrices: ', len(list(model.parameters()))) 142 | 143 | for i in range(len(list(model.parameters()))): 144 | print(list(model.parameters())[i].size()) 145 | 146 | ''' 147 | TRAIN MODEL 148 | ''' 149 | 150 | if is_training: 151 | logfile = os.path.join('diagnostics_{}.txt'.format(task)) 152 | else: 153 | logfile = os.path.join('diagnostics_{}_eval.txt'.format(task)) 154 | 155 | with open(logfile, 'w') as lf: 156 | lf.write('') 157 | 158 | 159 | def run_epoch(loader, epoch, is_training=False): 160 | m = math.ceil(len(loader.dataset) / loader.batch_size) 161 | 162 | accuracies = [] 163 | likelihoods = [] 164 | kls = [] 165 | losses = [] 166 | 167 | for i, (images, labels) in enumerate(loader): 168 | # Repeat samples (Casper's trick) 169 | x = images.view(-1, inputs, resize, resize).repeat(num_samples, 1, 1, 1) 170 | y = labels.repeat(num_samples) 171 | 172 | if cuda: 173 | x = x.cuda() 174 | y = y.cuda() 175 | 176 | if beta_type is "Blundell": 177 | beta = 2 ** (m - (i + 1)) / (2 ** m - 1) 178 | elif beta_type is "Soenderby": 179 | beta = min(epoch / (num_epochs//4), 1) 180 | elif beta_type is "Standard": 181 | beta = 1 / m 182 | else: 183 | beta = 0 184 | 185 | logits, kl = model.probforward(x) 186 | loss = vi(logits, y, kl, beta) 187 | ll = -loss.data.mean() + beta*kl.data.mean() 188 | 189 | if is_training: 190 | optimiser.zero_grad() 191 | loss.backward() 192 | optimiser.step() 193 | 194 | _, predicted = logits.max(1) 195 | accuracy = (predicted.data.cpu() == y.cpu()).float().mean() 196 | 197 | accuracies.append(accuracy) 198 | losses.append(loss.data.mean()) 199 | kls.append(beta*kl.data.mean()) 200 | likelihoods.append(ll) 201 | 202 | diagnostics = {'loss': sum(losses)/len(losses), 203 | 'acc': sum(accuracies)/len(accuracies), 204 | 'kl': sum(kls)/len(kls), 205 | 'likelihood': sum(likelihoods)/len(likelihoods)} 206 | 207 | return diagnostics 208 | 209 | 210 | for epoch in range(num_epochs): 211 | if is_training is True: 212 | diagnostics_train = run_epoch(loader_train, epoch, is_training=True) 213 | diagnostics_val = run_epoch(loader_val, epoch) 214 | diagnostics_train = dict({"type": "train", "epoch": epoch}, **diagnostics_train) 215 | diagnostics_val = dict({"type": "validation", "epoch": epoch}, **diagnostics_val) 216 | print(diagnostics_train) 217 | print(diagnostics_val) 218 | 219 | with open(logfile, 'a') as lf: 220 | lf.write(str(diagnostics_train)) 221 | lf.write(str(diagnostics_val)) 222 | else: 223 | diagnostics_val = run_epoch(loader_val, epoch) 224 | diagnostics_val = dict({"type": "validation", "epoch": epoch}, **diagnostics_val) 225 | print(diagnostics_val) 226 | 227 | with open(logfile, 'a') as lf: 228 | lf.write(str(diagnostics_val)) 229 | 230 | ''' 231 | SAVE PARAMETERS 232 | ''' 233 | 234 | if is_training: 235 | weightsfile = os.path.join("weights_{}.pkl".format(task)) 236 | with open(weightsfile, "wb") as wf: 237 | pickle.dump(model.state_dict(), wf) 238 | 239 | -------------------------------------------------------------------------------- /utils/BBBConvmodel.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .BBBdistributions import Normal 3 | from .BBBlayers import BBBConv2d, BBBLinearFactorial, FlattenLayer 4 | 5 | 6 | class BBBAlexNet(nn.Module): 7 | def __init__(self, outputs, inputs): 8 | # create AlexNet with probabilistic weights 9 | super(BBBAlexNet, self).__init__() 10 | 11 | # FEATURES 12 | self.conv1 = BBBConv2d(inputs, 64, kernel_size=11, stride=4, padding=2) 13 | self.conv1a = nn.Sequential( 14 | nn.Softplus(), 15 | # nn.BatchNorm2d(64), 16 | nn.MaxPool2d(kernel_size=3, stride=2) 17 | ) 18 | self.conv2 = BBBConv2d(64, 192, kernel_size=5, padding=2) 19 | self.conv2a = nn.Sequential( 20 | nn.Softplus(), 21 | # nn.BatchNorm2d(192), 22 | nn.MaxPool2d(kernel_size=3, stride=2) 23 | ) 24 | self.conv3 = BBBConv2d(192, 384, kernel_size=3, padding=1) 25 | self.conv3a = nn.Sequential( 26 | nn.Softplus(), 27 | # nn.BatchNorm2d(384), 28 | ) 29 | self.conv4 = BBBConv2d(384, 256, kernel_size=3, padding=1) 30 | self.conv4a = nn.Sequential( 31 | nn.Softplus(), 32 | # nn.BatchNorm2d(256), 33 | ) 34 | self.conv5 = BBBConv2d(256, 256, kernel_size=3, padding=1) 35 | self.conv5a = nn.Sequential( 36 | nn.Softplus(), 37 | # nn.BatchNorm2d(256), 38 | nn.MaxPool2d(kernel_size=3, stride=2) 39 | ) 40 | # CLASSIFIER 41 | self.flatten = FlattenLayer(256 * 6 * 6) 42 | self.drop1 = nn.Dropout() 43 | self.fc1 = BBBLinearFactorial(256 * 6 * 6, 4096) 44 | self.relu1 = nn.Softplus() 45 | self.drop2 = nn.Dropout() 46 | self.fc2 = BBBLinearFactorial(4096, 4096) 47 | self.relu2 = nn.Softplus() 48 | self.fc3 = BBBLinearFactorial(4096, outputs) 49 | 50 | layers = [self.conv1, self.conv1a, self.conv2, self.conv2a, self.conv3, self.conv3a, self.conv4, self.conv4a, 51 | self.conv5, self.conv5a, self.flatten, self.drop1, self.fc1, self.relu1, self.drop2, self.fc2, self.relu2, self.fc3] 52 | 53 | layers_cont = [self.conv1, self.conv1a, self.conv2, self.conv2a, self.conv3, self.conv3a, self.conv4, self.conv4a, 54 | self.conv5, self.conv5a] 55 | 56 | self.layers = nn.ModuleList(layers) 57 | self.layers_cont = nn.ModuleList(layers_cont) 58 | 59 | def probforward(self, x): 60 | kl = 0 61 | for layer in self.layers: 62 | if hasattr(layer, 'convprobforward') and callable(layer.convprobforward): 63 | x, _kl, = layer.convprobforward(x) 64 | kl += _kl 65 | 66 | elif hasattr(layer, 'fcprobforward') and callable(layer.fcprobforward): 67 | x, _kl, = layer.fcprobforward(x) 68 | kl += _kl 69 | else: 70 | x = layer(x) 71 | logits = x 72 | print('logits', logits) 73 | return logits, kl 74 | 75 | def load_prior(self, state_dict): 76 | d_q = {k: v for k, v in state_dict.items() if "q" in k} 77 | for i, layer in enumerate(self.layers_cont): 78 | if type(layer) is BBBConv2d: 79 | layer.pw = Normal(mu=d_q["layers.{}.qw_mean".format(i)], 80 | logvar=d_q["layers.{}.qw_logvar".format(i)]) 81 | # layer.pb = Normal(mu=d_q["layers.{}.qb_mean".format(i)], logvar=d_q["layers.{}.qb_logvar".format(i)]) 82 | 83 | elif type(layer) is BBBLinearFactorial: 84 | layer.pw = Normal(mu=(d_q["layers.{}.qw_mean".format(i)]), 85 | logvar=(d_q["layers.{}.qw_logvar".format(i)])) 86 | 87 | layer.pb = Normal(mu=(d_q["layers.{}.qb_mean".format(i)]), 88 | logvar=(d_q["layers.{}.qb_logvar".format(i)])) 89 | 90 | 91 | class BBBLeNet(nn.Module): 92 | def __init__(self, outputs, inputs): 93 | super(BBBLeNet, self).__init__() 94 | self.conv1 = BBBConv2d(inputs, 6, 5, stride=1) 95 | self.relu1 = nn.Softplus() 96 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 97 | 98 | self.conv2 = BBBConv2d(6, 16, 5, stride=1) 99 | self.relu2 = nn.Softplus() 100 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 101 | 102 | self.flatten = FlattenLayer(5 * 5 * 16) 103 | self.fc1 = BBBLinearFactorial(5 * 5 * 16, 120) 104 | self.relu3 = nn.Softplus() 105 | 106 | self.fc2 = BBBLinearFactorial(120, 84) 107 | self.relu4 = nn.Softplus() 108 | 109 | self.fc3 = BBBLinearFactorial(84, outputs) 110 | 111 | layers = [self.conv1, self.relu1, self.pool1, self.conv2, self.relu2, self.pool2, 112 | self.flatten, self.fc1, self.relu3, self.fc2, self.relu4, self.fc3] 113 | 114 | layers_cont = [self.conv1, self.relu1, self.pool1, self.conv2, self.relu2, self.pool2] 115 | 116 | self.layers = nn.ModuleList(layers) 117 | self.layers_cont = nn.ModuleList(layers_cont) 118 | 119 | def probforward(self, x): 120 | kl = 0 121 | for layer in self.layers: 122 | if hasattr(layer, 'convprobforward') and callable(layer.convprobforward): 123 | x, _kl, = layer.convprobforward(x) 124 | kl += _kl 125 | 126 | elif hasattr(layer, 'fcprobforward') and callable(layer.fcprobforward): 127 | x, _kl, = layer.fcprobforward(x) 128 | kl += _kl 129 | else: 130 | x = layer(x) 131 | logits = x 132 | print('logits', logits) 133 | return logits, kl 134 | 135 | # load priors for continual tasks 136 | def load_prior(self, state_dict): 137 | d_q = {k: v for k, v in state_dict.items() if "q" in k} 138 | for i, layer in enumerate(self.layers_cont): 139 | if type(layer) is BBBConv2d: 140 | layer.pw = Normal(mu=d_q["layers.{}.qw_mean".format(i)], 141 | logvar=d_q["layers.{}.qw_logvar".format(i)]) 142 | 143 | elif type(layer) is BBBLinearFactorial: 144 | layer.pw = Normal(mu=(d_q["layers.{}.qw_mean".format(i)]), 145 | logvar=(d_q["layers.{}.qw_logvar".format(i)])) 146 | 147 | #layer.pb = Normal(mu=(d_q["layers.{}.qb_mean".format(i)]), logvar=(d_q["layers.{}.qb_logvar".format(i)])) 148 | 149 | 150 | class BBB3Conv3FC(nn.Module): 151 | def __init__(self, outputs, inputs): 152 | super(BBB3Conv3FC, self).__init__() 153 | self.conv1 = BBBConv2d(inputs, 32, 5, stride=1, padding=2) 154 | self.soft1 = nn.Softplus() 155 | self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2) 156 | 157 | self.conv2 = BBBConv2d(32, 64, 5, stride=1, padding=2) 158 | self.soft2 = nn.Softplus() 159 | self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2) 160 | 161 | self.conv3 = BBBConv2d(64, 128, 5, stride=1, padding=1) 162 | self.soft3 = nn.Softplus() 163 | self.pool3 = nn.MaxPool2d(kernel_size=3, stride=2) 164 | 165 | self.flatten = FlattenLayer(2 * 2 * 128) 166 | self.fc1 = BBBLinearFactorial(2 * 2 * 128, 1000) 167 | self.soft5 = nn.Softplus() 168 | 169 | self.fc2 = BBBLinearFactorial(1000, 1000) 170 | self.soft6 = nn.Softplus() 171 | 172 | self.fc3 = BBBLinearFactorial(1000, outputs) 173 | 174 | layers = [self.conv1, self.soft1, self.pool1, self.conv2, self.soft2, self.pool2, 175 | self.conv3, self.soft3, self.pool3, self.flatten, self.fc1, self.soft5, 176 | self.fc2, self.soft6, self.fc3] 177 | 178 | layers_cont = [self.conv1, self.soft1, self.pool1, self.conv2, self.soft2, self.pool2, 179 | self.conv3, self.soft3, self.pool3] 180 | 181 | self.layers = nn.ModuleList(layers) 182 | self.layers_cont = nn.ModuleList(layers_cont) 183 | 184 | def probforward(self, x): 185 | kl = 0 186 | for layer in self.layers: 187 | if hasattr(layer, 'convprobforward') and callable(layer.convprobforward): 188 | x, _kl, = layer.convprobforward(x) 189 | kl += _kl 190 | 191 | elif hasattr(layer, 'fcprobforward') and callable(layer.fcprobforward): 192 | x, _kl, = layer.fcprobforward(x) 193 | kl += _kl 194 | else: 195 | x = layer(x) 196 | logits = x 197 | print('logits', logits) 198 | return logits, kl 199 | 200 | # load priors for continual tasks 201 | def load_prior(self, state_dict): 202 | d_q = {k: v for k, v in state_dict.items() if "q" in k} 203 | for i, layer in enumerate(self.layers_cont): 204 | if type(layer) is BBBConv2d: 205 | layer.pw = Normal(mu=d_q["layers.{}.qw_mean".format(i)], 206 | logvar=d_q["layers.{}.qw_logvar".format(i)]) 207 | 208 | elif type(layer) is BBBLinearFactorial: 209 | layer.pw = Normal(mu=(d_q["layers.{}.qw_mean".format(i)]), 210 | logvar=(d_q["layers.{}.qw_logvar".format(i)])) 211 | 212 | #layer.pb = Normal(mu=(d_q["layers.{}.qb_mean".format(i)]), logvar=(d_q["layers.{}.qb_logvar".format(i)])) 213 | 214 | -------------------------------------------------------------------------------- /utils/BBBdistributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | import torch.nn as nn 5 | 6 | 7 | class Distribution(object): 8 | """ 9 | Base class for torch-based probability distributions. 10 | """ 11 | def pdf(self, x): 12 | raise NotImplementedError 13 | 14 | def logpdf(self, x): 15 | raise NotImplementedError 16 | 17 | def cdf(self, x): 18 | raise NotImplementedError 19 | 20 | def logcdf(self, x): 21 | raise NotImplementedError 22 | 23 | def sample(self): 24 | raise NotImplementedError 25 | 26 | def forward(self, x): 27 | raise NotImplementedError 28 | 29 | 30 | class Normal(Distribution): 31 | # scalar version 32 | def __init__(self, mu, logvar): 33 | self.mu = mu 34 | self.logvar = logvar 35 | self.shape = mu.size() 36 | 37 | super(Normal, self).__init__() 38 | 39 | def logpdf(self, x): 40 | c = - float(0.5 * math.log(2 * math.pi)) 41 | return c - 0.5 * self.logvar - (x - self.mu).pow(2) / (2 * torch.exp(self.logvar)) 42 | 43 | def pdf(self, x): 44 | return torch.exp(self.logpdf(x)) 45 | 46 | def sample(self): 47 | if self.mu.is_cuda: 48 | eps = torch.cuda.FloatTensor(self.shape).normal_() 49 | else: 50 | eps = torch.FloatTensor(self.shape).normal_() 51 | # local reparameterization trick 52 | return self.mu + torch.exp(0.5 * self.logvar) * eps 53 | 54 | def entropy(self): 55 | return 0.5 * math.log(2. * math.pi * math.e) + 0.5 * self.logvar 56 | 57 | 58 | class FixedNormal(Distribution): 59 | # takes mu and logvar as float values and assumes they are shared across all weights 60 | def __init__(self, mu, logvar): 61 | self.mu = mu 62 | self.logvar = logvar 63 | super(FixedNormal, self).__init__() 64 | 65 | def logpdf(self, x): 66 | c = - float(0.5 * math.log(2 * math.pi)) 67 | return c - 0.5 * self.logvar - (x - self.mu).pow(2) / (2 * math.exp(self.logvar)) 68 | 69 | 70 | class Normalout(Distribution): 71 | # scalar version 72 | def __init__(self, mu, si): 73 | self.mu = mu 74 | self.si = si 75 | self.shape = mu.size() 76 | 77 | super(Normalout, self).__init__() 78 | 79 | def logpdf(self, x): 80 | c = - float(0.5 * math.log(2 * math.pi)) 81 | return c - 0.5 * self.si - (x - self.mu).pow(2) / (2 * torch.exp(self.si)) 82 | 83 | def pdf(self, x): 84 | return torch.exp(self.logpdf(x)) 85 | 86 | def sample(self): 87 | if self.mu.is_cuda: 88 | eps = torch.cuda.FloatTensor(self.shape).normal_() 89 | else: 90 | eps = torch.FloatTensor(self.shape).normal_() 91 | # local reparameterization trick 92 | return self.mu + torch.exp(0.5 * self.si) * eps 93 | 94 | def entropy(self): 95 | return 0.5 * math.log(2. * math.pi * math.e) + 0.5 * self.si 96 | 97 | 98 | class FixedMixtureNormal(nn.Module): 99 | # scale mixture Gaussian prior (with scale mixture factor pi) 100 | def __init__(self, mu, logvar, pi): 101 | super(FixedMixtureNormal, self).__init__() 102 | # Ensure convex combination 103 | assert sum(pi) - 1 < 0.0001 104 | self.mu = nn.Parameter(torch.from_numpy(np.array(mu)).float(), requires_grad=False) 105 | self.logvar = nn.Parameter(torch.from_numpy(np.array(logvar)).float(), requires_grad=False) 106 | self.pi = nn.Parameter(torch.from_numpy(np.array(pi)).float(), requires_grad=False) 107 | 108 | def _component_logpdf(self, x): 109 | ndim = len(x.size()) 110 | shape_expand = ndim * (None,) 111 | x = x.unsqueeze(-1) 112 | 113 | c = - float(0.5 * math.log(2 * math.pi)) 114 | mu = self.mu[shape_expand] 115 | logvar = self.logvar[shape_expand] 116 | pi = self.pi[shape_expand] 117 | 118 | return c - 0.5 * logvar - (x - mu).pow(2) / 2 * torch.exp(logvar) 119 | 120 | def logpdf(self, x): 121 | ndim = len(x.size()) 122 | shape_expand = ndim * (None,) 123 | pi = self.pi[shape_expand] 124 | px = torch.exp(self._component_logpdf(x)) # ... x num_components 125 | return torch.log(torch.sum(pi * px, -1)) 126 | 127 | 128 | def distribution_selector(mu, logvar, pi): 129 | if isinstance(logvar, (list, tuple)) and isinstance(pi, (list, tuple)): 130 | assert len(logvar) == len(pi) 131 | num_components = len(logvar) 132 | if not isinstance(mu, (list, tuple)): 133 | mu = (mu,) * num_components 134 | return FixedMixtureNormal(mu, logvar, pi) 135 | else: 136 | return FixedNormal(mu, logvar) 137 | -------------------------------------------------------------------------------- /utils/BBBlayers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from torch.nn import Parameter 5 | import torch.nn.functional as F 6 | from .BBBdistributions import Normal, Normalout, distribution_selector 7 | from torch.nn.modules.utils import _pair 8 | 9 | cuda = torch.cuda.is_available() 10 | 11 | 12 | class FlattenLayer(nn.Module): 13 | 14 | def __init__(self, num_features): 15 | super(FlattenLayer, self).__init__() 16 | self.num_features = num_features 17 | 18 | def forward(self, x): 19 | return x.view(-1, self.num_features) 20 | 21 | 22 | class _ConvNd(nn.Module): 23 | """ 24 | Describes a Bayesian convolutional layer with 25 | a distribution over each of the weights and biases 26 | in the layer. 27 | """ 28 | 29 | def __init__(self, in_channels, out_channels, kernel_size, stride, 30 | padding, dilation, output_padding, groups, p_logvar_init=-3, p_pi=1.0, q_logvar_init=-5): 31 | super(_ConvNd, self).__init__() 32 | if in_channels % groups != 0: 33 | raise ValueError('in_channels must be divisible by groups') 34 | if out_channels % groups != 0: 35 | raise ValueError('out_channels must be divisible by groups') 36 | 37 | self.in_channels = in_channels 38 | self.out_channels = out_channels 39 | self.kernel_size = kernel_size 40 | self.stride = stride 41 | self.padding = padding 42 | self.dilation = dilation 43 | self.output_padding = output_padding 44 | self.groups = groups 45 | 46 | # initialize log variance of p and q 47 | self.p_logvar_init = p_logvar_init 48 | self.q_logvar_init = q_logvar_init 49 | 50 | # approximate posterior weights... 51 | self.qw_mean = Parameter(torch.Tensor(out_channels, in_channels // groups, *kernel_size)) 52 | self.qw_logvar = Parameter(torch.Tensor(out_channels, in_channels // groups, *kernel_size)) 53 | # self.qb_mean = Parameter(torch.Tensor(out_channels)) 54 | # self.qb_logvar = Parameter(torch.Tensor(out_channels)) 55 | 56 | # ...and output... 57 | self.conv_qw_mean = Parameter(torch.Tensor(out_channels, in_channels // groups, *kernel_size)) 58 | self.conv_qw_si = Parameter(torch.Tensor(out_channels, in_channels // groups, *kernel_size)) 59 | 60 | # ...as normal distributions 61 | self.qw = Normal(mu=self.qw_mean, logvar=self.qw_logvar) 62 | # self.qb = Normal(mu=self.qb_mean, logvar=self.qb_logvar) 63 | self.conv_qw = Normalout(mu=self.conv_qw_mean, si=self.conv_qw_si) 64 | 65 | # initialise 66 | self.log_alpha = Parameter(torch.Tensor(1, 1)) 67 | 68 | # prior model 69 | # (does not have any trainable parameters so we use fixed normal or fixed mixture normal distributions) 70 | self.pw = distribution_selector(mu=0.0, logvar=p_logvar_init, pi=p_pi) 71 | #self.pb = distribution_selector(mu=0.0, logvar=p_logvar_init, pi=p_pi) 72 | 73 | # initialize all parameters 74 | self.reset_parameters() 75 | 76 | def reset_parameters(self): 77 | # initialise (learnable) approximate posterior parameters 78 | n = self.in_channels 79 | for k in self.kernel_size: 80 | n *= k 81 | stdv = 1. / math.sqrt(n) 82 | self.qw_mean.data.uniform_(-stdv, stdv) 83 | self.qw_logvar.data.uniform_(-stdv, stdv).add_(self.q_logvar_init) 84 | #self.qb_mean.data.uniform_(-stdv, stdv) 85 | #self.qb_logvar.data.uniform_(-stdv, stdv).add_(self.q_logvar_init) 86 | self.conv_qw_mean.data.uniform_(-stdv, stdv) 87 | self.conv_qw_si.data.uniform_(-stdv, stdv).add_(self.q_logvar_init) 88 | self.log_alpha.data.uniform_(-stdv, stdv) 89 | 90 | def extra_repr(self): 91 | s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' 92 | ', stride={stride}') 93 | if self.padding != (0,) * len(self.padding): 94 | s += ', padding={padding}' 95 | if self.dilation != (1,) * len(self.dilation): 96 | s += ', dilation={dilation}' 97 | if self.output_padding != (0,) * len(self.output_padding): 98 | s += ', output_padding={output_padding}' 99 | if self.groups != 1: 100 | s += ', groups={groups}' 101 | if self.bias is None: 102 | s += ', bias=False' 103 | return s.format(**self.__dict__) 104 | 105 | 106 | class BBBConv2d(_ConvNd): 107 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 108 | padding=0, dilation=1, groups=1): 109 | 110 | kernel_size = _pair(kernel_size) 111 | stride = _pair(stride) 112 | padding = _pair(padding) 113 | dilation = _pair(dilation) 114 | 115 | super(BBBConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, _pair(0), groups) 116 | 117 | def forward(self, input): 118 | raise NotImplementedError() 119 | 120 | def convprobforward(self, input): 121 | """ 122 | Convolutional probabilistic forwarding method. 123 | :param input: data tensor 124 | :return: output, KL-divergence 125 | """ 126 | 127 | # local reparameterization trick for convolutional layer 128 | conv_qw_mean = F.conv2d(input=input, weight=self.qw_mean, stride=self.stride, padding=self.padding, 129 | dilation=self.dilation, groups=self.groups) 130 | conv_qw_si = torch.sqrt(1e-8 + F.conv2d(input=input.pow(2), weight=torch.exp(self.log_alpha)*self.qw_mean.pow(2), 131 | stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)) 132 | 133 | if cuda: 134 | conv_qw_mean.cuda() 135 | conv_qw_si.cuda() 136 | 137 | # sample from output 138 | if cuda: 139 | output = conv_qw_mean + conv_qw_si * (torch.randn(conv_qw_mean.size())).cuda() 140 | else: 141 | output = conv_qw_mean + conv_qw_si * (torch.randn(conv_qw_mean.size())) 142 | 143 | if cuda: 144 | output.cuda() 145 | 146 | w_sample = self.conv_qw.sample() 147 | 148 | # KL divergence 149 | qw_logpdf = self.conv_qw.logpdf(w_sample) 150 | 151 | kl = torch.sum(qw_logpdf - self.pw.logpdf(w_sample)) 152 | 153 | return output, kl 154 | 155 | 156 | class BBBLinearFactorial(nn.Module): 157 | """ 158 | Describes a Linear fully connected Bayesian layer with 159 | a distribution over each of the weights and biases 160 | in the layer. 161 | """ 162 | def __init__(self, in_features, out_features, p_logvar_init=-3, p_pi=1.0, q_logvar_init=-5): 163 | # p_logvar_init, p_pi can be either 164 | # (list/tuples): prior model is a mixture of Gaussians components=len(p_pi)=len(p_logvar_init) 165 | # float: Gussian distribution 166 | # q_logvar_init: float, the approximate posterior is currently always a factorized gaussian 167 | super(BBBLinearFactorial, self).__init__() 168 | 169 | self.in_features = in_features 170 | self.out_features = out_features 171 | self.p_logvar_init = p_logvar_init 172 | self.q_logvar_init = q_logvar_init 173 | 174 | # Approximate posterior weights... 175 | self.qw_mean = Parameter(torch.Tensor(out_features, in_features)) 176 | self.qw_logvar = Parameter(torch.Tensor(out_features, in_features)) 177 | #self.qb_mean = Parameter(torch.Tensor(out_features)) 178 | #self.qb_logvar = Parameter(torch.Tensor(out_features)) 179 | 180 | # ...and output... 181 | self.fc_qw_mean = Parameter(torch.Tensor(out_features, in_features)) 182 | self.fc_qw_si = Parameter(torch.Tensor(out_features, in_features)) 183 | 184 | # ...as normal distributions 185 | self.qw = Normal(mu=self.qw_mean, logvar=self.qw_logvar) 186 | #self.qb = Normal(mu=self.qb_mean, logvar=self.qb_logvar) 187 | self.fc_qw = Normalout(mu=self.fc_qw_mean, si=self.fc_qw_si) 188 | 189 | # initialise 190 | self.log_alpha = Parameter(torch.Tensor(1, 1)) 191 | 192 | # prior model 193 | self.pw = distribution_selector(mu=0.0, logvar=p_logvar_init, pi=p_pi) 194 | #self.pb = distribution_selector(mu=0.0, logvar=p_logvar_init, pi=p_pi) 195 | 196 | # initialize all paramaters 197 | self.reset_parameters() 198 | 199 | def reset_parameters(self): 200 | # initialize (learnable) approximate posterior parameters 201 | stdv = 10. / math.sqrt(self.in_features) 202 | self.qw_mean.data.uniform_(-stdv, stdv) 203 | self.qw_logvar.data.uniform_(-stdv, stdv).add_(self.q_logvar_init) 204 | #self.qb_mean.data.uniform_(-stdv, stdv) 205 | #self.qb_logvar.data.uniform_(-stdv, stdv).add_(self.q_logvar_init) 206 | self.fc_qw_mean.data.uniform_(-stdv, stdv) 207 | self.fc_qw_si.data.uniform_(-stdv, stdv).add_(self.q_logvar_init) 208 | self.log_alpha.data.uniform_(-stdv, stdv) 209 | 210 | def forward(self, input): 211 | raise NotImplementedError() 212 | 213 | def fcprobforward(self, input): 214 | """ 215 | Probabilistic forwarding method. 216 | :param input: data tensor 217 | :return: output, kl-divergence 218 | """ 219 | 220 | fc_qw_mean = F.linear(input=input, weight=self.qw_mean) 221 | fc_qw_si = torch.sqrt(1e-8 + F.linear(input=input.pow(2), weight=torch.exp(self.log_alpha)*self.qw_mean.pow(2))) 222 | 223 | if cuda: 224 | fc_qw_mean.cuda() 225 | fc_qw_si.cuda() 226 | 227 | # sample from output 228 | if cuda: 229 | output = fc_qw_mean + fc_qw_si * (torch.randn(fc_qw_mean.size())).cuda() 230 | else: 231 | output = fc_qw_mean + fc_qw_si * (torch.randn(fc_qw_mean.size())) 232 | 233 | if cuda: 234 | output.cuda() 235 | 236 | w_sample = self.fc_qw.sample() 237 | 238 | # KL divergence 239 | qw_logpdf = self.fc_qw.logpdf(w_sample) 240 | 241 | kl = torch.sum(qw_logpdf - self.pw.logpdf(w_sample)) 242 | 243 | return output, kl 244 | 245 | def __repr__(self): 246 | return self.__class__.__name__ + ' (' \ 247 | + str(self.in_features) + ' -> ' \ 248 | + str(self.out_features) + ')' 249 | 250 | 251 | class GaussianVariationalInference(nn.Module): 252 | def __init__(self, loss=nn.CrossEntropyLoss()): 253 | super(GaussianVariationalInference, self).__init__() 254 | self.loss = loss 255 | 256 | def forward(self, logits, y, kl, beta): 257 | logpy = -self.loss(logits, y) 258 | 259 | ll = logpy - beta * kl # ELBO 260 | loss = -ll 261 | 262 | return loss 263 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | --------------------------------------------------------------------------------