├── __init__.py ├── src ├── __init__.py ├── train.py └── models.py ├── experiments ├── __init__.py ├── supervised-cifar10.py ├── supervised_spiral.py ├── supervised_uci.py ├── layer_generators.py └── load_data.py └── readme.md /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | Implementation of the Unbounded Depth Neural Network in PyTorch. 2 | 3 | ## Quick run 4 | Generates a spiral classification dataset and fit a UDN with fully connected hidden layers. 5 | `python -m experiments.supervised_spiral` 6 | 7 | ## Model and the Variational Depth 8 | The Unbounded Depth Neural network is implemented in PyTorch at [`src.models.UnboundedDepthNetwork`](src/models.py). 9 | 10 | The abstract class `src.models.VariationalDepth` represents the variational posterior on the depth L. Any implementation 11 | of this class can be given to the `UnboundedDepthNetwork`. 12 | - `TruncatedPoisson` implements the variational distribution introduced in the paper. 13 | - `FixedDepth` is a constant distribution simulating regular (bounded) neural network 14 | 15 | ## Training 16 | Some helpful functions for training and evaluating the UDN are available in [`src/train.py`](src/train.py). 17 | 18 | ## Experiments 19 | The three main experiments of the paper (cifar10, spirl, uci) can be reproduced using the code in `experiments`. 20 | 21 | ## Citation 22 | ``` 23 | @inproceedings{nazaret2022variational, 24 | title={Variational Inference for Infinitely Deep Neural Networks}, 25 | author={Nazaret, Achille and Blei, David}, 26 | booktitle={International Conference on Machine Learning}, 27 | year={2022}, 28 | } 29 | ``` 30 | 31 | -------------------------------------------------------------------------------- /experiments/supervised-cifar10.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | 4 | import pandas as pd 5 | import torch.optim as optim 6 | import torch.utils.data 7 | 8 | from experiments.load_data import load_data_cifar 9 | from experiments.layer_generators import generator_layers_cifar, generator_output_cifar 10 | from src.models import UnboundedDepthNetwork, TruncatedPoisson, FixedDepth 11 | from src.train import train_one_epoch_classification 12 | 13 | INPUT_SIZE = 784 14 | OUTPUT_SIZE = 10 15 | BATCH_SIZE = 256 16 | 17 | CUDA = True 18 | DEVICE = torch.device("cuda" if CUDA and torch.cuda.is_available() else "cpu") 19 | 20 | print(DEVICE) 21 | 22 | 23 | from torch.optim.lr_scheduler import _LRScheduler 24 | 25 | 26 | class ExplicitLR(_LRScheduler): 27 | """""" 28 | 29 | def __init__(self, optimizer, lrs, last_epoch=-1, verbose=False): 30 | self.lrs = lrs 31 | super().__init__(optimizer, last_epoch, verbose) 32 | 33 | def get_lr(self): 34 | if self.last_epoch >= len(self.lrs): 35 | return [group["lr"] for group in self.optimizer.param_groups] 36 | return [self.lrs[self.last_epoch] for _ in self.optimizer.param_groups] 37 | 38 | 39 | if __name__ == "__main__": 40 | parser = argparse.ArgumentParser(description="") 41 | parser.add_argument("-e", "--epochs", default=500, type=int, help="number of epochs") 42 | parser.add_argument("-l", "--layers", default=-1, help="number of layers", type=int) 43 | parser.add_argument("-s", "--seed", default=0, help="random seed", type=int) 44 | parser.add_argument("--categories", default="", help="filter categories", type=str) 45 | 46 | args = parser.parse_args() 47 | SEED = args.seed 48 | torch.manual_seed(SEED) 49 | L = args.layers 50 | 51 | device = DEVICE 52 | 53 | # LOAD DATA 54 | # Subsample categories to change the dataset complexity: [4,1] is (deer/car); [5,3] is (dog/cat) 55 | filter_labels = None 56 | if args.categories: 57 | filter_labels = list([int(c) for c in args.categories]) 58 | 59 | train_loader, valid_loader, test_loader = load_data_cifar( 60 | BATCH_SIZE, seed=SEED, filter_labels=filter_labels, validation_size=0 61 | ) 62 | N_train = len(train_loader.sampler) 63 | 64 | # CREATE MODEL 65 | 66 | torch.manual_seed(SEED) 67 | 68 | if L < 0: 69 | vpost = TruncatedPoisson(5.0) 70 | else: 71 | vpost = FixedDepth(L) 72 | 73 | model = UnboundedDepthNetwork( 74 | N_train, 75 | lambda l: generator_layers_cifar(l, True), 76 | generator_output_cifar, 77 | vpost, 78 | INPUT_SIZE, 79 | OUTPUT_SIZE, 80 | L_prior_poisson=1, 81 | theta_prior_scale=1.0, 82 | ) 83 | 84 | model.model_name += ".cifar.v2" + ("-f" + args.categories if args.categories else "") + "-s%d" % SEED 85 | 86 | print(model.n_obs) 87 | model.set_device(device) 88 | 89 | # TRAINING LOOP 90 | optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4) 91 | scheduler = ExplicitLR(optimizer, [0.01] * 5 + [0.1] * 195 + [0.01] * 100 + [0.001] * 100) 92 | model.set_optimizer(optimizer) 93 | 94 | tmp = pd.DataFrame({"depth": [], "nu_L": [], "test_acc": []}) 95 | tmp.to_csv("tmp.%s.csv" % model.model_name) 96 | 97 | for epoch in range(args.epochs): 98 | start_time = time.time() 99 | test_accuracy = train_one_epoch_classification( 100 | epoch, 101 | train_loader, 102 | valid_loader, 103 | test_loader, 104 | model, 105 | optimizer, 106 | scheduler, 107 | normalize_loss=True, 108 | ) 109 | scheduler.step() 110 | print(time.time() - start_time) 111 | -------------------------------------------------------------------------------- /experiments/supervised_spiral.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | 5 | import pandas as pd 6 | import torch.optim as optim 7 | import torch.utils.data 8 | 9 | from experiments.layer_generators import make_generators_fcn, make_generators_fcn_DUN 10 | from experiments.load_data import load_data_spiral 11 | from src.models import ( 12 | UnboundedDepthNetwork, 13 | TruncatedPoisson, 14 | FixedDepth, 15 | CategoricalDUN, 16 | ) 17 | from src.train import train_one_epoch_classification 18 | 19 | INPUT_SIZE = 2 20 | OUTPUT_SIZE = 2 21 | BATCH_SIZE = 256 22 | CUDA = False 23 | DEVICE = torch.device("cuda" if CUDA and torch.cuda.is_available() else "cpu") 24 | print(DEVICE) 25 | 26 | if __name__ == "__main__": 27 | parser = argparse.ArgumentParser(description="") 28 | parser.add_argument("-e", "--epochs", default=1000, type=int, help="number of epochs") 29 | parser.add_argument("--lr", default=0.01, type=float, help="learning rate") 30 | parser.add_argument("-l", "--layers", default=-1, help="number of layers", type=int) 31 | parser.add_argument("-s", "--seed", default=0, help="random seed", type=int) 32 | parser.add_argument("-r", "--spiral", default=1, help="spiral", type=int) 33 | parser.add_argument("-d", "--dun", action="store_true") 34 | parser.add_argument("-i", "--init", default=1.0, type=float, help="truncated poisson init") 35 | parser.add_argument("-p", "--prior", default=1.0, type=float, help=" poisson prior") 36 | args = parser.parse_args() 37 | 38 | SEED = args.seed 39 | spiral_R = args.spiral 40 | L = args.layers 41 | LR = args.lr 42 | DUN = args.dun 43 | 44 | torch.manual_seed(SEED) 45 | device = DEVICE 46 | 47 | # LOAD DATA 48 | train_loader, valid_loader, test_loader = load_data_spiral(spiral_R, BATCH_SIZE, SEED) 49 | N_train = len(train_loader.sampler) 50 | 51 | # CREATE MODEL 52 | if DUN: 53 | generator_layers, generator_residual = make_generators_fcn_DUN(8, 2, 2) 54 | else: 55 | generator_layers, generator_residual = make_generators_fcn(8, 2, 2) 56 | 57 | # Negative L means UDN (TruncatedPoisson); non-negative L means either standard NN (FixedDepth) or UDN (Categorical) 58 | if L < 0: 59 | if DUN: 60 | raise ValueError 61 | else: 62 | vpost = TruncatedPoisson(1.0) 63 | else: 64 | if DUN: 65 | vpost = CategoricalDUN(L) 66 | else: 67 | vpost = FixedDepth(L) 68 | 69 | model = UnboundedDepthNetwork( 70 | N_train, 71 | generator_layers, 72 | generator_residual, 73 | vpost, 74 | INPUT_SIZE, 75 | OUTPUT_SIZE, 76 | L_prior_poisson=0.5, 77 | theta_prior_scale=1.0, 78 | seed=SEED, 79 | ) 80 | 81 | model.model_name += ".spiral.v2.O-%d.seed-%d.lr%.4f" % (spiral_R, SEED, LR) 82 | model.set_device(device) 83 | 84 | # reduce the LR for he variational posterior qL 85 | optimizer = optim.Adam( 86 | [ 87 | {"params": [p for n, p in model.named_parameters() if n != "variational_posterior_L._nu_L"]}, 88 | { 89 | "params": [p for n, p in model.named_parameters() if n == "variational_posterior_L._nu_L"], 90 | "lr": LR / 10, 91 | }, 92 | ], 93 | lr=LR, 94 | ) 95 | # optimizer = optim.Adam(model.parameters(), lr=LR) 96 | 97 | STEP = 100 98 | # scheduler is useless here, we use a schedule for cifar only 99 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[], gamma=0.6) 100 | model.set_optimizer(optimizer) 101 | PREFIX = "log-spiral/" 102 | os.makedirs(PREFIX, exist_ok=True) 103 | 104 | tmp = pd.DataFrame({"depth": [], "nu_L": [], "test_acc": []}) 105 | tmp.to_csv(PREFIX + "tmp.%s.csv" % model.model_name) 106 | 107 | for epoch in range(args.epochs): 108 | start_time = time.time() 109 | test_accuracy = train_one_epoch_classification( 110 | epoch, 111 | train_loader, 112 | valid_loader, 113 | test_loader, 114 | model, 115 | optimizer, 116 | scheduler, 117 | PREFIX=PREFIX, 118 | ) 119 | scheduler.step() 120 | print(time.time() - start_time) 121 | 122 | # torch.save(model.state_dict(), PREFIX + "model.%s.pth" % model.model_name) 123 | -------------------------------------------------------------------------------- /experiments/supervised_uci.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | 5 | import pandas as pd 6 | import torch.optim as optim 7 | import torch.utils.data 8 | 9 | from experiments.layer_generators import make_generators_fcn, make_generators_fcn_DUN 10 | from experiments.load_data import load_data_uci 11 | from src.models import ( 12 | UnboundedDepthNetwork, 13 | TruncatedPoisson, 14 | FixedDepth, 15 | CategoricalDUN, 16 | ) 17 | from src.train import train_one_epoch_regression 18 | 19 | OUTPUT_SIZE = 1 20 | BATCH_SIZE = 256 21 | CUDA = False 22 | DEVICE = torch.device("cuda" if CUDA and torch.cuda.is_available() else "cpu") 23 | print(DEVICE) 24 | 25 | DATASET_NAMES = [ 26 | "boston", 27 | "concrete", 28 | "energy", 29 | "power", 30 | "wine", 31 | "yacht", 32 | "kin8nm", 33 | "naval", 34 | "protein", 35 | ] 36 | 37 | 38 | if __name__ == "__main__": 39 | parser = argparse.ArgumentParser(description="") 40 | parser.add_argument("-e", "--epochs", default=1000, type=int, help="number of epochs") 41 | parser.add_argument("--lr", default=0.01, type=float, help="learning rate") 42 | parser.add_argument("-l", "--layers", default=-1, help="number of layers", type=int) 43 | parser.add_argument("-s", "--seed", default=0, help="random seed", type=int) 44 | parser.add_argument("--dataset", choices=DATASET_NAMES, help="spiral", type=str) 45 | parser.add_argument("-d", "--dun", action="store_true") 46 | parser.add_argument("-i", "--init", default=1.0, type=float, help="truncated poisson init") 47 | parser.add_argument("-p", "--prior", default=1.0, type=float, help=" poisson prior") 48 | parser.add_argument("--split", default=0, type=int, help=" split_id") 49 | 50 | args = parser.parse_args() 51 | 52 | SEED = args.seed 53 | L = args.layers 54 | LR = args.lr 55 | DUN = args.dun 56 | 57 | torch.manual_seed(SEED) 58 | device = DEVICE 59 | 60 | # LOAD DATA 61 | train_loader, valid_loader, test_loader, INPUT_SIZE = load_data_uci( 62 | args.dataset, args.split, BATCH_SIZE, SEED 63 | ) 64 | N_train = len(train_loader.sampler) 65 | # N_test = len(test_loader.sampler) 66 | 67 | # CREATE MODEL 68 | if DUN: 69 | generator_layers, generator_residual = make_generators_fcn_DUN(8, INPUT_SIZE, 1) 70 | else: 71 | generator_layers, generator_residual = make_generators_fcn(8, INPUT_SIZE, 1) 72 | 73 | # Negative L means UDN (TruncatedPoisson); non-negative L means either standard NN (FixedDepth) or UDN (Categorical) 74 | if L < 0: 75 | if DUN: 76 | raise ValueError 77 | else: 78 | vpost = TruncatedPoisson(1.0) 79 | else: 80 | if DUN: 81 | vpost = CategoricalDUN(L) 82 | else: 83 | vpost = FixedDepth(L) 84 | 85 | model = UnboundedDepthNetwork( 86 | N_train, 87 | generator_layers, 88 | generator_residual, 89 | vpost, 90 | INPUT_SIZE, 91 | OUTPUT_SIZE, 92 | L_prior_poisson=0.5, 93 | theta_prior_scale=1.0, 94 | seed=SEED, 95 | mode="regression", 96 | ) 97 | 98 | model.model_name += ".uci.%s-split%d.v1.O.seed-%d.lr%.4f" % ( 99 | args.dataset, 100 | args.split, 101 | SEED, 102 | LR, 103 | ) 104 | 105 | model.set_device(device) 106 | 107 | optimizer = optim.Adam( 108 | [ 109 | {"params": [p for n, p in model.named_parameters() if n != "variational_posterior_L._nu_L"]}, 110 | { 111 | "params": [p for n, p in model.named_parameters() if n == "variational_posterior_L._nu_L"], 112 | "lr": LR / 10, 113 | }, 114 | ], 115 | lr=LR, 116 | ) 117 | # optimizer = optim.Adam(model.parameters(), lr=LR) 118 | 119 | STEP = 100 120 | # scheduler is useless here, we use a schedule for cifar only 121 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[], gamma=0.6) 122 | model.set_optimizer(optimizer) 123 | PREFIX = "log-uci/" 124 | os.makedirs(PREFIX, exist_ok=True) 125 | 126 | tmp = pd.DataFrame({"depth": [], "nu_L": [], "test_acc": []}) 127 | tmp.to_csv(PREFIX + "tmp.%s.csv" % model.model_name) 128 | 129 | for epoch in range(args.epochs): 130 | start_time = time.time() 131 | test_accuracy = train_one_epoch_regression( 132 | epoch, 133 | train_loader, 134 | valid_loader, 135 | test_loader, 136 | model, 137 | optimizer, 138 | scheduler, 139 | PREFIX=PREFIX, 140 | ) 141 | scheduler.step() 142 | print(time.time() - start_time) 143 | 144 | # torch.save(model.state_dict(), PREFIX + "model.%s.pth" % model.model_name) 145 | -------------------------------------------------------------------------------- /experiments/layer_generators.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | 4 | 5 | def generator_layers_cifar(layer_id: int, is_residual: bool = True) -> nn.Module: 6 | """Generator for the hidden layers of the UDN for CIFAR-10. 7 | 8 | At every call, generate the full network layer by layer until reaching the target `layer_id`, to get 9 | the correct dimensions. 10 | 11 | Parameters 12 | ---------- 13 | layer_id: int 14 | Hidden layer to generate. 15 | is_residual: bool, default True 16 | Whether to have skip connection (residual net) 17 | 18 | Returns 19 | ------- 20 | nn.Module: 21 | hidden layer of depth `layer_id` in the CNN for CIFAR classification 22 | 23 | """ 24 | if layer_id == -1: 25 | return None, 3, 32 26 | 27 | FIRST_LAYER_CHANNEL = 64 28 | out_dim = 32 29 | if layer_id == 0: 30 | bloc = [ 31 | nn.Conv2d(3, FIRST_LAYER_CHANNEL, kernel_size=3, stride=1, padding=1, bias=False), 32 | nn.BatchNorm2d(FIRST_LAYER_CHANNEL), 33 | nn.ReLU(), 34 | ] 35 | return nn.Sequential(*bloc), FIRST_LAYER_CHANNEL, out_dim 36 | 37 | in_channel = FIRST_LAYER_CHANNEL 38 | total = 0 39 | for block_id in range(1, 4): 40 | n_of_blocks = [0, 3, 5, 1000][block_id] 41 | channels_of_block = 2 ** (block_id + 5) 42 | 43 | for i in range(n_of_blocks): 44 | total += 1 45 | if i == 0: 46 | stride = 2 47 | else: 48 | stride = 1 49 | 50 | out_channel = channels_of_block * BottleneckCNN.expansion 51 | out_dim //= stride 52 | 53 | if total == layer_id: 54 | bloc = BottleneckCNN(in_channel, channels_of_block, stride, is_residual) 55 | return bloc, out_channel, out_dim 56 | 57 | in_channel = out_channel 58 | 59 | 60 | def generator_output_cifar(layer_id: int, generator_hidden_layers): 61 | """Generator for the output layers of the UDN for CIFAR-10. 62 | 63 | At every call, generate the hidden layer to obtain the correct input dimension. 64 | 65 | Parameters 66 | ---------- 67 | layer_id: int 68 | Hidden layer to generate. 69 | generator_hidden_layers: callable 70 | Generator of the hidden layers. 71 | 72 | Returns 73 | ------- 74 | nn.Module: 75 | output layer of depth L in the CNN for CIFAR classification 76 | 77 | """ 78 | _, last_channels, last_dim = generator_hidden_layers(layer_id) 79 | last_hidden_size = (last_dim // 4) ** 2 * last_channels 80 | layers = [ 81 | nn.AvgPool2d(4), 82 | nn.Flatten(), 83 | nn.Linear(last_hidden_size, 10), 84 | ] 85 | 86 | return nn.Sequential(*layers) 87 | 88 | 89 | class BottleneckCNN(nn.Module): 90 | """Construction block for the CNN""" 91 | 92 | expansion = 4 93 | 94 | def __init__(self, in_channels, out_channels, stride=1, residual=True): 95 | super().__init__() 96 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) 97 | self.bn1 = nn.BatchNorm2d(out_channels) 98 | self.conv2 = nn.Conv2d( 99 | out_channels, 100 | out_channels, 101 | kernel_size=3, 102 | stride=stride, 103 | padding=1, 104 | bias=False, 105 | ) 106 | self.bn2 = nn.BatchNorm2d(out_channels) 107 | self.conv3 = nn.Conv2d(out_channels, self.expansion * out_channels, kernel_size=1, bias=False) 108 | self.bn3 = nn.BatchNorm2d(self.expansion * out_channels) 109 | 110 | self.is_residual = residual 111 | self.shortcut = nn.Sequential() 112 | if (stride != 1 or in_channels != self.expansion * out_channels) and self.is_residual: 113 | self.shortcut = nn.Sequential( 114 | nn.Conv2d( 115 | in_channels, 116 | self.expansion * out_channels, 117 | kernel_size=1, 118 | stride=stride, 119 | bias=False, 120 | ), 121 | nn.BatchNorm2d(self.expansion * out_channels), 122 | ) 123 | 124 | def forward(self, x): 125 | out = F.relu(self.bn1(self.conv1(x))) 126 | out = F.relu(self.bn2(self.conv2(out))) 127 | out = self.bn3(self.conv3(out)) 128 | if self.is_residual: 129 | out += self.shortcut(x) 130 | out = F.relu(out) 131 | return out 132 | 133 | 134 | def make_generators_fcn(hidden_size, input_size, output_size): 135 | """Returns generators for a simple infinitely deep fully connected neural network of constant hidden dimension.""" 136 | 137 | def hidden_layers(layer_id): 138 | if layer_id == -1: 139 | return None, input_size 140 | 141 | if layer_id == 0: 142 | return ( 143 | nn.Sequential(nn.Linear(input_size, hidden_size), nn.ReLU()), 144 | hidden_size, 145 | ) 146 | 147 | return ( 148 | nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.ReLU()), 149 | hidden_size, 150 | ) 151 | 152 | def output_layers(layer_id, hidden_layers): 153 | _, s = hidden_layers(layer_id) 154 | return nn.Linear(s, output_size) 155 | 156 | return hidden_layers, output_layers 157 | 158 | 159 | def make_generators_fcn_DUN(hidden_size, input_size, output_size): 160 | """Returns generators for a simple infinitely deep fully connected neural network of constant hidden dimension. 161 | DUN is a bit delicate because the output layers are shared. The same layer needs to be returned from multiple calls 162 | of output_layers. 163 | """ 164 | 165 | def main_layers(layer_id): 166 | if layer_id == -1: 167 | return None, input_size 168 | 169 | if layer_id == 0: 170 | return ( 171 | nn.Sequential(nn.Linear(input_size, hidden_size), nn.ReLU()), 172 | hidden_size, 173 | ) 174 | 175 | return ( 176 | nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.ReLU()), 177 | hidden_size, 178 | ) 179 | 180 | o_layer = nn.Linear(input_size, output_size) 181 | final_layer = nn.Linear(hidden_size, output_size) 182 | 183 | def output_layers(layer_id, hidden_layers): 184 | if layer_id >= 0: 185 | return final_layer 186 | else: 187 | return o_layer 188 | 189 | return main_layers, output_layers 190 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | import tqdm 7 | 8 | from src.models import UnboundedDepthNetwork 9 | 10 | 11 | def train_one_epoch_classification( 12 | epoch, 13 | train_loader, 14 | valid_loader, 15 | test_loader, 16 | model: UnboundedDepthNetwork, 17 | optimizer, 18 | scheduler, 19 | PREFIX="./", 20 | normalize_loss=False, 21 | ): 22 | """Train the model for one epoch and log a lot of metrics.""" 23 | train_loss_epoch, train_one_epoch_time = train(model, train_loader, optimizer, normalize_loss) 24 | ( 25 | validation_accuracy, 26 | validation_predictive_loss, 27 | validation_accuracy_counts_per_layer, 28 | validation_predictions, 29 | validation_brier_score, 30 | validation_true_labels, 31 | ) = evaluate_classification(model, valid_loader) 32 | 33 | ( 34 | test_accuracy, 35 | test_predictive_loss, 36 | test_accuracy_counts_per_layer, 37 | test_predictions, 38 | test_brier_score, 39 | test_true_labels, 40 | ) = evaluate_classification(model, test_loader) 41 | 42 | # i don't think the if is needed, model.variational_posterior_L.compute_depth() and .mean() should always work. 43 | if isinstance(model, UnboundedDepthNetwork): 44 | depth_max = model.current_depth 45 | depth_mean = model.variational_posterior_L.mean() 46 | else: 47 | depth_max = model.n_layers 48 | depth_mean = model.n_layers 49 | 50 | log_string = "Epoch: {}, Train Loss: {:.8f}, Test Accuracy: {:.8f}, Mean Post L: {:.2f}" 51 | print(log_string.format(epoch + 1, train_loss_epoch, test_accuracy, depth_mean)) 52 | 53 | tmp = pd.read_csv(PREFIX + "tmp.%s.csv" % model.model_name, index_col=0) 54 | df_args = { 55 | "depth": [depth_max], 56 | "nu_L": depth_mean, 57 | "test_accuracy": test_accuracy, 58 | "validation_accuracy": validation_accuracy, 59 | "test_predictive_LL": test_predictive_loss, 60 | "validation_predictive_LL": validation_predictive_loss, 61 | "lr": scheduler.get_last_lr()[0], 62 | "test_brier": test_brier_score, 63 | "validation_brier": validation_brier_score, 64 | "train_time_one_epoch": train_one_epoch_time, 65 | "size_train": len(train_loader.sampler), 66 | "size_validation": len(valid_loader.sampler), 67 | "size_test": len(test_loader.sampler), 68 | "train_loss": train_loss_epoch, 69 | } 70 | 71 | for i, acc in enumerate(test_accuracy_counts_per_layer): 72 | df_args["test_accuracy_layer_%d" % (i)] = acc 73 | 74 | tmp = tmp.append(pd.DataFrame(df_args)) 75 | tmp.to_csv(PREFIX + "tmp.%s.csv" % model.model_name) 76 | 77 | return test_accuracy 78 | 79 | 80 | def train(model, train_loader, optimizer, normalize_loss=False): 81 | """ 82 | Train the model for one epoch 83 | """ 84 | train_loss_epoch = 0 85 | iterations = 0 86 | 87 | start_time = time.time() 88 | model.train() 89 | for features, target in tqdm.tqdm(train_loader): 90 | optimizer.zero_grad() 91 | features = features.to(model.device) 92 | target = target.to(model.device) 93 | loss = model.loss(features, target) 94 | if normalize_loss: 95 | loss = loss / model.n_obs 96 | train_loss_epoch += loss.item() 97 | 98 | loss.backward() 99 | optimizer.step() 100 | iterations += 1 101 | train_loss_epoch = train_loss_epoch / iterations 102 | train_one_epoch_time = time.time() - start_time 103 | 104 | return train_loss_epoch, train_one_epoch_time 105 | 106 | 107 | def evaluate_classification(model, evaluation_loader): 108 | """Evaluate the model for classification.""" 109 | accuracy_counts = 0 110 | accuracy_counts_per_layer = 0 111 | predictions = [] 112 | true_labels = [] 113 | predictive_loss = torch.tensor(0.0) 114 | brier_score = torch.tensor(0.0) 115 | 116 | model.eval() 117 | 118 | for features, labels in evaluation_loader: 119 | features = features.to(model.device) 120 | labels = labels.cpu() 121 | forward_pass = model(features) 122 | pred = forward_pass["predictions_global"].detach().cpu() 123 | 124 | accuracy_counts_per_layer_batch = np.array( 125 | [ 126 | (torch.max(p.detach().cpu(), dim=1).indices == labels).sum().item() 127 | for p in forward_pass["predictions_per_layer"] 128 | ] 129 | ) 130 | accuracy_counts_per_layer += accuracy_counts_per_layer_batch 131 | predictions.append(pred) 132 | true_labels.append(labels) 133 | predictive_loss += torch.gather(pred, 1, labels.view(-1, 1)).log().sum() 134 | brier_score += (pred.pow(2).sum() + (1 - 2 * torch.gather(pred, 1, labels.view(-1, 1))).sum()).item() 135 | accuracy_counts += (torch.max(pred, dim=1).indices == labels).sum().item() 136 | 137 | if len(evaluation_loader.sampler): 138 | accuracy = accuracy_counts / len(evaluation_loader.sampler) 139 | accuracy_counts_per_layer = accuracy_counts_per_layer / len(evaluation_loader.sampler) 140 | predictions = torch.cat(predictions, dim=0).numpy() 141 | true_labels = torch.cat(true_labels, dim=0).numpy() 142 | else: 143 | accuracy = 0 144 | accuracy_counts_per_layer = 0 145 | 146 | predictive_loss = predictive_loss.item() 147 | brier_score = brier_score.item() 148 | 149 | return accuracy, predictive_loss, accuracy_counts_per_layer, predictions, brier_score, true_labels 150 | 151 | 152 | def train_one_epoch_regression( 153 | epoch, 154 | train_loader, 155 | valid_loader, 156 | test_loader, 157 | model: UnboundedDepthNetwork, 158 | optimizer, 159 | scheduler, 160 | PREFIX="./", 161 | normalize_loss=False, 162 | ): 163 | train_loss_epoch, train_one_epoch_time = train(model, train_loader, optimizer, normalize_loss) 164 | validation_predictive_loss = evaluate_regression(model, valid_loader) 165 | test_predictive_loss = evaluate_regression(model, test_loader) 166 | 167 | # i don't think the if is needed, model.variational_posterior_L.compute_depth() and .mean() should always work. 168 | if isinstance(model, UnboundedDepthNetwork): 169 | depth_max = model.current_depth 170 | depth_mean = model.variational_posterior_L.mean() 171 | else: 172 | depth_max = model.n_layers 173 | depth_mean = model.n_layers 174 | print( 175 | "Epoch: {}, Train Loss: {:.8f}, Val RMSE: {:.8f}, Mean Post L: {:.2f}".format( 176 | epoch + 1, train_loss_epoch, validation_predictive_loss, depth_mean 177 | ), 178 | ) 179 | 180 | tmp = pd.read_csv(PREFIX + "tmp.%s.csv" % model.model_name, index_col=0) 181 | df_args = { 182 | "depth": [depth_max], 183 | "nu_L": depth_mean, 184 | "test_rmse": test_predictive_loss, 185 | "validation_rmse": validation_predictive_loss, 186 | "lr": scheduler.get_last_lr()[0], 187 | "train_time_one_epoch": train_one_epoch_time, 188 | "size_train": len(train_loader.sampler), 189 | "size_validation": len(valid_loader.sampler), 190 | "size_test": len(test_loader.sampler), 191 | "train_loss": train_loss_epoch, 192 | } 193 | 194 | tmp = tmp.append(pd.DataFrame(df_args)) 195 | tmp.to_csv(PREFIX + "tmp.%s.csv" % model.model_name) 196 | 197 | return test_predictive_loss 198 | 199 | 200 | def evaluate_regression(model, evaluation_loader): 201 | predictive_loss = torch.tensor(0.0) 202 | for features, labels in evaluation_loader: 203 | features = features.to(model.device) 204 | labels = labels.cpu() 205 | pred = model(features)["predictions_global"].detach().cpu() 206 | predictive_loss += ((pred - labels) ** 2).sum() 207 | 208 | if len(evaluation_loader.sampler): 209 | predictive_loss = (predictive_loss.item() / len(evaluation_loader.sampler)) ** 0.5 210 | 211 | return predictive_loss 212 | -------------------------------------------------------------------------------- /experiments/load_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torchvision 5 | from torch.utils.data import SubsetRandomSampler 6 | from torchvision import transforms 7 | 8 | 9 | def generate_spiral(phase: float, n: int = 1000, seed: int = 0): 10 | """Generate a spiral dataset 11 | 12 | Parameters 13 | ---------- 14 | phase: float 15 | Phase of the spiral. 16 | n: int 17 | Number of samples. 18 | seed: int 19 | Random seed. 20 | 21 | Returns 22 | ------- 23 | labels: list of int 24 | Labels of each each generated point, corresponding to which arm of the spiral the point is from. 25 | xy: array like 26 | Coordinates of the generated points 27 | 28 | """ 29 | omega = lambda x: phase * np.pi / 2 * np.abs(x) 30 | rng = np.random.default_rng(seed) 31 | ts = rng.uniform(-1, 1, n) 32 | ts = np.sign(ts) * np.sqrt(np.abs(ts)) 33 | xy = np.array([ts * np.cos(omega(ts)), ts * np.sin(omega(ts))]).T 34 | xy = rng.normal(xy, 0.02) 35 | labels = (ts >= 0).astype(int) 36 | return labels, xy 37 | 38 | 39 | def load_data_spiral(phase: float, batch_size: int, seed: int = 0): 40 | """Build the dataloaders for the spiral dataset. 41 | The spiral datasets are generated and then wrapped in a dataloader. 42 | 43 | Parameters 44 | ---------- 45 | phase: float 46 | Phase of the spiral. 47 | batch_size: int 48 | Batch size of the dataloaders. 49 | seed: int 50 | Random seed to use to generate the spiral data. 51 | 52 | Returns 53 | ------- 54 | train_loader: DataLoader 55 | DataLoader for the training points of the spiral dataset. 56 | valid_loader: DataLoader 57 | DataLoader for the validation points of the spiral dataset. 58 | train_loader: DataLoader 59 | test_loader for the testing points of the spiral dataset. 60 | 61 | """ 62 | t, xy = generate_spiral(phase, n=1024, seed=0 + seed) 63 | t_val, xy_val = generate_spiral(phase, n=1024, seed=1 + seed) 64 | t_test, xy_test = generate_spiral(phase, n=1024, seed=2 + seed) 65 | 66 | torch.manual_seed(seed) 67 | 68 | train_dataset = torch.utils.data.TensorDataset(torch.FloatTensor(xy), torch.LongTensor(t)) 69 | val_dataset = torch.utils.data.TensorDataset(torch.FloatTensor(xy_val), torch.LongTensor(t_val)) 70 | test_dataset = torch.utils.data.TensorDataset(torch.FloatTensor(xy_test), torch.LongTensor(t_test)) 71 | 72 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 73 | valid_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False) 74 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False) 75 | 76 | return train_loader, valid_loader, test_loader 77 | 78 | 79 | def load_data_cifar(batch_size: int, seed: int = 0, validation_size: float = 0.2, filter_labels=None): 80 | """ 81 | Load the CIFAR-10 dataset and wrap it in torch DataLoaders: train, validation, test. 82 | Has 10 classes: ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 83 | 84 | Parameters 85 | ---------- 86 | batch_size: int 87 | Batch size for the data loaders. 88 | seed: int 89 | Seed for splitting the dataset into train and validation. 90 | validation_size: float between 0.0 and 1.0, default 0.2 91 | Proportion of the train set used for the validation set. 92 | 93 | Returns 94 | ------- 95 | train_loader, valid_loader, test_loader 96 | 97 | """ 98 | transform_train = transforms.Compose( 99 | [ 100 | transforms.RandomCrop(32, padding=4), 101 | transforms.RandomHorizontalFlip(), 102 | transforms.ToTensor(), 103 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 104 | ] 105 | ) 106 | 107 | transform_test = transforms.Compose( 108 | [ 109 | transforms.ToTensor(), 110 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 111 | ] 112 | ) 113 | 114 | train_set = torchvision.datasets.CIFAR10( 115 | root="./data", train=True, download=True, transform=transform_train 116 | ) 117 | 118 | test_set = torchvision.datasets.CIFAR10( 119 | root="./data", train=False, download=True, transform=transform_test 120 | ) 121 | 122 | torch.manual_seed(seed) 123 | np.random.seed(seed) 124 | if filter_labels is not None: 125 | train_val_indices = [i for i, v in enumerate(train_set.targets) if v in filter_labels] 126 | else: 127 | train_val_indices = np.arange(len(train_set)) 128 | np.random.shuffle(train_val_indices) 129 | split = int(np.floor(validation_size * len(train_val_indices))) 130 | train_idx, valid_idx = train_val_indices[split:], train_val_indices[:split] 131 | 132 | if filter_labels is not None: 133 | test_idx = [i for i, v in enumerate(test_set.targets) if v in filter_labels] 134 | else: 135 | test_idx = np.arange(len(test_set)) 136 | 137 | train_sampler = SubsetRandomSampler(train_idx) 138 | valid_sampler = SubsetRandomSampler(valid_idx) 139 | test_sampler = SubsetRandomSampler(test_idx) 140 | 141 | if os.environ["HOME"] == "/Users/achille": 142 | # should be my laptop 143 | print("Local laptop") 144 | num_workers = 0 145 | elif os.environ["HOME"] == "/home/achille": 146 | # Should be the cluster 147 | num_workers = 2 148 | train_loader = torch.utils.data.DataLoader( 149 | train_set, batch_size=batch_size, sampler=train_sampler, num_workers=num_workers 150 | ) 151 | valid_loader = torch.utils.data.DataLoader( 152 | train_set, batch_size=batch_size, sampler=valid_sampler, num_workers=num_workers 153 | ) 154 | test_loader = torch.utils.data.DataLoader( 155 | test_set, 156 | batch_size=batch_size, 157 | shuffle=False, 158 | num_workers=num_workers, 159 | sampler=test_sampler, 160 | ) 161 | 162 | return train_loader, valid_loader, test_loader 163 | 164 | 165 | def load_data_uci(dataset_name, n_split, batchsize, seed=0): 166 | base_dir = "data" 167 | dir_load = base_dir + "/UCI_for_sharing/standard/" + dataset_name + "/data/" 168 | 169 | data = np.loadtxt(dir_load + "data.txt") 170 | feature_idx = np.loadtxt(dir_load + "index_features.txt").astype(int) 171 | target_idx = np.loadtxt(dir_load + "index_target.txt").astype(int) 172 | 173 | np.random.seed(n_split) 174 | indices = np.array(list(range(data.shape[0]))) 175 | np.random.shuffle(indices) 176 | train_end, val_end = int(len(indices) * 0.8), int(len(indices) * 0.9) 177 | train_idx = indices[:train_end] 178 | validation_idx = indices[train_end:val_end] 179 | test_idx = indices[val_end:] 180 | np.random.seed(0) 181 | 182 | data_train = data[train_idx] 183 | data_validation = data[validation_idx] 184 | data_test = data[test_idx] 185 | 186 | X_train = data_train[:, feature_idx].astype(np.float32) 187 | X_test = data_test[:, feature_idx].astype(np.float32) 188 | X_validation = data_validation[:, feature_idx].astype(np.float32) 189 | y_train = data_train[:, target_idx].astype(np.float32) 190 | y_test = data_test[:, target_idx].astype(np.float32) 191 | y_validation = data_validation[:, target_idx].astype(np.float32) 192 | 193 | x_means, x_stds = X_train.mean(axis=0), X_train.std(axis=0) 194 | y_means, y_stds = y_train.mean(axis=0), y_train.std(axis=0) 195 | 196 | x_stds[x_stds < 1e-10] = 1.0 197 | 198 | X_train = (X_train - x_means) / x_stds 199 | y_train = ((y_train - y_means) / y_stds)[:, np.newaxis] 200 | X_test = (X_test - x_means) / x_stds 201 | y_test = ((y_test - y_means) / y_stds)[:, np.newaxis] 202 | X_validation = (X_validation - x_means) / x_stds 203 | y_validation = ((y_validation - y_means) / y_stds)[:, np.newaxis] 204 | 205 | torch.manual_seed(seed) 206 | print(y_stds) 207 | 208 | train_dataset = torch.utils.data.TensorDataset(torch.FloatTensor(X_train), torch.FloatTensor(y_train)) 209 | val_dataset = torch.utils.data.TensorDataset( 210 | torch.FloatTensor(X_validation), torch.FloatTensor(y_validation) 211 | ) 212 | test_dataset = torch.utils.data.TensorDataset(torch.FloatTensor(X_test), torch.FloatTensor(y_test)) 213 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batchsize, shuffle=True) 214 | valid_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batchsize, shuffle=False) 215 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batchsize, shuffle=False) 216 | 217 | return train_loader, valid_loader, test_loader, x_means.shape[0] 218 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | import numpy as np 4 | import scipy.stats as st 5 | import torch 6 | import torch.distributions as dist 7 | import torch.utils.data 8 | import torch.nn as nn 9 | 10 | 11 | def softplus_inverse(x): 12 | """log(exp(x) - 1)""" 13 | return torch.where(x > 10, x, x.expm1().log()) 14 | 15 | 16 | class VariationalDepth(nn.Module, abc.ABC): 17 | """ 18 | Abstract class for a variational posterior approximation q(L) over the depth L of the UDN 19 | 20 | Methods 21 | ------- 22 | compute_depth() 23 | Returns the largest value of L with non-zero mass: max(i | q(L=i) > 0) 24 | probability_vector() 25 | Returns the vector of probabilities of q over integers up to its depth: [p(L=i) for i=0 to depth]. 26 | mean() 27 | Returns the expectation of q: E[L] for L ~ q(L). 28 | """ 29 | 30 | def __init__(self): 31 | super().__init__() 32 | 33 | @abc.abstractmethod 34 | def compute_depth(self): 35 | pass 36 | 37 | @abc.abstractmethod 38 | def probability_vector(self): 39 | pass 40 | 41 | @abc.abstractmethod 42 | def mean(self): 43 | pass 44 | 45 | 46 | class FixedDepth(VariationalDepth): 47 | """ 48 | Variational posterior approximation q(L) which is a constant mass at a given depth `n_layers`. 49 | Used to emulate a standard finite neural networks with `n_layers` layers. 50 | 51 | Parameters 52 | ------- 53 | n_layers: int 54 | Depth such that q(L=n_layers) = 1 55 | 56 | """ 57 | 58 | def __init__(self, n_layers: int): 59 | super().__init__() 60 | self.n_layers = n_layers 61 | 62 | def compute_depth(self): 63 | return self.n_layers 64 | 65 | def mean(self): 66 | return self.n_layers 67 | 68 | def probability_vector(self): 69 | v = torch.zeros(self.n_layers + 1, requires_grad=False) 70 | v[-1] = 1.0 71 | return v 72 | 73 | 74 | class TruncatedPoisson(VariationalDepth): 75 | """ 76 | Variational posterior approximation q(L) which is a Truncated Poisson. 77 | Used to adapt the depth during training. 78 | 79 | Parameters 80 | ------- 81 | initial_nu_L: float, default 2.0 82 | Initial value of the variational parameter nu_L. nu_L is almost equal to the 83 | mean of the TruncatedPoisson, so the defaults 2.0 starts with two layers. 84 | 85 | truncation_quantile: float (between 0.0 and 1.0), default 0.95 86 | Truncation level of the Truncated Poisson, recommended to leave at 0.95 87 | 88 | """ 89 | 90 | def __init__(self, initial_nu_L: float = 2.0, truncation_quantile: float = 0.95): 91 | super().__init__() 92 | self.truncation_quantile = truncation_quantile 93 | self._nu_L = nn.Parameter(softplus_inverse(torch.tensor(float(initial_nu_L)))) 94 | 95 | @property 96 | def nu_L(self): 97 | """Returns the variational parameter nu_L, which is reparametrized to be positive.""" 98 | return nn.Softplus()(self._nu_L) 99 | 100 | def compute_depth(self): 101 | p = st.poisson(self.nu_L.item()) 102 | for a in range(int(self.nu_L.item()) + 1, 10000): 103 | if p.cdf(a) >= self.truncation_quantile: 104 | return a + 1 105 | raise Exception() 106 | 107 | def probability_vector(self): 108 | depth = self.compute_depth() 109 | ks = torch.arange(0, depth, dtype=self._nu_L.dtype, device=self._nu_L.device) 110 | alpha_L = (ks * self.nu_L.log() - torch.lgamma(ks + 1)).exp() 111 | alpha_L = torch.cat([torch.zeros(1, device=ks.device, dtype=ks.dtype), alpha_L]) 112 | return alpha_L / alpha_L.sum() 113 | 114 | def mean(self): 115 | proba = self.probability_vector().cpu() 116 | return (proba * torch.arange(len(proba))).sum().item() 117 | 118 | 119 | class CategoricalDUN(VariationalDepth): 120 | """ 121 | Variational posterior approximation q(L) which is a categorical (=non parametric) distribution of fixed 122 | depth. It is equivalent to using a Depth Uncertainty Network [1]. 123 | 124 | Parameters 125 | ------- 126 | max_depth: int 127 | Depth of the categorical q(L). 128 | 129 | References 130 | ---------- 131 | [1] 132 | 133 | """ 134 | 135 | def __init__(self, max_depth): 136 | super().__init__() 137 | self.max_depth = max_depth 138 | self.logits = nn.Parameter(torch.ones(self.max_depth)) 139 | 140 | def compute_depth(self): 141 | return self.max_depth 142 | 143 | def probability_vector(self): 144 | probs = nn.Softmax(dim=0)(self.logits) 145 | return torch.cat([torch.zeros(1, device=probs.device, dtype=probs.dtype), probs]) 146 | 147 | def mean(self): 148 | proba = self.probability_vector().cpu() 149 | return (proba * torch.arange(len(proba))).sum().item() 150 | 151 | 152 | # Shift a Poisson distribution by 1 so it only takes (strictly) positive values. 153 | PositivePoisson = lambda p: torch.distributions.TransformedDistribution( 154 | torch.distributions.Poisson(p, validate_args=False), 155 | torch.distributions.AffineTransform(1, 1), 156 | ) 157 | 158 | 159 | class UnboundedDepthNetwork(nn.Module): 160 | """ 161 | Abstract class for a variational posterior approximation q(L) over the depth L of the UDN 162 | 163 | Parameters 164 | ------- 165 | n_obs: int 166 | Number of observations that will e used for training. It is needed to scale prior when doing 167 | stochastic variational optimization. 168 | hidden_layer_generator: callable 169 | function that takes an integer L and return a torch.nn.Module representing hidden layer L. 170 | output_layer_generator: callable 171 | function that takes an integer L and return a torch.nn.Module representing output layer L. 172 | L_variational_distribution: VariationalDepth 173 | The variational distribution q(L) 174 | in_dimension: int 175 | Input dimension of the neural network 176 | out_dimension: int 177 | Output dimension of the neural network 178 | mode: str {"classification", "regression"} 179 | Specify if the neural network is for regression or for classification. It impacts the forward pass. 180 | L_prior_poisson: float 181 | Mean of the Poisson prior. 182 | theta_prior_scale: float 183 | Standard deviation (scale) of the Gaussian prior for the neural network weights. 184 | seed: int 185 | Random seed for the initialization of the neural networks layers. 186 | 187 | Methods 188 | ------- 189 | set_optimizer() 190 | Set the optimizer to later add the dynamically created layers's parameters to it. 191 | set_device() 192 | Set the device of the model. 193 | update_depth() 194 | Compute the current maximal depth of the variational posterior q(L) and create new layers if needed. 195 | loss() 196 | Compute the loss (to minimize) of the UDN for the variational inference. 197 | elbo() 198 | Compute the ELBO (to maximize) of the UDN for the variational inference. 199 | forward 200 | 201 | """ 202 | 203 | def __init__( 204 | self, 205 | n_obs: int, 206 | hidden_layer_generator, 207 | output_layer_generator, 208 | L_variational_distribution: VariationalDepth, 209 | in_dimension: int, 210 | out_dimension: int, 211 | mode: str = "classification", 212 | L_prior_poisson=1.0, 213 | theta_prior_scale=10.0, 214 | seed=0, 215 | ): 216 | super().__init__() 217 | torch.manual_seed(seed) 218 | 219 | self.n_obs = n_obs 220 | self.device = None 221 | self.optimizer = None 222 | 223 | self.hidden_layer_generator = hidden_layer_generator 224 | self.output_layer_generator = output_layer_generator 225 | 226 | self.variational_posterior_L = L_variational_distribution 227 | 228 | # set priors 229 | self.theta_prior_scale = theta_prior_scale 230 | self.prior_theta = dist.Normal(0, self.theta_prior_scale) 231 | self.prior_L = PositivePoisson(torch.tensor(float(L_prior_poisson))) 232 | if isinstance(self.variational_posterior_L, FixedDepth): 233 | # in this case the prior dosen't matter (since q is fixed), yet if q(L) is set to have a depth of 0 234 | # then the prior cannot be a PositivePoisson, as the posterior would have mass outside the prior 235 | # support. We set it to a regular Poisson just to avoid computation error. 236 | if self.variational_posterior_L.n_layers == 0: 237 | self.prior_L = torch.distributions.Poisson(torch.tensor(float(L_prior_poisson))) 238 | elif isinstance(self.variational_posterior_L, CategoricalDUN): 239 | # For the DUN, we set a uniform prior. 240 | d = self.variational_posterior_L.max_depth 241 | self.prior_L = torch.distributions.Categorical(torch.tensor([0.0] + [1 / d] * d)) 242 | 243 | self.in_dimension = in_dimension 244 | self.out_dimension = out_dimension 245 | self.mode = mode 246 | 247 | # We generate only the first output layer at first (from the input directly to the output) 248 | self.hidden_layers = nn.ModuleList([]) 249 | self.output_layers = nn.ModuleList([self.output_layer_generator(-1, self.hidden_layer_generator)]) 250 | 251 | self.current_depth = None 252 | # self.update_depth() 253 | 254 | if isinstance(self.variational_posterior_L, TruncatedPoisson): 255 | self.model_name = "UDN-inf" 256 | elif isinstance(self.variational_posterior_L, CategoricalDUN): 257 | self.model_name = "UDN-DUN%d" % self.variational_posterior_L.max_depth 258 | else: 259 | self.model_name = "UDN-f%d" % self.variational_posterior_L.n_layers 260 | 261 | # The DUN share the output layers. We won't add the additional output layers to the optimizer. 262 | if isinstance(self.variational_posterior_L, CategoricalDUN): 263 | self._add_output_layer_to_optimizer = False 264 | else: 265 | self._add_output_layer_to_optimizer = True 266 | 267 | def set_optimizer(self, optimizer: torch.optim.Optimizer): 268 | """Set the optimizer to later add the dynamically created layers's parameters to it.""" 269 | self.optimizer = optimizer 270 | 271 | def set_device(self, device): 272 | """Set the device of the model.""" 273 | self.to(device) 274 | self.device = device 275 | 276 | def update_depth(self): 277 | """ 278 | Compute the current maximal depth of the variational posterior q(L) and create new layers if needed. 279 | """ 280 | self.current_depth = self.variational_posterior_L.compute_depth() 281 | while self.current_depth > len(self.hidden_layers): 282 | layer, *_ = self.hidden_layer_generator(len(self.hidden_layers)) 283 | output_layer = self.output_layer_generator(len(self.hidden_layers), self.hidden_layer_generator) 284 | 285 | layer.to(self.device) 286 | output_layer.to(self.device) 287 | 288 | self.hidden_layers.append(layer) 289 | self.output_layers.append(output_layer) 290 | 291 | if self.optimizer is not None: 292 | self.optimizer.param_groups[0]["params"].extend(self.hidden_layers[-1].parameters()) 293 | if self._add_output_layer_to_optimizer or len(self.output_layers) == 2: 294 | self.optimizer.param_groups[0]["params"].extend(self.output_layers[-1].parameters()) 295 | 296 | def loss(self, X, y): 297 | """Compute the loss (to minimize) of the UDN for the variational inference.""" 298 | return -self.elbo(X, y) 299 | 300 | def elbo(self, X, y): 301 | """Compute the ELBO (to maximize) of the UDN for the variational inference.""" 302 | res = self.forward(X, y) 303 | return sum(res["losses"]) + res["entropy_qL"] 304 | 305 | def forward(self, X, y=None): 306 | """ 307 | Compute the neural network output, in a single forward pass. 308 | Returns a detailed description of the forward pass. 309 | If y is given, also computes the ELBO, otherwise, just computes the predictions. 310 | 311 | Returns 312 | ------- 313 | predictions_global: array like of shape (X.shape[0], self.out_dimension) 314 | Posterior predictive expectation (averaged over the layers according to the posterior q(L)) 315 | predictions_per_layer: list of array like of shape (X.shape[0], self.out_dimension) 316 | Posterior predictive expectation of each layer 317 | losses: list of scalar torch.Tensor 318 | Loss for each layer (computes the elbo for all the term related to each layer except the entropy of qL) 319 | entropy_qL: scalar torch.Tensor 320 | Entropy of qL 321 | logp_per_layer: 322 | Detailed access to layer specific ELBO terms, here the reconstruction. 323 | logp_L_per_layer: 324 | Detailed access to layer specific ELBO terms, here the prior regularization for L. 325 | logp_theta_per_layer: 326 | Detailed access to layer specific ELBO terms, here the prior regularization for L theta. 327 | logp_y_per_layer: 328 | Detailed access to layer specific ELBO terms, here the predictive likelihood per layer. 329 | 330 | """ 331 | TRAIN_OUTPUT_LAYERS = True 332 | 333 | self.update_depth() 334 | variational_qL_probabilities = self.variational_posterior_L.probability_vector() 335 | 336 | intermediary_state_list = [] 337 | output_state_list = [] 338 | 339 | logp_theta_hidden_list = [] 340 | logp_theta_output_list = [] 341 | logp_theta_both_list = [] 342 | logp_L_list = [] 343 | logp_list = [] 344 | logp_y_list = [] 345 | predictions_per_layer = [] 346 | losses = [] 347 | 348 | log_theta_hidden_cumulative = torch.tensor(0.0, device=X.device) 349 | global_predictions = torch.zeros(X.shape[0], self.out_dimension, device=X.device, dtype=X.dtype) 350 | 351 | current_state = X 352 | i = 0 353 | while len(intermediary_state_list) - 1 < self.current_depth: 354 | a = variational_qL_probabilities[i] 355 | 356 | if i > 0: 357 | hidden_layer = self.hidden_layers[i - 1] 358 | current_state = hidden_layer(current_state) 359 | # logp_theta_hidden = sum([self.prior_theta.log_prob(p).sum() for p in hidden_layer.parameters()]) 360 | logp_theta_hidden = sum( 361 | [-(p ** 2).sum() / 2 / self.theta_prior_scale ** 2 for p in hidden_layer.parameters()] 362 | ) 363 | else: 364 | logp_theta_hidden = torch.tensor(0.0, device=X.device) 365 | 366 | output_layer = self.output_layers[i] 367 | # logp_theta_output = sum([self.prior_theta.log_prob(p).sum() for p in output_layer.parameters()]) 368 | logp_theta_output = sum( 369 | [-(p ** 2).sum() / 2 / self.theta_prior_scale ** 2 for p in output_layer.parameters()] 370 | ) 371 | logp_L = self.prior_L.log_prob(torch.tensor(i).float()) 372 | 373 | log_theta_hidden_cumulative += logp_theta_hidden 374 | # log_theta_output_cumulative += logp_theta_output 375 | 376 | if a.item() > 0: 377 | current_output = output_layer(current_state) 378 | else: 379 | # No weight on this layer, we compute just for inspection; no gradient to propagate 380 | current_output = output_layer(current_state.detach()) 381 | 382 | intermediary_state_list.append(current_state) 383 | output_state_list.append(current_output) 384 | logp_theta_hidden_list.append(log_theta_hidden_cumulative.item()) 385 | logp_theta_output_list.append(logp_theta_output.item()) 386 | logp_theta_both_list.append(log_theta_hidden_cumulative.item() + logp_theta_output.item()) 387 | logp_L_list.append(logp_L.item()) 388 | 389 | if y is not None: 390 | if self.mode == "classification": 391 | logpy = -nn.CrossEntropyLoss(reduction="mean")(current_output, y) * self.n_obs 392 | elif self.mode == "regression": 393 | # logpy = -nn.GaussianNLLLoss(reduction="mean")(current_output, y, torch.ones_like(y)) * self.n_obs 394 | logpy = (-((current_output - y) ** 2).mean() / 2 - np.log(2 * np.pi) / 2) * self.n_obs 395 | else: 396 | raise NotImplementedError 397 | else: 398 | logpy = torch.tensor(0.0, device=X.device) 399 | 400 | logp_y_list.append(logpy.item()) 401 | 402 | if a.item() > 0: 403 | logp = logpy + log_theta_hidden_cumulative + logp_theta_output + logp_L 404 | losses.append(a * logp) 405 | else: 406 | if TRAIN_OUTPUT_LAYERS: 407 | logp = logpy + logp_theta_output 408 | else: 409 | logp = torch.tensor(0.0) 410 | losses.append(logp) 411 | 412 | logp_list.append(logp.item()) 413 | 414 | if self.mode == "classification": 415 | current_predictions = nn.Softmax(dim=-1)(current_output) 416 | elif self.mode == "regression": 417 | current_predictions = current_output 418 | else: 419 | raise NotImplementedError 420 | 421 | predictions_per_layer.append(current_predictions) 422 | global_predictions = global_predictions + (a * current_predictions).detach() 423 | 424 | i += 1 425 | 426 | entropy_qL = torch.distributions.Categorical(variational_qL_probabilities).entropy() 427 | 428 | return dict( 429 | predictions_global=global_predictions, 430 | predictions_per_layer=predictions_per_layer, 431 | losses=losses, 432 | entropy_qL=entropy_qL, 433 | logp_per_layer=logp_list, 434 | logp_L_per_layer=logp_L_list, 435 | logp_theta_per_layer=logp_theta_both_list, 436 | logp_y_per_layer=logp_y_list, 437 | ) 438 | --------------------------------------------------------------------------------