├── evidential_deep_learning ├── __init__.py ├── layers │ ├── __init__.py │ ├── conv2d.py │ └── dense.py └── losses │ ├── __init__.py │ ├── discrete.py │ └── continuous.py ├── assets ├── banner.png ├── animation.gif └── cite.bib ├── neurips2020 ├── data │ └── uci │ │ ├── naval │ │ ├── README.txt │ │ └── Features.txt │ │ ├── concrete │ │ ├── Concrete_Data.xls │ │ └── Concrete_Readme.txt │ │ ├── power-plant │ │ ├── Folds5x2_pp.ods │ │ ├── Folds5x2_pp.xlsx │ │ └── Readme.txt │ │ ├── energy-efficiency │ │ └── ENB2012_data.xlsx │ │ ├── wine-quality │ │ └── winequality.names │ │ └── yacht │ │ └── yacht_hydrodynamics.data ├── trainers │ ├── __init__.py │ ├── util.py │ ├── deterministic.py │ ├── gaussian.py │ ├── bbbp.py │ ├── dropout.py │ ├── ensemble.py │ └── evidential.py ├── download_data.sh ├── models │ ├── depth │ │ ├── deterministic.py │ │ ├── ensemble.py │ │ ├── gaussian.py │ │ ├── evidential.py │ │ ├── bbbp.py │ │ └── dropout.py │ ├── toy │ │ ├── deterministic.py │ │ ├── gaussian.py │ │ ├── evidential.py │ │ ├── bbbp.py │ │ ├── ensemble.py │ │ ├── dropout.py │ │ └── h_params.py │ └── __init__.py ├── preprocess │ ├── cache_apolloscape.py │ ├── preprocess_nyu_depth.m │ └── cache_nyu_depth.py ├── train_depth.py ├── README.md ├── run_uci_dataset_tests.py ├── run_cubic_tests.py ├── data_loader.py └── gen_depth_results.py ├── CITATION.cff ├── setup.py ├── hello_world.py ├── .gitignore ├── README.md ├── environment.yml └── LICENSE /evidential_deep_learning/__init__.py: -------------------------------------------------------------------------------- 1 | from . import layers 2 | from . import losses 3 | -------------------------------------------------------------------------------- /evidential_deep_learning/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .dense import * 2 | from .conv2d import * 3 | -------------------------------------------------------------------------------- /assets/banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aamini/evidential-deep-learning/HEAD/assets/banner.png -------------------------------------------------------------------------------- /evidential_deep_learning/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .continuous import * 2 | from .discrete import * 3 | -------------------------------------------------------------------------------- /assets/animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aamini/evidential-deep-learning/HEAD/assets/animation.gif -------------------------------------------------------------------------------- /neurips2020/data/uci/naval/README.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aamini/evidential-deep-learning/HEAD/neurips2020/data/uci/naval/README.txt -------------------------------------------------------------------------------- /neurips2020/data/uci/concrete/Concrete_Data.xls: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aamini/evidential-deep-learning/HEAD/neurips2020/data/uci/concrete/Concrete_Data.xls -------------------------------------------------------------------------------- /neurips2020/data/uci/power-plant/Folds5x2_pp.ods: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aamini/evidential-deep-learning/HEAD/neurips2020/data/uci/power-plant/Folds5x2_pp.ods -------------------------------------------------------------------------------- /neurips2020/data/uci/power-plant/Folds5x2_pp.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aamini/evidential-deep-learning/HEAD/neurips2020/data/uci/power-plant/Folds5x2_pp.xlsx -------------------------------------------------------------------------------- /neurips2020/data/uci/energy-efficiency/ENB2012_data.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aamini/evidential-deep-learning/HEAD/neurips2020/data/uci/energy-efficiency/ENB2012_data.xlsx -------------------------------------------------------------------------------- /neurips2020/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .bbbp import BBBP 2 | from .dropout import Dropout 3 | from .ensemble import Ensemble 4 | from .evidential import Evidential 5 | from .gaussian import Gaussian 6 | from .deterministic import Deterministic 7 | -------------------------------------------------------------------------------- /assets/cite.bib: -------------------------------------------------------------------------------- 1 | @article{amini2020deep, 2 | title={Deep evidential regression}, 3 | author={Amini, Alexander and Schwarting, Wilko and Soleimany, Ava and Rus, Daniela}, 4 | journal={Advances in Neural Information Processing Systems}, 5 | volume={33}, 6 | year={2020} 7 | } 8 | 9 | -------------------------------------------------------------------------------- /neurips2020/download_data.sh: -------------------------------------------------------------------------------- 1 | DEPTH_DATA_URL="https://www.dropbox.com/s/qtab28cauzalqi7/depth_data.tar.gz?dl=1" 2 | DATA_EXTRACT_DIR="./data" 3 | 4 | PRETRAINED_URL="https://www.dropbox.com/s/356r36lfpyzhcht/pretrained_models.tar.gz?dl=1" 5 | PRETRAINED_EXTRACT_DIR="./" 6 | 7 | wget -c $DEPTH_DATA_URL -O - | tar -xz -C $DATA_EXTRACT_DIR 8 | 9 | mkdir $PRETRAINED_DIR 10 | wget -c $PRETRAINED_URL -O - | tar -xz -C $PRETRAINED_EXTRACT_DIR 11 | -------------------------------------------------------------------------------- /neurips2020/models/depth/deterministic.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.layers import Conv2D, MaxPooling2D, \ 3 | UpSampling2D, Cropping2D, concatenate, ZeroPadding2D 4 | 5 | import functools 6 | from evidential_deep_learning.layers import Conv2DNormal 7 | 8 | def create(input_shape, activation=tf.nn.relu, num_class=1): 9 | opts = locals().copy() 10 | model, opts = dropout.create(input_shape, drop_prob=0.0, sigma=False, activation=activation, num_class=num_class) 11 | return model, opts 12 | -------------------------------------------------------------------------------- /neurips2020/models/depth/ensemble.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from . import dropout 3 | 4 | def create(input_shape, num_ensembles=5, sigma=True, activation=tf.nn.relu, num_class=1): 5 | opts = locals().copy() 6 | 7 | def create_single_model(): 8 | model, dropout_options = dropout.create(input_shape, drop_prob=0.0, sigma=sigma, activation=activation, num_class=num_class) 9 | return model 10 | 11 | models = [create_single_model() for _ in range(num_ensembles)] 12 | return models, opts 13 | -------------------------------------------------------------------------------- /neurips2020/models/toy/deterministic.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import functools 3 | 4 | def create( 5 | input_shape, 6 | num_neurons=100, 7 | num_layers=2, 8 | activation=tf.nn.relu, 9 | ): 10 | 11 | options = locals().copy() 12 | 13 | Dense = functools.partial(tf.keras.layers.Dense, activation=activation) 14 | 15 | layers = [] 16 | for _ in range(num_layers): 17 | layers.append(Dense(num_neurons)) 18 | layers.append(Dense(1, activation=tf.identity)) 19 | 20 | model = tf.keras.models.Sequential(layers) 21 | 22 | return model, options 23 | -------------------------------------------------------------------------------- /neurips2020/models/depth/gaussian.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | # tf.enable_eager_execution() 3 | from tensorflow.keras.layers import Conv2D, MaxPooling2D, \ 4 | UpSampling2D, Cropping2D, concatenate, ZeroPadding2D, SpatialDropout2D 5 | import functools 6 | 7 | from evidential_deep_learning.layers import Conv2DNormal 8 | from . import dropout 9 | 10 | def create(input_shape, activation=tf.nn.relu, num_class=1): 11 | opts = locals().copy() 12 | model, dropout_options = dropout.create(input_shape, drop_prob=0.0, sigma=True, activation=activation, num_class=num_class) 13 | return model, dropout_options 14 | -------------------------------------------------------------------------------- /neurips2020/models/toy/gaussian.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow_probability as tfp 3 | import functools 4 | import evidential_deep_learning as edl 5 | 6 | def create( 7 | input_shape, 8 | num_neurons=50, 9 | num_layers=1, 10 | activation=tf.nn.relu, 11 | ): 12 | 13 | options = locals().copy() 14 | 15 | inputs = tf.keras.Input(input_shape) 16 | x = inputs 17 | for _ in range(num_layers): 18 | x = tf.keras.layers.Dense(num_neurons, activation=activation)(x) 19 | output = edl.layers.DenseNormal(1)(x) 20 | model = tf.keras.Model(inputs=inputs, outputs=output) 21 | 22 | return model, options 23 | -------------------------------------------------------------------------------- /neurips2020/models/toy/evidential.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow_probability as tfp 3 | import functools 4 | import evidential_deep_learning as edl 5 | 6 | def create( 7 | input_shape, 8 | num_neurons=50, 9 | num_layers=1, 10 | activation=tf.nn.relu, 11 | ): 12 | 13 | options = locals().copy() 14 | 15 | inputs = tf.keras.Input(input_shape) 16 | x = inputs 17 | for _ in range(num_layers): 18 | x = tf.keras.layers.Dense(num_neurons, activation=activation)(x) 19 | output = edl.layers.DenseNormalGamma(1)(x) 20 | model = tf.keras.Model(inputs=inputs, outputs=output) 21 | 22 | return model, options 23 | -------------------------------------------------------------------------------- /neurips2020/models/toy/bbbp.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow_probability as tfp 3 | import functools 4 | 5 | def create( 6 | input_shape, 7 | num_neurons=100, 8 | num_layers=2, 9 | activation=tf.nn.relu, 10 | ): 11 | 12 | options = locals().copy() 13 | 14 | DenseReparameterization = functools.partial(tfp.layers.DenseReparameterization, activation=activation) 15 | layers = [] 16 | for _ in range(num_layers): 17 | layers.append(DenseReparameterization(num_neurons)) 18 | layers.append(DenseReparameterization(1, activation=tf.identity)) 19 | 20 | model = tf.keras.models.Sequential(layers) 21 | 22 | return model, options 23 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: "1.1.0" 2 | message: "If you use this software, please cite it using these metadata." 3 | title: "Deep Evidential Regression" 4 | authors: 5 | - 6 | family-names: Amini 7 | given-names: Alexander 8 | orcid: "https://orcid.org/0000-0002-9673-1267" 9 | - 10 | family-names: Schwarting 11 | given-names: Wilko 12 | - 13 | family-names: Soleimany 14 | given-names: Ava 15 | orcid: "https://orcid.org/0000-0002-8601-6040" 16 | - 17 | family-names: Rus 18 | given-names: Daniela 19 | conference: "Advances in Neural Information Processing Systems (NeurIPS)" 20 | year: 2020 21 | volume: 33 22 | repository-code: "https://github.com/aamini/evidential-deep-learning" 23 | -------------------------------------------------------------------------------- /neurips2020/trainers/util.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | def normalize(x, crop=True): 5 | if len(x.shape) == 4: 6 | x = x[:, 10:-10,5:-5] 7 | else: 8 | x = x[10:-10,5:-5] 9 | 10 | min = tf.reduce_min(x, axis=(-1,-2,-3), keepdims=True) 11 | max = tf.reduce_max(x, axis=(-1,-2,-3), keepdims=True) 12 | return (x - min)/(max-min) 13 | 14 | def gallery(array, ncols=3): 15 | nindex, height, width, intensity = array.shape 16 | nrows = nindex//ncols 17 | assert nindex == nrows*ncols 18 | # want result.shape = (height*nrows, width*ncols, intensity) 19 | result = (array.reshape(nrows, ncols, height, width, intensity) 20 | .swapaxes(1,2) 21 | .reshape(height*nrows, width*ncols, intensity)) 22 | return result 23 | -------------------------------------------------------------------------------- /neurips2020/models/toy/ensemble.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import functools 3 | import evidential_deep_learning as edl 4 | 5 | def create( 6 | input_shape, 7 | num_neurons=50, 8 | num_layers=1, 9 | activation=tf.nn.relu, 10 | num_ensembles=5, 11 | sigma=True 12 | ): 13 | 14 | options = locals().copy() 15 | 16 | def create_model(): 17 | inputs = tf.keras.Input(input_shape) 18 | x = inputs 19 | for _ in range(num_layers): 20 | x = tf.keras.layers.Dense(num_neurons, activation=activation)(x) 21 | output = edl.layers.DenseNormal(1)(x) 22 | model = tf.keras.Model(inputs=inputs, outputs=output) 23 | return model 24 | 25 | models = [create_model() for _ in range(num_ensembles)] 26 | 27 | return models, options 28 | -------------------------------------------------------------------------------- /neurips2020/data/uci/naval/Features.txt: -------------------------------------------------------------------------------- 1 | 1 - Lever position (lp) [ ] 2 | 2 - Ship speed (v) [knots] 3 | 3 - Gas Turbine shaft torque (GTT) [kN m] 4 | 4 - Gas Turbine rate of revolutions (GTn) [rpm] 5 | 5 - Gas Generator rate of revolutions (GGn) [rpm] 6 | 6 - Starboard Propeller Torque (Ts) [kN] 7 | 7 - Port Propeller Torque (Tp) [kN] 8 | 8 - HP Turbine exit temperature (T48) [C] 9 | 9 - GT Compressor inlet air temperature (T1) [C] 10 | 10 - GT Compressor outlet air temperature (T2) [C] 11 | 11 - HP Turbine exit pressure (P48) [bar] 12 | 12 - GT Compressor inlet air pressure (P1) [bar] 13 | 13 - GT Compressor outlet air pressure (P2) [bar] 14 | 14 - Gas Turbine exhaust gas pressure (Pexh) [bar] 15 | 15 - Turbine Injecton Control (TIC) [%] 16 | 16 - Fuel flow (mf) [kg/s] 17 | 17 - GT Compressor decay state coefficient. 18 | 18 - GT Turbine decay state coefficient. -------------------------------------------------------------------------------- /neurips2020/models/toy/dropout.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.regularizers import l2 3 | import functools 4 | 5 | def create( 6 | input_shape, 7 | num_neurons=50, 8 | num_layers=1, 9 | activation=tf.nn.relu, 10 | drop_prob=0.05, 11 | lam=1e-3, 12 | l=1e-2, 13 | sigma=False 14 | ): 15 | 16 | options = locals().copy() 17 | 18 | Dense = functools.partial(tf.keras.layers.Dense, kernel_regularizer=l2(lam), bias_regularizer=l2(lam), activation=activation) 19 | Dropout = functools.partial(tf.keras.layers.Dropout, drop_prob) 20 | n_out = 2 if sigma else 1 21 | 22 | layers = [] 23 | for _ in range(num_layers): 24 | layers.append(Dense(num_neurons)) 25 | layers.append(Dropout()) 26 | layers.append(Dense(n_out, activation=tf.identity)) 27 | 28 | model = tf.keras.models.Sequential(layers) 29 | 30 | return model, options 31 | -------------------------------------------------------------------------------- /neurips2020/models/toy/h_params.py: -------------------------------------------------------------------------------- 1 | # This file contains the hyperparmeters for reproducing the benchmark UCI 2 | # dataset regression tasks. The hyperparmeters, which included the learning_rate 3 | # and batch_size were optimized using grid search on an 80-20 train-test split 4 | # of the dataset with the optimal resulting hyperparmeters saved in this file 5 | # for quick reloading. 6 | 7 | h_params = { 8 | 'yacht': {'learning_rate': 5e-4, 'batch_size': 1}, 9 | 'naval': {'learning_rate': 5e-4, 'batch_size': 1}, 10 | 'concrete': {'learning_rate': 5e-3, 'batch_size': 1}, 11 | 'energy-efficiency': {'learning_rate': 2e-3, 'batch_size': 1}, 12 | 'kin8nm': {'learning_rate': 1e-3, 'batch_size': 1}, 13 | 'power-plant': {'learning_rate': 1e-3, 'batch_size': 2}, 14 | 'boston': {'learning_rate': 1e-3, 'batch_size': 8}, 15 | 'wine': {'learning_rate': 1e-4, 'batch_size': 32}, 16 | 'protein': {'learning_rate': 1e-3, 'batch_size': 64}, 17 | } 18 | -------------------------------------------------------------------------------- /neurips2020/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .toy import bbbp 2 | from .toy import dropout 3 | from .toy import ensemble 4 | from .toy import evidential 5 | from .toy import gaussian 6 | from .toy import deterministic 7 | from .toy.h_params import h_params 8 | 9 | from .depth import bbbp 10 | from .depth import dropout 11 | from .depth import ensemble 12 | from .depth import evidential 13 | from .depth import gaussian 14 | from .depth import deterministic 15 | 16 | 17 | def get_correct_model(dataset, trainer): 18 | """ Hacky helper function to grab the right model for a given dataset and trainer. """ 19 | dataset_loader = globals()[dataset] 20 | trainer_lookup = trainer.__name__.lower() 21 | model_pointer = dataset_loader.__dict__[trainer_lookup] 22 | return model_pointer 23 | 24 | def load_depth_model(path, compile=False): 25 | import glob 26 | import tensorflow as tf 27 | import edl 28 | 29 | model_paths = glob.glob(path) 30 | if model_paths == []: 31 | model_paths = [path] 32 | 33 | custom_objects ={'Conv2DNormal': edl.layers.Conv2DNormal, 34 | 'Conv2DNormalGamma': edl.layers.Conv2DNormalGamma} 35 | 36 | models = [tf.keras.models.load_model(model_path, custom_objects, compile=compile) for model_path in model_paths] 37 | if len(models) == 1: 38 | models = models[0] 39 | 40 | return models 41 | -------------------------------------------------------------------------------- /neurips2020/preprocess/cache_apolloscape.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import os 3 | from scipy.io import loadmat 4 | import cv2 5 | from tqdm import tqdm 6 | import numpy as np 7 | import glob 8 | 9 | root = "/data/apollo" 10 | test_path = "/data/apolloscape_test.h5" 11 | if not path.isdir(root): 12 | raise ValueError("Please download the Apolloscape dataset to /data/apollo. \ 13 | Or follow the instructions in the README to download a subset of the \ 14 | data to test this the code provided in this repository.") 15 | 16 | 17 | img_paths = sorted(glob.glob(os.path.join(root, "camera_5/*.jpg"))) 18 | disp_paths = sorted(glob.glob(os.path.join(root, "disparity/*.png"))) 19 | inds = np.random.choice(len(img_paths), 1000, replace=False) 20 | 21 | 22 | sz = (160, 128) 23 | def resize(I): 24 | w_ = int(I.shape[1] / (float(I.shape[0]) / float(sz[1]))) 25 | I = cv2.resize(I, (w_, sz[1])) 26 | I = I[:, 100:-100] 27 | I = cv2.resize(I, sz) 28 | return I 29 | 30 | imgs = [] 31 | disps = [] 32 | for i, ind in enumerate(tqdm(inds)): 33 | img = cv2.imread(img_paths[ind]) 34 | imgs.append(resize(img)) 35 | 36 | disp = cv2.imread(disp_paths[ind], 0) 37 | disps.append(np.expand_dims(resize(disp), -1)) 38 | 39 | 40 | print("saving") 41 | 42 | f = h5py.File(test_path, 'w') 43 | f.create_dataset("image", data=imgs, dtype=np.uint8) 44 | f.create_dataset("depth", data=disps, dtype=np.uint8) 45 | f.close() 46 | 47 | import pdb; pdb.set_trace() 48 | -------------------------------------------------------------------------------- /evidential_deep_learning/losses/discrete.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | def Dirichlet_SOS(y, alpha, t): 6 | def KL(alpha): 7 | beta=tf.constant(np.ones((1,alpha.shape[1])),dtype=tf.float32) 8 | S_alpha = tf.reduce_sum(alpha,axis=1,keepdims=True) 9 | S_beta = tf.reduce_sum(beta,axis=1,keepdims=True) 10 | lnB = tf.math.lgamma(S_alpha) - tf.reduce_sum(tf.math.lgamma(alpha),axis=1,keepdims=True) 11 | lnB_uni = tf.reduce_sum(tf.math.lgamma(beta),axis=1,keepdims=True) - tf.math.lgamma(S_beta) 12 | lnB_uni = tf.reduce_sum(tf.math.lgamma(beta),axis=1,keepdims=True) - tf.math.lgamma(S_beta) 13 | 14 | dg0 = tf.math.digamma(S_alpha) 15 | dg1 = tf.math.digamma(alpha) 16 | 17 | kl = tf.reduce_sum((alpha - beta)*(dg1-dg0),axis=1,keepdims=True) + lnB + lnB_uni 18 | return kl 19 | 20 | S = tf.reduce_sum(alpha, axis=1, keepdims=True) 21 | evidence = alpha - 1 22 | m = alpha / S 23 | 24 | A = tf.reduce_sum((y-m)**2, axis=1, keepdims=True) 25 | B = tf.reduce_sum(alpha*(S-alpha)/(S*S*(S+1)), axis=1, keepdims=True) 26 | 27 | # annealing_coef = tf.minimum(1.0,tf.cast(global_step/annealing_step,tf.float32)) 28 | alpha_hat = y + (1-y)*alpha 29 | C = KL(alpha_hat) 30 | 31 | C = tf.reduce_mean(C, axis=1) 32 | return tf.reduce_mean(A + B + C) 33 | 34 | def Sigmoid_CE(y, y_logits): 35 | loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=y, logits=y_logits) 36 | return tf.reduce_mean(loss) 37 | -------------------------------------------------------------------------------- /neurips2020/preprocess/preprocess_nyu_depth.m: -------------------------------------------------------------------------------- 1 | clear; 2 | 3 | dataset_path = '/data/nyu_depth'; 4 | output_path = '/data/nyu_depth_processed'; 5 | skip_n = 1; 6 | 7 | addpath(genpath('~/Documents/MATLAB')); 8 | 9 | places = dir(dataset_path); 10 | places = places(3:end); 11 | for i=1:length(places) 12 | place = places(i).name; 13 | 14 | places_done = dir(output_path); 15 | done = false; 16 | for j=1:length(places_done) 17 | if strcmp(place, places_done(j).name) 18 | done = true; 19 | break 20 | end 21 | end 22 | 23 | if done 24 | fprintf('skipping %s \n',place); 25 | continue 26 | end 27 | mkdir(strcat(output_path, '/', place)); 28 | sync = get_synched_frames(strcat(dataset_path, '/', place)); 29 | fprintf('preprocessing %d in %s \n', int16(length(sync)/skip_n), place); 30 | for j=1:skip_n:length(sync) 31 | try 32 | rgb = imread(strcat(dataset_path, '/', place, '/', sync(j).rawRgbFilename)); 33 | depth = imread(strcat(dataset_path, '/', place, '/', sync(j).rawDepthFilename)); 34 | depth_proj = project_depth_map(swapbytes(depth), rgb); 35 | depth_fill = fill_depth_colorization(double(rgb)/255,depth_proj,0.9); 36 | 37 | disp = (1./depth_fill) * 255.; 38 | disp = uint8(max(0, min(255, disp))); 39 | catch 40 | continue; 41 | end 42 | 43 | save(strcat(output_path, '/', place, '/scan_', num2str(j)), 'rgb','disp'); 44 | 45 | fprintf('.'); 46 | end 47 | fprintf('\n'); 48 | 49 | end 50 | -------------------------------------------------------------------------------- /neurips2020/data/uci/power-plant/Readme.txt: -------------------------------------------------------------------------------- 1 | The dataset contains 9568 data points collected from a Combined Cycle Power Plant over 6 years (2006-2011), when the power plant was set to work with full load. Features consist of hourly average ambient variables Temperature (T), Ambient Pressure (AP), Relative Humidity (RH) and Exhaust Vacuum (V) to predict the net hourly electrical energy output (EP) of the plant. 2 | A combined cycle power plant (CCPP) is composed of gas turbines (GT), steam turbines (ST) and heat recovery steam generators. In a CCPP, the electricity is generated by gas and steam turbines, which are combined in one cycle, and is transferred from one turbine to another. While the Vacuum is colected from and has effect on the Steam Turbine, he other three of the ambient variables effect the GT performance. 3 | For comparability with our baseline studies, and to allow 5x2 fold statistical tests be carried out, we provide the data shuffled five times. For each shuffling 2-fold CV is carried out and the resulting 10 measurements are used for statistical testing. 4 | We provide the data both in .ods and in .xlsx formats. 5 | 6 | Relevant Papers to cite: 7 | 8 | Pınar Tüfekci, Prediction of full load electrical power output of a base load operated combined cycle power plant using machine learning methods, International Journal of Electrical Power & Energy Systems, Volume 60, September 2014, Pages 126-140, ISSN 0142-0615, http://dx.doi.org/10.1016/j.ijepes.2014.02.027. 9 | (http://www.sciencedirect.com/science/article/pii/S0142061514000908) 10 | 11 | Heysem Kaya, Pınar Tüfekci , Sadık Fikret Gürgen: Local and Global Learning Methods for Predicting Power of a Combined Gas & Steam Turbine, Proceedings of the International Conference on Emerging Trends in Computer and Electronics Engineering ICETCEE 2012, pp. 13-18 (Mar. 2012, Dubai) 12 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Alexander Amini 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from setuptools import setup, find_packages 16 | 17 | with open('README.md') as readme_file: 18 | readme = readme_file.read() 19 | 20 | version = "0.4.0" 21 | 22 | setup( 23 | name="evidential_deep_learning", 24 | version=version, 25 | packages=find_packages(), 26 | description= 27 | "Learn fast, scalable, and calibrated measures of uncertainty using neural networks!", 28 | long_description=readme, 29 | long_description_content_type='text/markdown', 30 | url="https://github.com/aamini/evidential-deep-learning", 31 | download_url=f"https://github.com/aamini/evidential-deep-learning/archive/v{version}.tar.gz", 32 | author="Alexander Amini", 33 | author_email="amini@mit.edu", 34 | license="Apache License 2.0", 35 | install_requires=[ 36 | "numpy", 37 | "matplotlib", 38 | ], # Tensorflow must be installed manually 39 | python_requires='>=3.7', 40 | classifiers=[ 41 | "Programming Language :: Python", 42 | "Programming Language :: Python :: 3.7", 43 | "Operating System :: Unix", 44 | "Operating System :: Microsoft :: Windows", 45 | "Operating System :: MacOS", 46 | "Intended Audience :: Science/Research", 47 | "Topic :: Scientific/Engineering", 48 | "Topic :: Software Development", 49 | ], 50 | ) 51 | -------------------------------------------------------------------------------- /evidential_deep_learning/layers/conv2d.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tensorflow.keras.layers import Layer, Conv2D 4 | 5 | 6 | class Conv2DNormal(Layer): 7 | def __init__(self, filters, kernel_size, **kwargs): 8 | self.filters = filters 9 | self.kernel_size = kernel_size 10 | super(Conv2DNormal, self).__init__() 11 | self.conv = Conv2D(2 * filters, kernel_size, **kwargs) 12 | 13 | def call(self, x): 14 | output = self.conv(x) 15 | mu, logsigma = tf.split(output, 2, axis=-1) 16 | sigma = tf.nn.softplus(logsigma) + 1e-6 17 | # return [mu, sigma] 18 | return tf.concat([mu, sigma], axis=-1) 19 | 20 | def compute_output_shape(self, input_shape): 21 | return self.conv.compute_output_shape(input_shape) 22 | 23 | def get_config(self): 24 | base_config = super(Conv2DNormal, self).get_config() 25 | base_config['filters'] = self.filters 26 | base_config['kernel_size'] = self.kernel_size 27 | return base_config 28 | 29 | 30 | class Conv2DNormalGamma(Layer): 31 | def __init__(self, filters, kernel_size, **kwargs): 32 | self.filters = filters 33 | self.kernel_size = kernel_size 34 | super(Conv2DNormalGamma, self).__init__() 35 | self.conv = Conv2D(4 * filters, kernel_size, **kwargs) 36 | 37 | def evidence(self, x): 38 | # return tf.exp(x) 39 | return tf.nn.softplus(x) 40 | 41 | def call(self, x): 42 | output = self.conv(x) 43 | mu, logv, logalpha, logbeta = tf.split(output, 4, axis=-1) 44 | v = self.evidence(logv) 45 | alpha = self.evidence(logalpha) + 1 46 | beta = self.evidence(logbeta) 47 | return tf.concat([mu, v, alpha, beta], axis=-1) 48 | 49 | def compute_output_shape(self, input_shape): 50 | return self.conv.compute_output_shape(input_shape) 51 | 52 | def get_config(self): 53 | base_config = super(Conv2DNormalGamma, self).get_config() 54 | base_config['filters'] = self.filters 55 | base_config['kernel_size'] = self.kernel_size 56 | return base_config 57 | 58 | 59 | # Conv2DNormalGamma(32, (5,5)) 60 | -------------------------------------------------------------------------------- /neurips2020/train_depth.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import h5py 4 | import numpy as np 5 | import os 6 | import time 7 | import tensorflow as tf 8 | 9 | import edl 10 | import data_loader 11 | import models 12 | import trainers 13 | 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--model", default="evidential", type=str, 17 | choices=["evidential", "dropout", "ensemble"]) 18 | parser.add_argument("--batch-size", default=32, type=int) 19 | parser.add_argument("--iters", default=60000, type=int) 20 | parser.add_argument("--learning-rate", default=5e-5, type=float) 21 | args = parser.parse_args() 22 | 23 | ### Try to limit GPU memory to fit ensembles on RTX 2080Ti 24 | gpus = tf.config.experimental.list_physical_devices('GPU') 25 | if gpus: 26 | try: 27 | tf.config.experimental.set_virtual_device_configuration( 28 | gpus[0], [tf.config.experimental.VirtualDeviceConfiguration( 29 | memory_limit=9000)]) 30 | except RuntimeError as e: 31 | print(e) 32 | 33 | ### Load the data 34 | (x_train, y_train), (x_test, y_test) = data_loader.load_depth() 35 | 36 | ### Create the trainer 37 | if args.model == "evidential": 38 | trainer_obj = trainers.Evidential 39 | model_generator = models.get_correct_model(dataset="depth", trainer=trainer_obj) 40 | model, opts = model_generator.create(input_shape=x_train.shape[1:]) 41 | trainer = trainer_obj(model, opts, args.learning_rate, lam=2e-1, epsilon=0., maxi_rate=0.) 42 | 43 | elif args.model == "dropout": 44 | trainer_obj = trainers.Dropout 45 | model_generator = models.get_correct_model(dataset="depth", trainer=trainer_obj) 46 | model, opts = model_generator.create(input_shape=x_train.shape[1:], sigma=False) 47 | trainer = trainer_obj(model, opts, args.learning_rate) 48 | 49 | elif args.model == "ensemble": 50 | trainer_obj = trainers.Ensemble 51 | model_generator = models.get_correct_model(dataset="depth", trainer=trainer_obj) 52 | model, opts = model_generator.create(input_shape=x_train.shape[1:], sigma=False) 53 | trainer = trainer_obj(model, opts, args.learning_rate) 54 | 55 | 56 | ### Train the model 57 | model, rmse, nll = trainer.train(x_train, y_train, x_test, y_test, np.array([[1.]]), iters=args.iters, batch_size=args.batch_size, verbose=True) 58 | tf.keras.backend.clear_session() 59 | 60 | print("Done training!") 61 | -------------------------------------------------------------------------------- /neurips2020/preprocess/cache_nyu_depth.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import os 3 | from scipy.io import loadmat 4 | import cv2 5 | from tqdm import tqdm 6 | import numpy as np 7 | 8 | nyu_root = "/data/nyu_depth_processed" 9 | train_path = "./data/depth_train_big_new.h5" 10 | test_path = "./data/depth_test_big_new.h5" 11 | 12 | if not path.isdir(nyu_root): 13 | raise ValueError("Please download the Apolloscape dataset to /data/apollo. \ 14 | Or follow the instructions in the README to download a subset of the \ 15 | data to test this the code provided in this repository.") 16 | 17 | 18 | datasets = ["bedroom", "bathroom","cafe","bookstore","classroom","computer_lab","conference_room","dentist_office","dining_room","excercise_room","foyer","home_office","kitchen","library","living_room","office_","study_"] 19 | 20 | scan_paths = [] 21 | dirs = set() 22 | 23 | for dirName, subdirList, fileList in os.walk(nyu_root): 24 | if not any([d in dirName for d in datasets]): 25 | continue 26 | 27 | for fname in fileList: 28 | if "scan_" in fname: 29 | dirs.add(os.path.basename(dirName)) 30 | scan_paths.append(os.path.join(dirName, fname)) 31 | 32 | print("found {} scans in {}".format(len(scan_paths), datasets)) 33 | 34 | imgs = [] 35 | depths = [] 36 | sorted(scan_paths) 37 | sz = (160, 128) 38 | for scan_path in tqdm(scan_paths): 39 | try: 40 | f = loadmat(scan_path) 41 | except: 42 | print("failed to load: {}".format(scan_path)) 43 | continue 44 | 45 | img = f['rgb'] 46 | depth = f['disp'] 47 | imgs.append(cv2.resize(img, sz)) 48 | depths.append(np.expand_dims(cv2.resize(depth, sz), -1)) 49 | # if len(imgs)>100: break 50 | 51 | print("saving") 52 | n=len(imgs) 53 | all_idx = np.arange(n) 54 | train_idx = np.random.choice(all_idx, int(n*0.9),replace=False) 55 | train_idx = sorted(train_idx) 56 | test_idx = list(set(all_idx)-set(train_idx)) 57 | test_idx = sorted(test_idx) 58 | 59 | imgs = np.array(imgs) 60 | depths = np.array(depths) 61 | 62 | f = h5py.File(train_path, 'w') 63 | f.create_dataset("image", data=imgs[train_idx], dtype=np.uint8) 64 | f.create_dataset("depth", data=depths[train_idx], dtype=np.uint8) 65 | f.close() 66 | 67 | f = h5py.File(test_path, 'w') 68 | f.create_dataset("image", data=imgs[test_idx], dtype=np.uint8) 69 | f.create_dataset("depth", data=depths[test_idx], dtype=np.uint8) 70 | f.close() 71 | 72 | import pdb; pdb.set_trace() 73 | -------------------------------------------------------------------------------- /evidential_deep_learning/losses/continuous.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | def MSE(y, y_, reduce=True): 6 | ax = list(range(1, len(y.shape))) 7 | 8 | mse = tf.reduce_mean((y-y_)**2, axis=ax) 9 | return tf.reduce_mean(mse) if reduce else mse 10 | 11 | def RMSE(y, y_): 12 | rmse = tf.sqrt(tf.reduce_mean((y-y_)**2)) 13 | return rmse 14 | 15 | def Gaussian_NLL(y, mu, sigma, reduce=True): 16 | ax = list(range(1, len(y.shape))) 17 | 18 | logprob = -tf.math.log(sigma) - 0.5*tf.math.log(2*np.pi) - 0.5*((y-mu)/sigma)**2 19 | loss = tf.reduce_mean(-logprob, axis=ax) 20 | return tf.reduce_mean(loss) if reduce else loss 21 | 22 | def Gaussian_NLL_logvar(y, mu, logvar, reduce=True): 23 | ax = list(range(1, len(y.shape))) 24 | 25 | log_liklihood = 0.5 * ( 26 | -tf.exp(-logvar)*(mu-y)**2 - tf.math.log(2*tf.constant(np.pi, dtype=logvar.dtype)) - logvar 27 | ) 28 | loss = tf.reduce_mean(-log_liklihood, axis=ax) 29 | return tf.reduce_mean(loss) if reduce else loss 30 | 31 | def NIG_NLL(y, gamma, v, alpha, beta, reduce=True): 32 | twoBlambda = 2*beta*(1+v) 33 | 34 | nll = 0.5*tf.math.log(np.pi/v) \ 35 | - alpha*tf.math.log(twoBlambda) \ 36 | + (alpha+0.5) * tf.math.log(v*(y-gamma)**2 + twoBlambda) \ 37 | + tf.math.lgamma(alpha) \ 38 | - tf.math.lgamma(alpha+0.5) 39 | 40 | return tf.reduce_mean(nll) if reduce else nll 41 | 42 | def KL_NIG(mu1, v1, a1, b1, mu2, v2, a2, b2): 43 | KL = 0.5*(a1-1)/b1 * (v2*tf.square(mu2-mu1)) \ 44 | + 0.5*v2/v1 \ 45 | - 0.5*tf.math.log(tf.abs(v2)/tf.abs(v1)) \ 46 | - 0.5 + a2*tf.math.log(b1/b2) \ 47 | - (tf.math.lgamma(a1) - tf.math.lgamma(a2)) \ 48 | + (a1 - a2)*tf.math.digamma(a1) \ 49 | - (b1 - b2)*a1/b1 50 | return KL 51 | 52 | def NIG_Reg(y, gamma, v, alpha, beta, omega=0.01, reduce=True, kl=False): 53 | # error = tf.stop_gradient(tf.abs(y-gamma)) 54 | error = tf.abs(y-gamma) 55 | 56 | if kl: 57 | kl = KL_NIG(gamma, v, alpha, beta, gamma, omega, 1+omega, beta) 58 | reg = error*kl 59 | else: 60 | evi = 2*v+(alpha) 61 | reg = error*evi 62 | 63 | return tf.reduce_mean(reg) if reduce else reg 64 | 65 | def EvidentialRegression(y_true, evidential_output, coeff=1.0): 66 | gamma, v, alpha, beta = tf.split(evidential_output, 4, axis=-1) 67 | loss_nll = NIG_NLL(y_true, gamma, v, alpha, beta) 68 | loss_reg = NIG_Reg(y_true, gamma, v, alpha, beta) 69 | return loss_nll + coeff * loss_reg 70 | -------------------------------------------------------------------------------- /hello_world.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | import evidential_deep_learning as edl 6 | import tensorflow as tf 7 | 8 | 9 | def main(): 10 | # Create some training and testing data 11 | x_train, y_train = my_data(-4, 4, 1000) 12 | x_test, y_test = my_data(-7, 7, 1000, train=False) 13 | 14 | # Define our model with an evidential output 15 | model = tf.keras.Sequential([ 16 | tf.keras.layers.Dense(64, activation="relu"), 17 | tf.keras.layers.Dense(64, activation="relu"), 18 | edl.layers.DenseNormalGamma(1), 19 | ]) 20 | 21 | # Custom loss function to handle the custom regularizer coefficient 22 | def EvidentialRegressionLoss(true, pred): 23 | return edl.losses.EvidentialRegression(true, pred, coeff=1e-2) 24 | 25 | # Compile and fit the model! 26 | model.compile( 27 | optimizer=tf.keras.optimizers.Adam(5e-4), 28 | loss=EvidentialRegressionLoss) 29 | model.fit(x_train, y_train, batch_size=100, epochs=500) 30 | 31 | # Predict and plot using the trained model 32 | y_pred = model(x_test) 33 | plot_predictions(x_train, y_train, x_test, y_test, y_pred) 34 | 35 | # Done!! 36 | 37 | 38 | #### Helper functions #### 39 | def my_data(x_min, x_max, n, train=True): 40 | x = np.linspace(x_min, x_max, n) 41 | x = np.expand_dims(x, -1).astype(np.float32) 42 | 43 | sigma = 3 * np.ones_like(x) if train else np.zeros_like(x) 44 | y = x**3 + np.random.normal(0, sigma).astype(np.float32) 45 | 46 | return x, y 47 | 48 | def plot_predictions(x_train, y_train, x_test, y_test, y_pred, n_stds=4, kk=0): 49 | x_test = x_test[:, 0] 50 | mu, v, alpha, beta = tf.split(y_pred, 4, axis=-1) 51 | mu = mu[:, 0] 52 | var = np.sqrt(beta / (v * (alpha - 1))) 53 | var = np.minimum(var, 1e3)[:, 0] # for visualization 54 | 55 | plt.figure(figsize=(5, 3), dpi=200) 56 | plt.scatter(x_train, y_train, s=1., c='#463c3c', zorder=0, label="Train") 57 | plt.plot(x_test, y_test, 'r--', zorder=2, label="True") 58 | plt.plot(x_test, mu, color='#007cab', zorder=3, label="Pred") 59 | plt.plot([-4, -4], [-150, 150], 'k--', alpha=0.4, zorder=0) 60 | plt.plot([+4, +4], [-150, 150], 'k--', alpha=0.4, zorder=0) 61 | for k in np.linspace(0, n_stds, 4): 62 | plt.fill_between( 63 | x_test, (mu - k * var), (mu + k * var), 64 | alpha=0.3, 65 | edgecolor=None, 66 | facecolor='#00aeef', 67 | linewidth=0, 68 | zorder=1, 69 | label="Unc." if k == 0 else None) 70 | plt.gca().set_ylim(-150, 150) 71 | plt.gca().set_xlim(-7, 7) 72 | plt.legend(loc="upper left") 73 | plt.show() 74 | 75 | 76 | if __name__ == "__main__": 77 | main() 78 | -------------------------------------------------------------------------------- /evidential_deep_learning/layers/dense.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tensorflow.keras.layers import Layer, Dense 4 | 5 | 6 | class DenseNormal(Layer): 7 | def __init__(self, units): 8 | super(DenseNormal, self).__init__() 9 | self.units = int(units) 10 | self.dense = Dense(2 * self.units) 11 | 12 | def call(self, x): 13 | output = self.dense(x) 14 | mu, logsigma = tf.split(output, 2, axis=-1) 15 | sigma = tf.nn.softplus(logsigma) + 1e-6 16 | return tf.concat([mu, sigma], axis=-1) 17 | 18 | def compute_output_shape(self, input_shape): 19 | return (input_shape[0], 2 * self.units) 20 | 21 | def get_config(self): 22 | base_config = super(DenseNormal, self).get_config() 23 | base_config['units'] = self.units 24 | return base_config 25 | 26 | 27 | class DenseNormalGamma(Layer): 28 | def __init__(self, units): 29 | super(DenseNormalGamma, self).__init__() 30 | self.units = int(units) 31 | self.dense = Dense(4 * self.units, activation=None) 32 | 33 | def evidence(self, x): 34 | # return tf.exp(x) 35 | return tf.nn.softplus(x) 36 | 37 | def call(self, x): 38 | output = self.dense(x) 39 | mu, logv, logalpha, logbeta = tf.split(output, 4, axis=-1) 40 | v = self.evidence(logv) 41 | alpha = self.evidence(logalpha) + 1 42 | beta = self.evidence(logbeta) 43 | return tf.concat([mu, v, alpha, beta], axis=-1) 44 | 45 | def compute_output_shape(self, input_shape): 46 | return (input_shape[0], 4 * self.units) 47 | 48 | def get_config(self): 49 | base_config = super(DenseNormalGamma, self).get_config() 50 | base_config['units'] = self.units 51 | return base_config 52 | 53 | 54 | class DenseDirichlet(Layer): 55 | def __init__(self, units): 56 | super(DenseDirichlet, self).__init__() 57 | self.units = int(units) 58 | self.dense = Dense(int(units)) 59 | 60 | def call(self, x): 61 | output = self.dense(x) 62 | evidence = tf.exp(output) 63 | alpha = evidence + 1 64 | prob = alpha / tf.reduce_sum(alpha, 1, keepdims=True) 65 | return tf.concat([alpha, prob], axis=-1) 66 | 67 | def compute_output_shape(self, input_shape): 68 | return (input_shape[0], 2 * self.units) 69 | 70 | 71 | class DenseSigmoid(Layer): 72 | def __init__(self, units): 73 | super(DenseSigmoid, self).__init__() 74 | self.units = int(units) 75 | self.dense = Dense(int(units)) 76 | 77 | def call(self, x): 78 | logits = self.dense(x) 79 | prob = tf.nn.sigmoid(logits) 80 | return [logits, prob] 81 | 82 | def compute_output_shape(self, input_shape): 83 | return (input_shape[0], self.units) 84 | -------------------------------------------------------------------------------- /neurips2020/README.md: -------------------------------------------------------------------------------- 1 | # Deep Evidential Regression 2 | *Alexander Amini, Wilko Schwarting, Ava Soleimany, Daniela Rus. NeurIPS 2020* 3 | 4 | This repository contains the code to reproduce all results presented in the NeurIPS submission: "Deep Evidential Regression". 5 | 6 | @article{amini2020deep, 7 | title={Deep evidential regression}, 8 | author={Amini, Alexander and Schwarting, Wilko and Soleimany, Ava and Rus, Daniela}, 9 | journal={Advances in Neural Information Processing Systems}, 10 | volume={33}, 11 | year={2020} 12 | } 13 | 14 | 15 | ## Setup 16 | 17 | ### Download datasets 18 | To get started, first download the relevant datasets and some pre-trained models (if you don't want to re-train from scratch). These datasets include: 19 | 1. [NYU Depth v2 dataset ](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html "NYU Depth v2 dataset ") 20 | 2. [Apolloscapes depth](http://apolloscape.auto/stereo.html "Apolloscapes depth") 21 | 3. UCI regression tasks (already pre-downloaded in `./data`) 22 | 23 | To download run the following commands from a Unix shell: 24 | ``` 25 | cd evidential_deep_learning/neurips2020 26 | bash ./download_data.sh 27 | ``` 28 | We also include pre-trained models, in case reviewers would like to use these to produce results without retraining from scratch. If you would like to re-train, we provide the code to do so - we include pre-trained models here only for added convienence. 29 | 30 | ### Software environment 31 | We package our codebase into a conda environment, with all dependencies listed in `environment.yml`. To create a local copy of this environment and activate it, run the following commands: 32 | ``` 33 | conda env create -f environment.yml 34 | conda activate evidence 35 | ``` 36 | 37 | 38 | ## Reproducing Results 39 | 40 | ### Monocular Depth 41 | The easiest way to reproduce the depth results presented in the submission would be to run: 42 | ``` 43 | python gen_depth_results.py 44 | ``` 45 | This command will automatically use the pre-trained models downloaded above, if you would like to retrain the depth models from scratch you can run: 46 | ``` 47 | python train_depth.py [--model {evidential, dropout, ensemble}] 48 | ``` 49 | Note that the path to any new trained models should be replaced in the `trained_models` parameter within `gen_depth_results.py` if you'd like the plots to reflect changes in this newly trained model. 50 | 51 | 52 | ### UCI and Cubic Examples 53 | Results for the cubic example figures can be reproduced by running: 54 | ``` 55 | python run_cubic_tests.py 56 | ``` 57 | 58 | Results for the UCI benchmarking tasks can be reproduced by running: 59 | ``` 60 | python run_uci_dataset_tests.py [-h] [--num-trials NUM_TRIALS] 61 | [--num-epochs NUM_EPOCHS] 62 | [--datasets {yacht, ...}]] 63 | ``` 64 | Enjoy! 65 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | .DS_Store 141 | ._DS_Store 142 | 143 | *.h5 144 | *.pdf 145 | 146 | logs/ 147 | save/ 148 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Evidential Deep Learning 2 | 3 |

"All models are wrong, but some — that know when they can be trusted — are useful!"

4 |

- George Box (Adapted)

5 | 6 | 7 | ![](assets/banner.png) 8 | 9 | This repository contains the code to reproduce [Deep Evidential Regression](https://proceedings.neurips.cc/paper/2020/file/aab085461de182608ee9f607f3f7d18f-Paper.pdf), as published in [NeurIPS 2020](https://neurips.cc/), as well as more general code to leverage evidential learning to train neural networks to learn their own measures of uncertainty directly from data! 10 | 11 | ## Setup 12 | To use this package, you must install the following dependencies first: 13 | - python (>=3.7) 14 | - tensorflow (>=2.0) 15 | - pytorch (support coming soon) 16 | 17 | Now you can install to start adding evidential layers and losses to your models! 18 | ``` 19 | pip install evidential-deep-learning 20 | ``` 21 | Now you're ready to start using this package directly as part of your existing `tf.keras` model pipelines (`Sequential`, `Functional`, or `model-subclassing`): 22 | ``` 23 | >>> import evidential_deep_learning as edl 24 | ``` 25 | 26 | ### Example 27 | To use evidential deep learning, you must edit the last layer of your model to be *evidential* and use a supported loss function to train the system end-to-end. This repository supports evidential layers for both fully connected and convolutional (2D) layers. The evidential prior distribution presented in the paper follows a Normal Inverse-Gamma and can be added to your model: 28 | 29 | ``` 30 | import evidential_deep_learning as edl 31 | import tensorflow as tf 32 | 33 | model = tf.keras.Sequential( 34 | [ 35 | tf.keras.layers.Dense(64, activation="relu"), 36 | tf.keras.layers.Dense(64, activation="relu"), 37 | edl.layers.DenseNormalGamma(1), # Evidential distribution! 38 | ] 39 | ) 40 | model.compile( 41 | optimizer=tf.keras.optimizers.Adam(1e-3), 42 | loss=edl.losses.EvidentialRegression # Evidential loss! 43 | ) 44 | ``` 45 | 46 | ![](assets/animation.gif) 47 | Checkout `hello_world.py` for an end-to-end toy example walking through this step-by-step. For more complex examples, scaling up to computer vision problems (where we learn to predict tens of thousands of evidential distributions simultaneously!), please refer to the NeurIPS 2020 paper, and the reproducibility section of this repo to run those examples. 48 | 49 | ## Reproducibility 50 | All of the results published as part of our NeurIPS paper can be reproduced as part of this repository. Please refer to [the reproducibility section](./neurips2020) for details and instructions to obtain each result. 51 | 52 | ## Citation 53 | If you use this code for evidential learning as part of your project or paper, please cite the following work: 54 | 55 | @article{amini2020deep, 56 | title={Deep evidential regression}, 57 | author={Amini, Alexander and Schwarting, Wilko and Soleimany, Ava and Rus, Daniela}, 58 | journal={Advances in Neural Information Processing Systems}, 59 | volume={33}, 60 | year={2020} 61 | } 62 | -------------------------------------------------------------------------------- /neurips2020/run_uci_dataset_tests.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import tensorflow as tf 5 | import time 6 | from scipy import stats 7 | 8 | import edl 9 | import data_loader 10 | import trainers 11 | import models 12 | from models.toy.h_params import h_params 13 | 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--num-trials", default=20, type=int, 17 | help="Number of trials to repreat training for \ 18 | statistically significant results.") 19 | parser.add_argument("--num-epochs", default=40, type=int) 20 | parser.add_argument('--datasets', nargs='+', default=["yacht"], 21 | choices=['boston', 'concrete', 'energy-efficiency', 22 | 'kin8nm', 'naval', 'power-plant', 'protein', 23 | 'wine', 'yacht']) 24 | args = parser.parse_args() 25 | 26 | """" ================================================""" 27 | training_schemes = [trainers.Evidential] 28 | datasets = args.datasets 29 | num_trials = args.num_trials 30 | num_epochs = args.num_epochs 31 | dev = "/cpu:0" # for small datasets/models cpu is faster than gpu 32 | """" ================================================""" 33 | 34 | RMSE = np.zeros((len(datasets), len(training_schemes), num_trials)) 35 | NLL = np.zeros((len(datasets), len(training_schemes), num_trials)) 36 | for di, dataset in enumerate(datasets): 37 | for ti, trainer_obj in enumerate(training_schemes): 38 | for n in range(num_trials): 39 | (x_train, y_train), (x_test, y_test), y_scale = data_loader.load_dataset(dataset, return_as_tensor=False) 40 | batch_size = h_params[dataset]["batch_size"] 41 | num_iterations = num_epochs * x_train.shape[0]//batch_size 42 | done = False 43 | while not done: 44 | with tf.device(dev): 45 | model_generator = models.get_correct_model(dataset="toy", trainer=trainer_obj) 46 | model, opts = model_generator.create(input_shape=x_train.shape[1:]) 47 | trainer = trainer_obj(model, opts, dataset, learning_rate=h_params[dataset]["learning_rate"]) 48 | model, rmse, nll = trainer.train(x_train, y_train, x_test, y_test, y_scale, iters=num_iterations, batch_size=batch_size, verbose=True) 49 | del model 50 | tf.keras.backend.clear_session() 51 | done = False if np.isinf(nll) or np.isnan(nll) else True 52 | print("saving {} {}".format(rmse, nll)) 53 | RMSE[di, ti, n] = rmse 54 | NLL[di, ti, n] = nll 55 | 56 | RESULTS = np.hstack((RMSE, NLL)) 57 | mu = RESULTS.mean(axis=-1) 58 | error = np.std(RESULTS, axis=-1) 59 | 60 | print("==========================") 61 | print("[{}]: {} pm {}".format(dataset, mu, error)) 62 | print("==========================") 63 | 64 | print("TRAINERS: {}\nDATASETS: {}".format([trainer.__name__ for trainer in training_schemes], datasets)) 65 | print("MEAN: \n{}".format(mu)) 66 | print("ERROR: \n{}".format(error)) 67 | 68 | import pdb; pdb.set_trace() 69 | -------------------------------------------------------------------------------- /neurips2020/data/uci/wine-quality/winequality.names: -------------------------------------------------------------------------------- 1 | Citation Request: 2 | This dataset is public available for research. The details are described in [Cortez et al., 2009]. 3 | Please include this citation if you plan to use this database: 4 | 5 | P. Cortez, A. Cerdeira, F. Almeida, T. Matos and J. Reis. 6 | Modeling wine preferences by data mining from physicochemical properties. 7 | In Decision Support Systems, Elsevier, 47(4):547-553. ISSN: 0167-9236. 8 | 9 | Available at: [@Elsevier] http://dx.doi.org/10.1016/j.dss.2009.05.016 10 | [Pre-press (pdf)] http://www3.dsi.uminho.pt/pcortez/winequality09.pdf 11 | [bib] http://www3.dsi.uminho.pt/pcortez/dss09.bib 12 | 13 | 1. Title: Wine Quality 14 | 15 | 2. Sources 16 | Created by: Paulo Cortez (Univ. Minho), Antonio Cerdeira, Fernando Almeida, Telmo Matos and Jose Reis (CVRVV) @ 2009 17 | 18 | 3. Past Usage: 19 | 20 | P. Cortez, A. Cerdeira, F. Almeida, T. Matos and J. Reis. 21 | Modeling wine preferences by data mining from physicochemical properties. 22 | In Decision Support Systems, Elsevier, 47(4):547-553. ISSN: 0167-9236. 23 | 24 | In the above reference, two datasets were created, using red and white wine samples. 25 | The inputs include objective tests (e.g. PH values) and the output is based on sensory data 26 | (median of at least 3 evaluations made by wine experts). Each expert graded the wine quality 27 | between 0 (very bad) and 10 (very excellent). Several data mining methods were applied to model 28 | these datasets under a regression approach. The support vector machine model achieved the 29 | best results. Several metrics were computed: MAD, confusion matrix for a fixed error tolerance (T), 30 | etc. Also, we plot the relative importances of the input variables (as measured by a sensitivity 31 | analysis procedure). 32 | 33 | 4. Relevant Information: 34 | 35 | The two datasets are related to red and white variants of the Portuguese "Vinho Verde" wine. 36 | For more details, consult: http://www.vinhoverde.pt/en/ or the reference [Cortez et al., 2009]. 37 | Due to privacy and logistic issues, only physicochemical (inputs) and sensory (the output) variables 38 | are available (e.g. there is no data about grape types, wine brand, wine selling price, etc.). 39 | 40 | These datasets can be viewed as classification or regression tasks. 41 | The classes are ordered and not balanced (e.g. there are munch more normal wines than 42 | excellent or poor ones). Outlier detection algorithms could be used to detect the few excellent 43 | or poor wines. Also, we are not sure if all input variables are relevant. So 44 | it could be interesting to test feature selection methods. 45 | 46 | 5. Number of Instances: red wine - 1599; white wine - 4898. 47 | 48 | 6. Number of Attributes: 11 + output attribute 49 | 50 | Note: several of the attributes may be correlated, thus it makes sense to apply some sort of 51 | feature selection. 52 | 53 | 7. Attribute information: 54 | 55 | For more information, read [Cortez et al., 2009]. 56 | 57 | Input variables (based on physicochemical tests): 58 | 1 - fixed acidity 59 | 2 - volatile acidity 60 | 3 - citric acid 61 | 4 - residual sugar 62 | 5 - chlorides 63 | 6 - free sulfur dioxide 64 | 7 - total sulfur dioxide 65 | 8 - density 66 | 9 - pH 67 | 10 - sulphates 68 | 11 - alcohol 69 | Output variable (based on sensory data): 70 | 12 - quality (score between 0 and 10) 71 | 72 | 8. Missing Attribute Values: None 73 | -------------------------------------------------------------------------------- /neurips2020/models/depth/evidential.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | # tf.enable_eager_execution() 3 | from tensorflow.keras.layers import Conv2D, MaxPooling2D, \ 4 | UpSampling2D, Cropping2D, concatenate, ZeroPadding2D, SpatialDropout2D 5 | import functools 6 | from evidential_deep_learning.layers import Conv2DNormalGamma 7 | 8 | def create(input_shape, activation=tf.nn.relu, num_class=1): 9 | opts = locals().copy() 10 | 11 | concat_axis = 3 12 | inputs = tf.keras.layers.Input(shape=input_shape) 13 | 14 | Conv2D_ = functools.partial(Conv2D, activation=activation, padding='same') 15 | 16 | conv1 = Conv2D_(32, (3, 3), name='conv1_1')(inputs) 17 | conv1 = Conv2D_(32, (3, 3))(conv1) 18 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) 19 | 20 | conv2 = Conv2D_(64, (3, 3))(pool1) 21 | conv2 = Conv2D_(64, (3, 3))(conv2) 22 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) 23 | 24 | conv3 = Conv2D_(128, (3, 3))(pool2) 25 | conv3 = Conv2D_(128, (3, 3))(conv3) 26 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) 27 | 28 | conv4 = Conv2D_(256, (3, 3))(pool3) 29 | conv4 = Conv2D_(256, (3, 3))(conv4) 30 | pool4 = MaxPooling2D(pool_size=(2, 2))(conv4) 31 | 32 | conv5 = Conv2D_(512, (3, 3))(pool4) 33 | conv5 = Conv2D_(512, (3, 3))(conv5) 34 | 35 | up_conv5 = UpSampling2D(size=(2, 2))(conv5) 36 | ch, cw = get_crop_shape(conv4, up_conv5) 37 | crop_conv4 = Cropping2D(cropping=(ch,cw))(conv4) 38 | up6 = concatenate([up_conv5, crop_conv4], axis=concat_axis) 39 | conv6 = Conv2D_(256, (3, 3))(up6) 40 | conv6 = Conv2D_(256, (3, 3))(conv6) 41 | 42 | up_conv6 = UpSampling2D(size=(2, 2))(conv6) 43 | ch, cw = get_crop_shape(conv3, up_conv6) 44 | crop_conv3 = Cropping2D(cropping=(ch,cw))(conv3) 45 | up7 = concatenate([up_conv6, crop_conv3], axis=concat_axis) 46 | conv7 = Conv2D_(128, (3, 3))(up7) 47 | conv7 = Conv2D_(128, (3, 3))(conv7) 48 | 49 | up_conv7 = UpSampling2D(size=(2, 2))(conv7) 50 | ch, cw = get_crop_shape(conv2, up_conv7) 51 | crop_conv2 = Cropping2D(cropping=(ch,cw))(conv2) 52 | up8 = concatenate([up_conv7, crop_conv2], axis=concat_axis) 53 | conv8 = Conv2D_(64, (3, 3))(up8) 54 | conv8 = Conv2D_(64, (3, 3))(conv8) 55 | 56 | up_conv8 = UpSampling2D(size=(2, 2))(conv8) 57 | ch, cw = get_crop_shape(conv1, up_conv8) 58 | crop_conv1 = Cropping2D(cropping=(ch,cw))(conv1) 59 | up9 = concatenate([up_conv8, crop_conv1], axis=concat_axis) 60 | conv9 = Conv2D_(32, (3, 3))(up9) 61 | conv9 = Conv2D_(32, (3, 3))(conv9) 62 | 63 | ch, cw = get_crop_shape(inputs, conv9) 64 | conv9 = ZeroPadding2D(padding=((ch[0], ch[1]), (cw[0], cw[1])))(conv9) 65 | conv10 = Conv2D_(4*num_class, (1, 1))(conv9) 66 | evidential_output = Conv2DNormalGamma(num_class, (1, 1))(conv10) 67 | 68 | model = tf.keras.models.Model(inputs=inputs, outputs=evidential_output) 69 | return model, opts 70 | 71 | def get_crop_shape(target, refer): 72 | # width, the 3rd dimension 73 | cw = (target.get_shape()[2] - refer.get_shape()[2]) 74 | assert (cw >= 0) 75 | if cw % 2 != 0: 76 | cw1, cw2 = int(cw/2), int(cw/2) + 1 77 | else: 78 | cw1, cw2 = int(cw/2), int(cw/2) 79 | # height, the 2nd dimension 80 | ch = (target.get_shape()[1] - refer.get_shape()[1]) 81 | assert (ch >= 0) 82 | if ch % 2 != 0: 83 | ch1, ch2 = int(ch/2), int(ch/2) + 1 84 | else: 85 | ch1, ch2 = int(ch/2), int(ch/2) 86 | 87 | return (ch1, ch2), (cw1, cw2) 88 | 89 | # import numpy as np 90 | # model = create((64,64,3), 2) 91 | # x = np.ones((1,64,64,3), dtype=np.float32) 92 | # output = model(x) 93 | # import pdb; pdb.set_trace() 94 | -------------------------------------------------------------------------------- /neurips2020/models/depth/bbbp.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow_probability as tfp 3 | from tensorflow.keras.layers import Conv2D, MaxPooling2D, \ 4 | UpSampling2D, Cropping2D, concatenate, ZeroPadding2D, SpatialDropout2D 5 | 6 | import functools 7 | 8 | def create(input_shape, num_class=1, activation=tf.nn.relu): 9 | opts = locals().copy() 10 | 11 | # model = Depth_BBBP(num_class, activation) 12 | # return model, opts 13 | 14 | concat_axis = 3 15 | inputs = tf.keras.layers.Input(shape=input_shape) 16 | 17 | Conv2D_ = functools.partial(tfp.layers.Convolution2DReparameterization, activation=activation, padding='same') 18 | 19 | conv1 = Conv2D_(32, (3, 3))(inputs) 20 | conv1 = Conv2D_(32, (3, 3))(conv1) 21 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) 22 | 23 | conv2 = Conv2D_(64, (3, 3))(pool1) 24 | conv2 = Conv2D_(64, (3, 3))(conv2) 25 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) 26 | 27 | conv3 = Conv2D_(128, (3, 3))(pool2) 28 | conv3 = Conv2D_(128, (3, 3))(conv3) 29 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) 30 | 31 | conv4 = Conv2D_(256, (3, 3))(pool3) 32 | conv4 = Conv2D_(256, (3, 3))(conv4) 33 | pool4 = MaxPooling2D(pool_size=(2, 2))(conv4) 34 | 35 | conv5 = Conv2D_(512, (3, 3))(pool4) 36 | conv5 = Conv2D_(512, (3, 3))(conv5) 37 | 38 | up_conv5 = UpSampling2D(size=(2, 2))(conv5) 39 | ch, cw = get_crop_shape(conv4, up_conv5) 40 | crop_conv4 = Cropping2D(cropping=(ch,cw))(conv4) 41 | up6 = concatenate([up_conv5, crop_conv4], axis=concat_axis) 42 | conv6 = Conv2D_(256, (3, 3))(up6) 43 | conv6 = Conv2D_(256, (3, 3))(conv6) 44 | 45 | up_conv6 = UpSampling2D(size=(2, 2))(conv6) 46 | ch, cw = get_crop_shape(conv3, up_conv6) 47 | crop_conv3 = Cropping2D(cropping=(ch,cw))(conv3) 48 | up7 = concatenate([up_conv6, crop_conv3], axis=concat_axis) 49 | conv7 = Conv2D_(128, (3, 3))(up7) 50 | conv7 = Conv2D_(128, (3, 3))(conv7) 51 | 52 | up_conv7 = UpSampling2D(size=(2, 2))(conv7) 53 | ch, cw = get_crop_shape(conv2, up_conv7) 54 | crop_conv2 = Cropping2D(cropping=(ch,cw))(conv2) 55 | up8 = concatenate([up_conv7, crop_conv2], axis=concat_axis) 56 | conv8 = Conv2D_(64, (3, 3))(up8) 57 | conv8 = Conv2D_(64, (3, 3))(conv8) 58 | 59 | up_conv8 = UpSampling2D(size=(2, 2))(conv8) 60 | ch, cw = get_crop_shape(conv1, up_conv8) 61 | crop_conv1 = Cropping2D(cropping=(ch,cw))(conv1) 62 | up9 = concatenate([up_conv8, crop_conv1], axis=concat_axis) 63 | conv9 = Conv2D_(32, (3, 3))(up9) 64 | conv9 = Conv2D_(32, (3, 3))(conv9) 65 | 66 | ch, cw = get_crop_shape(inputs, conv9) 67 | conv9 = ZeroPadding2D(padding=((ch[0], ch[1]), (cw[0], cw[1])))(conv9) 68 | conv10 = Conv2D(num_class, (1, 1))(conv9) 69 | conv10 = 1e-6 * conv10 70 | 71 | model = tf.keras.models.Model(inputs=inputs, outputs=conv10) 72 | return model, opts 73 | 74 | def get_crop_shape(target, refer): 75 | # width, the 3rd dimension 76 | cw = (target.get_shape()[2] - refer.get_shape()[2]) 77 | assert (cw >= 0) 78 | if cw % 2 != 0: 79 | cw1, cw2 = int(cw/2), int(cw/2) + 1 80 | else: 81 | cw1, cw2 = int(cw/2), int(cw/2) 82 | # height, the 2nd dimension 83 | ch = (target.get_shape()[1] - refer.get_shape()[1]) 84 | assert (ch >= 0) 85 | if ch % 2 != 0: 86 | ch1, ch2 = int(ch/2), int(ch/2) + 1 87 | else: 88 | ch1, ch2 = int(ch/2), int(ch/2) 89 | 90 | return (ch1, ch2), (cw1, cw2) 91 | # 92 | # # import numpy as np 93 | # # model = create((64,64,3), 2) 94 | # # x = np.ones((1,64,64,3), dtype=np.float32) 95 | # # output = model(x) 96 | # # import pdb; pdb.set_trace() 97 | -------------------------------------------------------------------------------- /neurips2020/models/depth/dropout.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.layers import Conv2D, MaxPooling2D, \ 3 | UpSampling2D, Cropping2D, concatenate, ZeroPadding2D, SpatialDropout2D 4 | import functools 5 | from evidential_deep_learning.layers import Conv2DNormal 6 | 7 | def create(input_shape, drop_prob=0.1, reg=None, sigma=False, activation=tf.nn.relu, num_class=1, lam=1e-3, l=0.5): 8 | opts = locals().copy() 9 | 10 | concat_axis = 3 11 | inputs = tf.keras.layers.Input(shape=input_shape) 12 | # inputs_normalized = tf.multiply(inputs, 1/255.) 13 | 14 | Conv2D_ = functools.partial(Conv2D, activation=activation, padding='same', kernel_regularizer=reg, bias_regularizer=reg) 15 | 16 | conv1 = Conv2D_(32, (3, 3))(inputs) 17 | conv1 = Conv2D_(32, (3, 3))(conv1) 18 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) 19 | pool1 = SpatialDropout2D(drop_prob)(pool1) 20 | 21 | conv2 = Conv2D_(64, (3, 3))(pool1) 22 | conv2 = Conv2D_(64, (3, 3))(conv2) 23 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) 24 | pool2 = SpatialDropout2D(drop_prob)(pool2) 25 | 26 | conv3 = Conv2D_(128, (3, 3))(pool2) 27 | conv3 = Conv2D_(128, (3, 3))(conv3) 28 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) 29 | pool3 = SpatialDropout2D(drop_prob)(pool3) 30 | 31 | conv4 = Conv2D_(256, (3, 3))(pool3) 32 | conv4 = Conv2D_(256, (3, 3))(conv4) 33 | pool4 = MaxPooling2D(pool_size=(2, 2))(conv4) 34 | pool4 = SpatialDropout2D(drop_prob)(pool4) 35 | 36 | conv5 = Conv2D_(512, (3, 3))(pool4) 37 | conv5 = Conv2D_(512, (3, 3))(conv5) 38 | 39 | up_conv5 = UpSampling2D(size=(2, 2))(conv5) 40 | ch, cw = get_crop_shape(conv4, up_conv5) 41 | crop_conv4 = Cropping2D(cropping=(ch,cw))(conv4) 42 | up6 = concatenate([up_conv5, crop_conv4], axis=concat_axis) 43 | conv6 = Conv2D_(256, (3, 3))(up6) 44 | conv6 = Conv2D_(256, (3, 3))(conv6) 45 | 46 | up_conv6 = UpSampling2D(size=(2, 2))(conv6) 47 | ch, cw = get_crop_shape(conv3, up_conv6) 48 | crop_conv3 = Cropping2D(cropping=(ch,cw))(conv3) 49 | up7 = concatenate([up_conv6, crop_conv3], axis=concat_axis) 50 | conv7 = Conv2D_(128, (3, 3))(up7) 51 | conv7 = Conv2D_(128, (3, 3))(conv7) 52 | 53 | up_conv7 = UpSampling2D(size=(2, 2))(conv7) 54 | ch, cw = get_crop_shape(conv2, up_conv7) 55 | crop_conv2 = Cropping2D(cropping=(ch,cw))(conv2) 56 | up8 = concatenate([up_conv7, crop_conv2], axis=concat_axis) 57 | conv8 = Conv2D_(64, (3, 3))(up8) 58 | conv8 = Conv2D_(64, (3, 3))(conv8) 59 | 60 | up_conv8 = UpSampling2D(size=(2, 2))(conv8) 61 | ch, cw = get_crop_shape(conv1, up_conv8) 62 | crop_conv1 = Cropping2D(cropping=(ch,cw))(conv1) 63 | up9 = concatenate([up_conv8, crop_conv1], axis=concat_axis) 64 | conv9 = Conv2D_(32, (3, 3))(up9) 65 | conv9 = Conv2D_(32, (3, 3))(conv9) 66 | 67 | ch, cw = get_crop_shape(inputs, conv9) 68 | conv9 = ZeroPadding2D(padding=((ch[0], ch[1]), (cw[0], cw[1])))(conv9) 69 | if sigma: 70 | conv10 = Conv2DNormal(num_class, (1, 1))(conv9) 71 | else: 72 | conv10 = Conv2D(num_class, (1, 1))(conv9) 73 | 74 | # conv10 = tf.multiply(conv10, 255.) 75 | model = tf.keras.models.Model(inputs=inputs, outputs=conv10) 76 | return model, opts 77 | 78 | def get_crop_shape(target, refer): 79 | # width, the 3rd dimension 80 | cw = (target.get_shape()[2] - refer.get_shape()[2]) 81 | assert (cw >= 0) 82 | if cw % 2 != 0: 83 | cw1, cw2 = int(cw/2), int(cw/2) + 1 84 | else: 85 | cw1, cw2 = int(cw/2), int(cw/2) 86 | # height, the 2nd dimension 87 | ch = (target.get_shape()[1] - refer.get_shape()[1]) 88 | assert (ch >= 0) 89 | if ch % 2 != 0: 90 | ch1, ch2 = int(ch/2), int(ch/2) + 1 91 | else: 92 | ch1, ch2 = int(ch/2), int(ch/2) 93 | 94 | return (ch1, ch2), (cw1, cw2) 95 | 96 | # import numpy as np 97 | # model = create((64,64,3), 2) 98 | # x = np.ones((1,64,64,3), dtype=np.float32) 99 | # output = model(x) 100 | # import pdb; pdb.set_trace() 101 | -------------------------------------------------------------------------------- /neurips2020/data/uci/concrete/Concrete_Readme.txt: -------------------------------------------------------------------------------- 1 | Concrete Compressive Strength 2 | 3 | --------------------------------- 4 | 5 | Data Type: multivariate 6 | 7 | Abstract: Concrete is the most important material in civil engineering. The 8 | concrete compressive strength is a highly nonlinear function of age and 9 | ingredients. These ingredients include cement, blast furnace slag, fly ash, 10 | water, superplasticizer, coarse aggregate, and fine aggregate. 11 | 12 | --------------------------------- 13 | 14 | Sources: 15 | 16 | Original Owner and Donor 17 | Prof. I-Cheng Yeh 18 | Department of Information Management 19 | Chung-Hua University, 20 | Hsin Chu, Taiwan 30067, R.O.C. 21 | e-mail:icyeh@chu.edu.tw 22 | TEL:886-3-5186511 23 | 24 | Date Donated: August 3, 2007 25 | 26 | --------------------------------- 27 | 28 | Data Characteristics: 29 | 30 | The actual concrete compressive strength (MPa) for a given mixture under a 31 | specific age (days) was determined from laboratory. Data is in raw form (not scaled). 32 | 33 | Summary Statistics: 34 | 35 | Number of instances (observations): 1030 36 | Number of Attributes: 9 37 | Attribute breakdown: 8 quantitative input variables, and 1 quantitative output variable 38 | Missing Attribute Values: None 39 | 40 | --------------------------------- 41 | 42 | Variable Information: 43 | 44 | Given is the variable name, variable type, the measurement unit and a brief description. 45 | The concrete compressive strength is the regression problem. The order of this listing 46 | corresponds to the order of numerals along the rows of the database. 47 | 48 | Name -- Data Type -- Measurement -- Description 49 | 50 | Cement (component 1) -- quantitative -- kg in a m3 mixture -- Input Variable 51 | Blast Furnace Slag (component 2) -- quantitative -- kg in a m3 mixture -- Input Variable 52 | Fly Ash (component 3) -- quantitative -- kg in a m3 mixture -- Input Variable 53 | Water (component 4) -- quantitative -- kg in a m3 mixture -- Input Variable 54 | Superplasticizer (component 5) -- quantitative -- kg in a m3 mixture -- Input Variable 55 | Coarse Aggregate (component 6) -- quantitative -- kg in a m3 mixture -- Input Variable 56 | Fine Aggregate (component 7) -- quantitative -- kg in a m3 mixture -- Input Variable 57 | Age -- quantitative -- Day (1~365) -- Input Variable 58 | Concrete compressive strength -- quantitative -- MPa -- Output Variable 59 | --------------------------------- 60 | 61 | Past Usage: 62 | 63 | Main 64 | 1. I-Cheng Yeh, "Modeling of strength of high performance concrete using artificial 65 | neural networks," Cement and Concrete Research, Vol. 28, No. 12, pp. 1797-1808 (1998). 66 | 67 | Others 68 | 2. I-Cheng Yeh, "Modeling Concrete Strength with Augment-Neuron Networks," J. of 69 | Materials in Civil Engineering, ASCE, Vol. 10, No. 4, pp. 263-268 (1998). 70 | 71 | 3. I-Cheng Yeh, "Design of High Performance Concrete Mixture Using Neural Networks," 72 | J. of Computing in Civil Engineering, ASCE, Vol. 13, No. 1, pp. 36-42 (1999). 73 | 74 | 4. I-Cheng Yeh, "Prediction of Strength of Fly Ash and Slag Concrete By The Use of 75 | Artificial Neural Networks," Journal of the Chinese Institute of Civil and Hydraulic 76 | Engineering, Vol. 15, No. 4, pp. 659-663 (2003). 77 | 78 | 5. I-Cheng Yeh, "A mix Proportioning Methodology for Fly Ash and Slag Concrete Using 79 | Artificial Neural Networks," Chung Hua Journal of Science and Engineering, Vol. 1, No. 80 | 1, pp. 77-84 (2003). 81 | 82 | 6. Yeh, I-Cheng, "Analysis of strength of concrete using design of experiments and 83 | neural networks,": Journal of Materials in Civil Engineering, ASCE, Vol.18, No.4, 84 | pp.597-604 ?2006?. 85 | 86 | --------------------------------- 87 | 88 | Acknowledgements, Copyright Information, and Availability: 89 | 90 | NOTE: Reuse of this database is unlimited with retention of copyright notice for 91 | Prof. I-Cheng Yeh and the following published paper: 92 | 93 | I-Cheng Yeh, "Modeling of strength of high performance concrete using artificial 94 | neural networks," Cement and Concrete Research, Vol. 28, No. 12, pp. 1797-1808 (1998) 95 | 96 | 97 | 98 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: evidence 2 | channels: 3 | - plotly 4 | - conda-forge 5 | - anaconda 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _tflow_select=2.1.0=gpu 10 | - absl-py=0.9.0=py37_0 11 | - astor=0.8.0=py37_0 12 | - attrs=19.3.0=py_0 13 | - blas=1.0=mkl 14 | - blinker=1.4=py37_0 15 | - bzip2=1.0.8=h7b6447c_0 16 | - c-ares=1.15.0=h7b6447c_1001 17 | - ca-certificates=2020.1.1=0 18 | - cachetools=3.1.1=py_0 19 | - cairo=1.16.0=h18b612c_1001 20 | - certifi=2020.4.5.1=py37_0 21 | - cffi=1.14.0=py37he30daa8_1 22 | - cftime=1.1.2=py37heb32a55_0 23 | - chardet=3.0.4=py37_1003 24 | - click=7.1.2=py_0 25 | - cryptography=2.9.2=py37h1ba5d50_0 26 | - cudatoolkit=10.1.243=h6bb024c_0 27 | - cudnn=7.6.5=cuda10.1_0 28 | - cupti=10.1.168=0 29 | - curl=7.69.1=hbc83047_0 30 | - cycler=0.10.0=py37_0 31 | - dbus=1.13.14=hb2f20db_0 32 | - decorator=4.4.2=py_0 33 | - dill=0.3.1.1=py37_0 34 | - expat=2.2.6=he6710b0_0 35 | - ffmpeg=4.1.3=h167e202_0 36 | - fontconfig=2.13.1=he4413a7_1000 37 | - freetype=2.9.1=h8a8886c_1 38 | - future=0.18.2=py37_0 39 | - giflib=5.1.4=h14c3975_1 40 | - glib=2.63.1=h3eb4bd4_1 41 | - gmp=6.1.2=h6c8ec71_1 42 | - gnutls=3.6.5=h71b1129_1002 43 | - google-auth=1.14.1=py_0 44 | - google-auth-oauthlib=0.4.1=py_2 45 | - google-pasta=0.2.0=py_0 46 | - googleapis-common-protos=1.51.0=py37_2 47 | - graphite2=1.3.13=h23475e2_0 48 | - grpcio=1.27.2=py37hf8bcb03_0 49 | - gst-plugins-base=1.14.0=hbbd80ab_1 50 | - gstreamer=1.14.0=hb31296c_0 51 | - h5py=2.10.0=py37h7918eee_0 52 | - harfbuzz=2.4.0=h37c48d4_1 53 | - hdf4=4.2.13=h3ca952b_2 54 | - hdf5=1.10.4=hb1b8bf9_0 55 | - icu=58.2=he6710b0_3 56 | - idna=2.9=py_1 57 | - intel-openmp=2020.1=217 58 | - jasper=1.900.1=hd497a04_4 59 | - jinja2=2.11.2=py_0 60 | - joblib=0.15.1=py_0 61 | - jpeg=9c=h14c3975_1001 62 | - keras-applications=1.0.8=py_0 63 | - keras-preprocessing=1.1.0=py_1 64 | - kiwisolver=1.2.0=py37hfd86e86_0 65 | - krb5=1.17.1=h173b8e3_0 66 | - lame=3.100=h7b6447c_0 67 | - ld_impl_linux-64=2.33.1=h53a641e_7 68 | - libblas=3.8.0=15_mkl 69 | - libcblas=3.8.0=15_mkl 70 | - libcurl=7.69.1=h20c2e04_0 71 | - libedit=3.1.20181209=hc058e9b_0 72 | - libffi=3.3=he6710b0_1 73 | - libgcc-ng=9.1.0=hdf63c60_0 74 | - libgfortran-ng=7.3.0=hdf63c60_0 75 | - libiconv=1.15=h63c8f33_5 76 | - liblapack=3.8.0=15_mkl 77 | - liblapacke=3.8.0=15_mkl 78 | - libnetcdf=4.7.3=hb80b6cc_0 79 | - libpng=1.6.37=hbc83047_0 80 | - libprotobuf=3.11.4=hd408876_0 81 | - libssh2=1.9.0=h1ba5d50_1 82 | - libstdcxx-ng=9.1.0=hdf63c60_0 83 | - libtiff=4.1.0=h2733197_0 84 | - libuuid=2.32.1=h14c3975_1000 85 | - libwebp=1.0.1=h8e7db2f_0 86 | - libxcb=1.13=h1bed415_1 87 | - libxml2=2.9.9=hea5a465_1 88 | - markdown=3.1.1=py37_0 89 | - markupsafe=1.1.1=py37h7b6447c_0 90 | - matplotlib=3.1.3=py37_0 91 | - matplotlib-base=3.1.3=py37hef1b27d_0 92 | - meshio=4.0.13=py_0 93 | - mkl=2020.1=217 94 | - mkl-service=2.3.0=py37he904b0f_0 95 | - mkl_fft=1.0.15=py37ha843d7b_0 96 | - mkl_random=1.1.1=py37h0573a6f_0 97 | - mpld3=0.3=py37_0 98 | - ncurses=6.2=he6710b0_1 99 | - netcdf4=1.5.3=py37hbf33ddf_0 100 | - nettle=3.4.1=hbb512f6_0 101 | - numpy=1.18.1=py37h4f9e942_0 102 | - numpy-base=1.18.1=py37hde5b4d6_1 103 | - oauthlib=3.1.0=py_0 104 | - olefile=0.46=py37_0 105 | - opencv=4.1.0=py37h5517eff_4 106 | - openh264=1.8.0=hd408876_0 107 | - openssl=1.1.1g=h7b6447c_0 108 | - opt_einsum=3.1.0=py_0 109 | - pandas=1.0.3=py37h0573a6f_0 110 | - patsy=0.5.1=py37_0 111 | - pcre=8.43=he6710b0_0 112 | - pillow=7.1.2=py37hb39fc2d_0 113 | - pip=20.0.2=py37_3 114 | - pixman=0.38.0=h7b6447c_0 115 | - plotly=4.6.0=py_0 116 | - plotly-orca=1.3.1=1 117 | - promise=2.2.1=py37_0 118 | - protobuf=3.11.4=py37he6710b0_0 119 | - psutil=5.7.0=py37h7b6447c_0 120 | - pyasn1=0.4.8=py_0 121 | - pyasn1-modules=0.2.7=py_0 122 | - pycparser=2.20=py_0 123 | - pyjwt=1.7.1=py37_0 124 | - pyopenssl=19.1.0=py37_0 125 | - pyparsing=2.4.7=py_0 126 | - pyqt=5.9.2=py37h05f1152_2 127 | - pysocks=1.7.1=py37_0 128 | - python=3.7.7=hcff3b4d_5 129 | - python-dateutil=2.8.1=py_0 130 | - pytz=2020.1=py_0 131 | - qt=5.9.7=h5867ecd_1 132 | - readline=8.0=h7b6447c_0 133 | - requests=2.23.0=py37_0 134 | - requests-oauthlib=1.3.0=py_0 135 | - retrying=1.3.3=py37_2 136 | - rsa=4.0=py_0 137 | - scikit-learn=0.22.1=py37hd81dba3_0 138 | - scipy=1.4.1=py37h0b6359f_0 139 | - seaborn=0.10.1=py_0 140 | - setuptools=46.4.0=py37_0 141 | - sip=4.19.8=py37hf484d3e_0 142 | - six=1.14.0=py37_0 143 | - sqlite=3.31.1=h62c20be_1 144 | - statsmodels=0.11.1=py37h7b6447c_0 145 | - tensorboard=2.2.1=pyh532a8cf_0 146 | - tensorboard-plugin-wit=1.6.0=py_0 147 | - tensorflow=2.1.0=gpu_py37h7a4bb67_0 148 | - tensorflow-base=2.1.0=gpu_py37h6c5654b_0 149 | - tensorflow-datasets=1.2.0=py37_0 150 | - tensorflow-estimator=2.1.0=pyhd54b08b_0 151 | - tensorflow-gpu=2.1.0=h0d30ee6_0 152 | - tensorflow-metadata=0.14.0=pyhe6710b0_1 153 | - tensorflow-probability=0.8.0=py_0 154 | - termcolor=1.1.0=py37_1 155 | - tk=8.6.8=hbc83047_0 156 | - tornado=6.0.4=py37h7b6447c_1 157 | - tqdm=4.46.0=py_0 158 | - urllib3=1.25.8=py37_0 159 | - werkzeug=1.0.1=py_0 160 | - wheel=0.34.2=py37_0 161 | - wrapt=1.12.1=py37h7b6447c_1 162 | - x264=1!152.20180806=h7b6447c_0 163 | - xlrd=1.2.0=py37_0 164 | - xorg-kbproto=1.0.7=h14c3975_1002 165 | - xorg-libice=1.0.10=h516909a_0 166 | - xorg-libsm=1.2.3=h84519dc_1000 167 | - xorg-libx11=1.6.9=h516909a_0 168 | - xorg-libxext=1.3.4=h516909a_0 169 | - xorg-libxrender=0.9.10=h516909a_1002 170 | - xorg-renderproto=0.11.1=h14c3975_1002 171 | - xorg-xextproto=7.3.0=h14c3975_1002 172 | - xorg-xproto=7.0.31=h14c3975_1007 173 | - xz=5.2.5=h7b6447c_0 174 | - zlib=1.2.11=h7b6447c_3 175 | - zstd=1.3.7=h0b5b093_0 176 | - pip: 177 | - cloudpickle==1.4.1 178 | - gast==0.3.3 179 | - importlib-metadata==1.6.0 180 | - zipp==3.1.0 181 | prefix: /home/amini/miniconda3/envs/evidence 182 | 183 | -------------------------------------------------------------------------------- /neurips2020/trainers/deterministic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import time 4 | import datetime 5 | import os 6 | import sys 7 | import h5py 8 | from pathlib import Path 9 | 10 | import evidential_deep_learning as edl 11 | from .util import normalize, gallery 12 | 13 | class Deterministic: 14 | def __init__(self, model, opts, dataset="", learning_rate=1e-3, tag=""): 15 | self.loss_function = edl.losses.MSE 16 | 17 | self.model = model 18 | 19 | self.optimizer = tf.optimizers.Adam(learning_rate) 20 | 21 | self.min_rmse = float('inf') 22 | self.min_nll = -float('inf') # deterministic model has inf LL 23 | self.min_vloss = float('inf') 24 | 25 | trainer = self.__class__.__name__ 26 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 27 | self.save_dir = os.path.join('save','{}_{}_{}_{}'.format(current_time, dataset, trainer, tag)) 28 | Path(self.save_dir).mkdir(parents=True, exist_ok=True) 29 | 30 | train_log_dir = os.path.join('logs', '{}_{}_{}_{}_train'.format(current_time, dataset, trainer, tag)) 31 | self.train_summary_writer = tf.summary.create_file_writer(train_log_dir) 32 | val_log_dir = os.path.join('logs', '{}_{}_{}_{}_val'.format(current_time, dataset, trainer, tag)) 33 | self.val_summary_writer = tf.summary.create_file_writer(val_log_dir) 34 | 35 | @tf.function 36 | def run_train_step(self, x, y): 37 | with tf.GradientTape() as tape: 38 | y_hat = self.model(x, training=True) #forward pass 39 | loss = self.loss_function(y, y_hat) 40 | grads = tape.gradient(loss, self.model.variables) #compute gradient 41 | self.optimizer.apply_gradients(zip(grads, self.model.variables)) 42 | 43 | return loss, y_hat 44 | 45 | @tf.function 46 | def evaluate(self, x, y): 47 | y_hat = self.model(x, training=True) #forward pass 48 | rmse = edl.losses.RMSE(y, y_hat) 49 | loss = self.loss_function(y, y_hat) 50 | 51 | return y_hat, loss, rmse 52 | 53 | def save_train_summary(self, loss, x, y, y_hat): 54 | with self.train_summary_writer.as_default(): 55 | tf.summary.scalar('mse', tf.reduce_mean(edl.losses.MSE(y, y_hat)), step=self.iter) 56 | tf.summary.scalar('loss', tf.reduce_mean(loss), step=self.iter) 57 | 58 | idx = np.random.choice(int(tf.shape(x)[0]), 9) 59 | if tf.shape(x).shape==4: 60 | tf.summary.image("x", [self.gallery(tf.gather(x,idx).numpy())], max_outputs=1, step=self.iter) 61 | 62 | if tf.shape(y).shape==4: 63 | tf.summary.image("y", [self.gallery(tf.gather(y,idx).numpy())], max_outputs=1, step=self.iter) 64 | tf.summary.image("y_hat", [self.gallery(tf.gather(y_hat,idx).numpy())], max_outputs=1, step=self.iter) 65 | 66 | def save_val_summary(self, loss, x, y, y_hat): 67 | with self.val_summary_writer.as_default(): 68 | tf.summary.scalar('mse', tf.reduce_mean(edl.losses.MSE(y, y_hat)), step=self.iter) 69 | tf.summary.scalar('loss', tf.reduce_mean(self.loss_function(y, y_hat)), step=self.iter) 70 | idx = np.random.choice(int(tf.shape(x)[0]), 9) 71 | if tf.shape(x).shape==4: 72 | tf.summary.image("x", [self.gallery(tf.gather(x,idx).numpy())], max_outputs=1, step=self.iter) 73 | 74 | if tf.shape(y).shape==4: 75 | tf.summary.image("y", [self.gallery(tf.gather(y,idx).numpy())], max_outputs=1, step=self.iter) 76 | tf.summary.image("y_hat", [self.gallery(tf.gather(y_hat,idx).numpy())], max_outputs=1, step=self.iter) 77 | 78 | def get_batch(self, x, y, batch_size): 79 | idx = np.random.choice(x.shape[0], batch_size, replace=False) 80 | if isinstance(x, tf.Tensor): 81 | x_ = x[idx,...] 82 | y_ = y[idx,...] 83 | elif isinstance(x, np.ndarray) or isinstance(x, h5py.Dataset): 84 | idx = np.sort(idx) 85 | x_ = x[idx,...] 86 | y_ = y[idx,...] 87 | 88 | x_divisor = 255. if x_.dtype == np.uint8 else 1.0 89 | y_divisor = 255. if y_.dtype == np.uint8 else 1.0 90 | 91 | x_ = tf.convert_to_tensor(x_/x_divisor, tf.float32) 92 | y_ = tf.convert_to_tensor(y_/y_divisor, tf.float32) 93 | else: 94 | print("unknown dataset type {} {}".format(type(x), type(y))) 95 | return x_, y_ 96 | 97 | def save(self, name): 98 | self.model.save(os.path.join(self.save_dir, "{}.h5".format(name))) 99 | 100 | def train(self, x_train, y_train, x_test, y_test, y_scale, batch_size=128, iters=10000, verbose=True): 101 | tic = time.time() 102 | for self.iter in range(iters): 103 | x_input_batch, y_input_batch = self.get_batch(x_train, y_train, batch_size) 104 | loss, y_hat = self.run_train_step(x_input_batch, y_input_batch) 105 | 106 | if self.iter % 10 == 0: 107 | self.save_train_summary(loss, x_input_batch, y_input_batch, y_hat) 108 | 109 | if self.iter % 100 == 0: 110 | x_test_batch, y_test_batch = self.get_batch(x_test, y_test, min(100, x_test.shape[0])) 111 | y_hat, vloss, rmse = self.evaluate(x_test_batch, y_test_batch) 112 | rmse *= y_scale[0,0] 113 | 114 | self.save_val_summary(vloss, x_test_batch, y_test_batch, y_hat) 115 | 116 | if rmse.numpy() < self.min_rmse: 117 | self.min_rmse = rmse.numpy() 118 | self.save(f"model_rmse_{self.iter}") 119 | 120 | if vloss.numpy() < self.min_vloss: 121 | self.min_vloss = vloss.numpy() 122 | self.save(f"model_vloss_{self.iter}") 123 | 124 | if verbose: print("[{}] \t RMSE: {:.4f} \t NLL: {:.4f} \t train_loss: {:.4f} \t t: {:.2f} sec".format(self.iter, self.min_rmse, self.min_nll, loss, time.time()-tic)) 125 | tic = time.time() 126 | 127 | 128 | return self.model, self.min_rmse, self.min_nll 129 | -------------------------------------------------------------------------------- /neurips2020/trainers/gaussian.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import time 4 | import datetime 5 | import os 6 | import sys 7 | import h5py 8 | from pathlib import Path 9 | 10 | import evidential_deep_learning as edl 11 | from .util import normalize, gallery 12 | 13 | class Gaussian: 14 | def __init__(self, model, opts, dataset="", learning_rate=1e-3, tag=""): 15 | self.loss_function = edl.losses.Gaussian_NLL 16 | 17 | self.model = model 18 | 19 | self.optimizer = tf.optimizers.Adam(learning_rate) 20 | 21 | self.min_rmse = float('inf') 22 | self.min_nll = float('inf') 23 | self.min_vloss = float('inf') 24 | 25 | trainer = self.__class__.__name__ 26 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 27 | self.save_dir = os.path.join('save','{}_{}_{}_{}'.format(current_time, dataset, trainer, tag)) 28 | Path(self.save_dir).mkdir(parents=True, exist_ok=True) 29 | 30 | train_log_dir = os.path.join('logs', '{}_{}_{}_{}_train'.format(current_time, dataset, trainer, tag)) 31 | self.train_summary_writer = tf.summary.create_file_writer(train_log_dir) 32 | val_log_dir = os.path.join('logs', '{}_{}_{}_{}_val'.format(current_time, dataset, trainer, tag)) 33 | self.val_summary_writer = tf.summary.create_file_writer(val_log_dir) 34 | 35 | @tf.function 36 | def run_train_step(self, x, y): 37 | with tf.GradientTape() as tape: 38 | outputs = self.model(x, training=True) #forward pass 39 | mu, sigma = tf.split(outputs, 2, axis=-1) 40 | loss = self.loss_function(y, mu, sigma) 41 | grads = tape.gradient(loss, self.model.variables) #compute gradient 42 | self.optimizer.apply_gradients(zip(grads, self.model.variables)) 43 | 44 | return loss, mu 45 | 46 | @tf.function 47 | def evaluate(self, x, y): 48 | outputs = self.model(x, training=True) #forward pass 49 | mu, sigma = tf.split(outputs, 2, axis=-1) 50 | rmse = edl.losses.RMSE(y, mu) 51 | nll = edl.losses.Gaussian_NLL(y, mu, sigma) 52 | loss = self.loss_function(y, mu, sigma) 53 | 54 | return mu, sigma, loss, rmse, nll 55 | 56 | def save_train_summary(self, loss, x, y, y_hat): 57 | with self.train_summary_writer.as_default(): 58 | tf.summary.scalar('mse', tf.reduce_mean(edl.losses.MSE(y, y_hat)), step=self.iter) 59 | tf.summary.scalar('loss', tf.reduce_mean(loss), step=self.iter) 60 | idx = np.random.choice(int(tf.shape(x)[0]), 9) 61 | if tf.shape(x).shape==4: 62 | tf.summary.image("x", [gallery(tf.gather(x,idx).numpy())], max_outputs=1, step=self.iter) 63 | 64 | if tf.shape(y).shape==4: 65 | tf.summary.image("y", [gallery(tf.gather(y,idx).numpy())], max_outputs=1, step=self.iter) 66 | tf.summary.image("y_hat", [gallery(tf.gather(y_hat,idx).numpy())], max_outputs=1, step=self.iter) 67 | 68 | def save_val_summary(self, loss, x, y, mu, var): 69 | with self.val_summary_writer.as_default(): 70 | tf.summary.scalar('mse', tf.reduce_mean(edl.losses.MSE(y, mu)), step=self.iter) 71 | tf.summary.scalar('loss', tf.reduce_mean(self.loss_function(y, mu, tf.sqrt(var))), step=self.iter) 72 | idx = np.random.choice(int(tf.shape(x)[0]), 9) 73 | if tf.shape(x).shape==4: 74 | tf.summary.image("x", [gallery(tf.gather(x,idx).numpy())], max_outputs=1, step=self.iter) 75 | 76 | if tf.shape(y).shape==4: 77 | tf.summary.image("y", [gallery(tf.gather(y,idx).numpy())], max_outputs=1, step=self.iter) 78 | tf.summary.image("y_hat", [gallery(tf.gather(mu,idx).numpy())], max_outputs=1, step=self.iter) 79 | tf.summary.image("y_var", [gallery(normalize(tf.gather(var,idx)).numpy())], max_outputs=1, step=self.iter) 80 | 81 | def get_batch(self, x, y, batch_size): 82 | idx = np.random.choice(x.shape[0], batch_size, replace=False) 83 | if isinstance(x, tf.Tensor): 84 | x_ = x[idx,...] 85 | y_ = y[idx,...] 86 | elif isinstance(x, np.ndarray) or isinstance(x, h5py.Dataset): 87 | idx = np.sort(idx) 88 | x_ = x[idx,...] 89 | y_ = y[idx,...] 90 | 91 | x_divisor = 255. if x_.dtype == np.uint8 else 1.0 92 | y_divisor = 255. if y_.dtype == np.uint8 else 1.0 93 | 94 | x_ = tf.convert_to_tensor(x_/x_divisor, tf.float32) 95 | y_ = tf.convert_to_tensor(y_/y_divisor, tf.float32) 96 | else: 97 | print("unknown dataset type {} {}".format(type(x), type(y))) 98 | return x_, y_ 99 | 100 | def save(self, name): 101 | self.model.save(os.path.join(self.save_dir, "{}.h5".format(name))) 102 | 103 | def train(self, x_train, y_train, x_test, y_test, y_scale, batch_size=128, iters=10000, verbose=True): 104 | tic = time.time() 105 | for self.iter in range(iters): 106 | x_input_batch, y_input_batch = self.get_batch(x_train, y_train, batch_size) 107 | loss, y_hat = self.run_train_step(x_input_batch, y_input_batch) 108 | 109 | if self.iter % 10 == 0: 110 | self.save_train_summary(loss, x_input_batch, y_input_batch, y_hat) 111 | 112 | if self.iter % 10 == 0: 113 | x_test_batch, y_test_batch = self.get_batch(x_test, y_test, min(100, x_test.shape[0])) 114 | mu, var, vloss, rmse, nll = self.evaluate(x_test_batch, y_test_batch) 115 | nll += np.log(y_scale[0,0]) 116 | rmse *= y_scale[0,0] 117 | 118 | self.save_val_summary(vloss, x_test_batch, y_test_batch, mu, var) 119 | 120 | if rmse.numpy() < self.min_rmse: 121 | self.min_rmse = rmse.numpy() 122 | self.save(f"model_rmse_{self.iter}") 123 | 124 | if nll.numpy() < self.min_nll: 125 | self.min_nll = nll.numpy() 126 | self.save(f"model_nll_{self.iter}") 127 | 128 | if vloss.numpy() < self.min_vloss: 129 | self.min_vloss = vloss.numpy() 130 | self.save(f"model_vloss_{self.iter}") 131 | 132 | if verbose: print("[{}] \t RMSE: {:.4f} \t NLL: {:.4f} \t train_loss: {:.4f} \t t: {:.2f} sec".format(self.iter, self.min_rmse, self.min_nll, loss, time.time()-tic)) 133 | tic = time.time() 134 | 135 | 136 | return self.model, self.min_rmse, self.min_nll 137 | -------------------------------------------------------------------------------- /neurips2020/trainers/bbbp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import tensorflow_probability as tfp 4 | import time 5 | import datetime 6 | import os 7 | import sys 8 | import h5py 9 | from pathlib import Path 10 | 11 | import evidential_deep_learning as edl 12 | from .util import normalize, gallery 13 | 14 | class BBBP: 15 | def __init__(self, model, opts, dataset="", learning_rate=1e-3, tag=""): 16 | self.loss_function = edl.losses.MSE 17 | 18 | self.model = model 19 | 20 | self.optimizer = tf.optimizers.Adam(learning_rate) 21 | 22 | self.min_rmse = float('inf') 23 | self.min_nll = float('inf') 24 | self.min_vloss = float('inf') 25 | 26 | trainer = self.__class__.__name__ 27 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 28 | self.save_dir = os.path.join('save','{}_{}_{}_{}'.format(current_time, dataset, trainer, tag)) 29 | Path(self.save_dir).mkdir(parents=True, exist_ok=True) 30 | 31 | 32 | train_log_dir = os.path.join('logs', '{}_{}_{}_{}_train'.format(current_time, dataset, trainer, tag)) 33 | self.train_summary_writer = tf.summary.create_file_writer(train_log_dir) 34 | val_log_dir = os.path.join('logs', '{}_{}_{}_{}_val'.format(current_time, dataset, trainer, tag)) 35 | self.val_summary_writer = tf.summary.create_file_writer(val_log_dir) 36 | 37 | @tf.function 38 | def run_train_step(self, x, y): 39 | with tf.GradientTape() as tape: 40 | y_hat = self.model(x, training=True) #forward pass 41 | loss = self.loss_function(y, y_hat) 42 | loss += tf.reduce_mean(self.model.losses) 43 | 44 | grads = tape.gradient(loss, self.model.variables) #compute gradient 45 | self.optimizer.apply_gradients(zip(grads, self.model.variables)) 46 | return loss, y_hat 47 | 48 | @tf.function 49 | def evaluate(self, x, y): 50 | preds = tf.stack([self.model(x, training=True) for _ in range(5)], axis=0) #forward pass 51 | mu, var = tf.nn.moments(preds, axes=0) 52 | rmse = edl.losses.RMSE(y, mu) 53 | nll = edl.losses.Gaussian_NLL(y, mu, tf.sqrt(var)) 54 | loss = self.loss_function(y, mu) 55 | 56 | return mu, var, loss, rmse, nll 57 | 58 | def save_train_summary(self, loss, x, y, y_hat): 59 | with self.train_summary_writer.as_default(): 60 | # tf.summary.scalar('loss', tf.reduce_mean(loss), step=self.iter) 61 | tf.summary.scalar('mse', tf.reduce_mean(edl.losses.MSE(y, y_hat)), step=self.iter) 62 | idx = np.random.choice(int(tf.shape(x)[0]), 9) 63 | if tf.shape(x).shape==4: 64 | tf.summary.image("x", [gallery(tf.gather(x,idx).numpy())], max_outputs=1, step=self.iter) 65 | 66 | if tf.shape(y).shape==4: 67 | tf.summary.image("y", [gallery(tf.gather(y,idx).numpy())], max_outputs=1, step=self.iter) 68 | tf.summary.image("y_hat", [gallery(tf.gather(y_hat,idx).numpy())], max_outputs=1, step=self.iter) 69 | 70 | def save_val_summary(self, loss, x, y, mu, var): 71 | with self.val_summary_writer.as_default(): 72 | tf.summary.scalar('loss', tf.reduce_mean(self.loss_function(y, mu)), step=self.iter) 73 | tf.summary.scalar('mse', tf.reduce_mean(edl.losses.MSE(y, mu)), step=self.iter) 74 | idx = np.random.choice(int(tf.shape(x)[0]), 9) 75 | if tf.shape(x).shape==4: 76 | tf.summary.image("x", [gallery(tf.gather(x,idx).numpy())], max_outputs=1, step=self.iter) 77 | 78 | if tf.shape(y).shape==4: 79 | tf.summary.image("y", [gallery(tf.gather(y,idx).numpy())], max_outputs=1, step=self.iter) 80 | tf.summary.image("y_hat", [gallery(tf.gather(mu,idx).numpy())], max_outputs=1, step=self.iter) 81 | tf.summary.image("y_var", [gallery(normalize(tf.gather(var,idx)).numpy())], max_outputs=1, step=self.iter) 82 | 83 | def get_batch(self, x, y, batch_size): 84 | idx = np.random.choice(x.shape[0], batch_size, replace=False) 85 | if isinstance(x, tf.Tensor): 86 | x_ = x[idx,...] 87 | y_ = y[idx,...] 88 | elif isinstance(x, np.ndarray) or isinstance(x, h5py.Dataset): 89 | idx = np.sort(idx) 90 | x_ = x[idx,...] 91 | y_ = y[idx,...] 92 | 93 | x_divisor = 255. if x_.dtype == np.uint8 else 1.0 94 | y_divisor = 255. if y_.dtype == np.uint8 else 1.0 95 | 96 | x_ = tf.convert_to_tensor(x_/x_divisor, tf.float32) 97 | y_ = tf.convert_to_tensor(y_/y_divisor, tf.float32) 98 | else: 99 | print("unknown dataset type {} {}".format(type(x), type(y))) 100 | return x_, y_ 101 | 102 | def save(self, name): 103 | self.model.save(os.path.join(self.save_dir, "{}.h5".format(name))) 104 | # pass 105 | 106 | def train(self, x_train, y_train, x_test, y_test, y_scale, batch_size=128, iters=10000, verbose=True): 107 | tic = time.time() 108 | for self.iter in range(iters): 109 | x_input_batch, y_input_batch = self.get_batch(x_train, y_train, batch_size) 110 | loss, y_hat = self.run_train_step(x_input_batch, y_input_batch) 111 | 112 | if self.iter % 10 == 0: 113 | self.save_train_summary(loss, x_input_batch, y_input_batch, y_hat) 114 | 115 | if self.iter % 100 == 0: 116 | x_test_batch, y_test_batch = self.get_batch(x_test, y_test, min(100, x_test.shape[0])) 117 | mu, var, vloss, rmse, nll = self.evaluate(x_test_batch, y_test_batch) 118 | nll += np.log(y_scale[0,0]) 119 | rmse *= y_scale[0,0] 120 | 121 | self.save_val_summary(vloss, x_test_batch, y_test_batch, mu, var) 122 | 123 | if rmse.numpy() < self.min_rmse: 124 | self.min_rmse = rmse.numpy() 125 | print("SAVING") 126 | self.save("model_rmse") 127 | 128 | if nll.numpy() < self.min_nll: 129 | self.min_nll = nll.numpy() 130 | self.save("model_nll") 131 | 132 | if vloss.numpy() < self.min_vloss: 133 | self.min_vloss = vloss.numpy() 134 | self.save("model_vloss") 135 | 136 | if verbose: print("[{}] \t RMSE: {:.4f} \t NLL: {:.4f} \t train_loss: {:.4f} \t t: {:.2f} sec".format(self.iter, self.min_rmse, self.min_nll, vloss, time.time()-tic)) 137 | tic = time.time() 138 | 139 | return self.model, self.min_rmse, self.min_nll 140 | -------------------------------------------------------------------------------- /neurips2020/trainers/dropout.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import time 4 | import datetime 5 | import os 6 | import sys 7 | import h5py 8 | from pathlib import Path 9 | 10 | import evidential_deep_learning as edl 11 | from .util import normalize, gallery 12 | 13 | class Dropout: 14 | def __init__(self, model, opts, dataset="", learning_rate=1e-3, tag=""): 15 | 16 | self.model = model 17 | 18 | self.l = opts['l'] 19 | self.drop_prob = opts['drop_prob'] 20 | self.mse = not opts['sigma'] 21 | self.lam = opts['lam'] 22 | 23 | self.loss_function = edl.losses.MSE if self.mse else edl.losses.Gaussian_NLL 24 | self.optimizer = tf.optimizers.Adam(learning_rate) 25 | 26 | self.min_rmse = float('inf') 27 | self.min_nll = float('inf') 28 | self.min_vloss = float('inf') 29 | 30 | trainer = self.__class__.__name__ 31 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 32 | self.save_dir = os.path.join('save','{}_{}_{}_{}'.format(current_time, dataset, trainer, tag)) 33 | Path(self.save_dir).mkdir(parents=True, exist_ok=True) 34 | 35 | train_log_dir = os.path.join('logs', '{}_{}_{}_{}_train'.format(current_time, dataset, trainer, tag)) 36 | self.train_summary_writer = tf.summary.create_file_writer(train_log_dir) 37 | val_log_dir = os.path.join('logs', '{}_{}_{}_{}_val'.format(current_time, dataset, trainer, tag)) 38 | self.val_summary_writer = tf.summary.create_file_writer(val_log_dir) 39 | 40 | 41 | @tf.function 42 | def run_train_step(self, x, y): 43 | with tf.GradientTape() as tape: 44 | y_hat = self.model(x, training=True) #forward pass 45 | if self.mse: 46 | loss = self.loss_function(y, y_hat) 47 | else: 48 | mu, logsigma = tf.split(y_hat, 2, axis=-1) 49 | loss = self.loss_function(y, mu, tf.exp(logsigma)) 50 | loss += tf.reduce_sum(self.model.losses) 51 | 52 | grads = tape.gradient(loss, self.model.trainable_variables) #compute gradient 53 | self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables)) 54 | 55 | return loss, y_hat 56 | 57 | @tf.function 58 | def evaluate(self, x, y): 59 | preds = tf.stack([self.model(x, training=True) for _ in range(5)], axis=0) #forward pass 60 | mu, var = tf.nn.moments(preds, axes=0) 61 | if self.mse: 62 | mean_mu = mu 63 | loss = self.loss_function(y, mean_mu) 64 | else: 65 | mean_mu, mean_sigma = tf.split(mu, 2, axis=1) 66 | loss = self.loss_function(y, mean_mu, mean_sigma) 67 | 68 | rmse = edl.losses.RMSE(y, mean_mu) 69 | 70 | tau = self.l**2 * (1-self.drop_prob) / (2. * self.lam) # https://www.cs.ox.ac.uk/people/yarin.gal/website/blog_3d801aa532c1ce.html 71 | var += tau**-1 72 | nll = edl.losses.Gaussian_NLL(y, mean_mu, tf.sqrt(var)) 73 | 74 | return mu, var, loss, rmse, nll 75 | 76 | def save_train_summary(self, loss, x, y, y_hat): 77 | with self.train_summary_writer.as_default(): 78 | tf.summary.scalar('loss', tf.reduce_mean(loss), step=self.iter) 79 | tf.summary.scalar('mse', tf.reduce_mean(edl.losses.MSE(y, y_hat)), step=self.iter) 80 | idx = np.random.choice(int(tf.shape(x)[0]), 9) 81 | if tf.shape(x).shape==4: 82 | tf.summary.image("x", [gallery(tf.gather(x,idx).numpy())], max_outputs=1, step=self.iter) 83 | 84 | if tf.shape(y).shape==4: 85 | tf.summary.image("y", [gallery(tf.gather(y,idx).numpy())], max_outputs=1, step=self.iter) 86 | tf.summary.image("y_hat", [gallery(tf.gather(y_hat,idx).numpy())], max_outputs=1, step=self.iter) 87 | 88 | def save_val_summary(self, loss, x, y, mu, var): 89 | with self.val_summary_writer.as_default(): 90 | tf.summary.scalar('loss', loss, step=self.iter) 91 | tf.summary.scalar('mse', tf.reduce_mean(edl.losses.MSE(y, mu)), step=self.iter) 92 | idx = np.random.choice(int(tf.shape(x)[0]), 9) 93 | if tf.shape(x).shape==4: 94 | tf.summary.image("x", [gallery(tf.gather(x,idx).numpy())], max_outputs=1, step=self.iter) 95 | 96 | if tf.shape(y).shape==4: 97 | tf.summary.image("y", [gallery(tf.gather(y,idx).numpy())], max_outputs=1, step=self.iter) 98 | tf.summary.image("y_hat", [gallery(tf.gather(mu,idx).numpy())], max_outputs=1, step=self.iter) 99 | tf.summary.image("y_var", [gallery(normalize(tf.gather(var,idx)).numpy())], max_outputs=1, step=self.iter) 100 | 101 | def get_batch(self, x, y, batch_size): 102 | idx = np.random.choice(x.shape[0], batch_size, replace=False) 103 | if isinstance(x, tf.Tensor): 104 | x_ = x[idx,...] 105 | y_ = y[idx,...] 106 | elif isinstance(x, np.ndarray) or isinstance(x, h5py.Dataset): 107 | idx = np.sort(idx) 108 | x_ = x[idx,...] 109 | y_ = y[idx,...] 110 | 111 | x_divisor = 255. if x_.dtype == np.uint8 else 1.0 112 | y_divisor = 255. if y_.dtype == np.uint8 else 1.0 113 | 114 | x_ = tf.convert_to_tensor(x_/x_divisor, tf.float32) 115 | y_ = tf.convert_to_tensor(y_/y_divisor, tf.float32) 116 | else: 117 | print("unknown dataset type {} {}".format(type(x), type(y))) 118 | return x_, y_ 119 | 120 | def save(self, name): 121 | self.model.save(os.path.join(self.save_dir, "{}.h5".format(name))) 122 | 123 | def train(self, x_train, y_train, x_test, y_test, y_scale, batch_size=128, iters=10000, verbose=True): 124 | tic = time.time() 125 | for self.iter in range(iters): 126 | x_input_batch, y_input_batch = self.get_batch(x_train, y_train, batch_size) 127 | loss, y_hat = self.run_train_step(x_input_batch, y_input_batch) 128 | 129 | if self.iter % 10 == 0: 130 | self.save_train_summary(loss, x_input_batch, y_input_batch, y_hat) 131 | 132 | if self.iter % 100 == 0: 133 | x_test_batch, y_test_batch = self.get_batch(x_test, y_test, min(100, x_test.shape[0])) 134 | mu, var, vloss, rmse, nll = self.evaluate(x_test_batch, y_test_batch) 135 | nll += np.log(y_scale[0,0]) 136 | rmse *= y_scale[0,0] 137 | 138 | self.save_val_summary(vloss, x_test_batch, y_test_batch, mu, var) 139 | 140 | if rmse.numpy() < self.min_rmse: 141 | self.min_rmse = rmse.numpy() 142 | self.save(f"model_rmse_{self.iter}") 143 | 144 | if nll.numpy() < self.min_nll: 145 | self.min_nll = nll.numpy() 146 | self.save(f"model_nll_{self.iter}") 147 | 148 | if vloss.numpy() < self.min_vloss: 149 | self.min_vloss = vloss.numpy() 150 | self.save(f"model_vloss_{self.iter}") 151 | 152 | if verbose: print("[{}] \t RMSE: {:.4f} \t NLL: {:.4f} \t train_loss: {:.4f} \t t: {:.2f} sec".format(self.iter, self.min_rmse, self.min_nll, loss, time.time()-tic)) 153 | tic = time.time() 154 | 155 | 156 | return self.model, self.min_rmse, self.min_nll 157 | -------------------------------------------------------------------------------- /neurips2020/trainers/ensemble.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import time 4 | import datetime 5 | import os 6 | import sys 7 | import h5py 8 | from pathlib import Path 9 | 10 | import evidential_deep_learning as edl 11 | from .util import normalize, gallery 12 | 13 | class Ensemble: 14 | def __init__(self, models, opts, dataset="", learning_rate=1e-3, tag=""): 15 | self.mse = not opts['sigma'] 16 | self.loss_function = edl.losses.MSE if self.mse else edl.losses.Gaussian_NLL 17 | 18 | self.models = models 19 | 20 | self.num_ensembles = opts['num_ensembles'] 21 | 22 | self.optimizers = [tf.optimizers.Adam(learning_rate) for _ in range(self.num_ensembles)] 23 | 24 | self.min_rmse = float('inf') 25 | self.min_nll = float('inf') 26 | self.min_vloss = float('inf') 27 | 28 | trainer = self.__class__.__name__ 29 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 30 | self.save_dir = os.path.join('save','{}_{}_{}_{}'.format(current_time, dataset, trainer, tag)) 31 | Path(self.save_dir).mkdir(parents=True, exist_ok=True) 32 | 33 | train_log_dir = os.path.join('logs', '{}_{}_{}_{}_train'.format(current_time, dataset, trainer, tag)) 34 | self.train_summary_writer = tf.summary.create_file_writer(train_log_dir) 35 | val_log_dir = os.path.join('logs', '{}_{}_{}_{}_val'.format(current_time, dataset, trainer, tag)) 36 | self.val_summary_writer = tf.summary.create_file_writer(val_log_dir) 37 | 38 | 39 | @tf.function 40 | def run_train_step(self, x, y): 41 | losses = [] 42 | y_hats = [] 43 | for (model, optimizer) in zip(self.models, self.optimizers): #Autograph unrolls this so make sure ensemble size is not too large 44 | 45 | with tf.GradientTape() as tape: 46 | outputs = model(x, training=True) #forward pass 47 | if self.mse: 48 | mu = outputs 49 | loss = self.loss_function(y, mu) 50 | else: 51 | mu, sigma = tf.split(outputs, 2, axis=-1) 52 | loss = self.loss_function(y, mu, sigma) 53 | y_hats.append(mu) 54 | losses.append(loss) 55 | 56 | grads = tape.gradient(loss, model.variables) #compute gradient 57 | optimizer.apply_gradients(zip(grads, model.variables)) 58 | 59 | return tf.reduce_mean(losses), tf.reduce_mean(y_hats, 0) 60 | 61 | @tf.function 62 | def evaluate(self, x, y): 63 | preds = tf.stack([model(x, training=False) for model in self.models], axis=0) #forward pass 64 | if self.mse: 65 | mean_mu, var = tf.nn.moments(preds, 0) 66 | loss = self.loss_function(y, mean_mu) 67 | 68 | else: 69 | mus, sigmas = tf.split(preds, 2, axis=-1) 70 | mean_mu = tf.reduce_mean(mus, axis=0) 71 | mean_sigma = tf.reduce_mean(sigmas, axis=0) 72 | var = tf.reduce_mean(sigmas**2 + tf.square(mus), axis=0) - tf.square(mean_mu) 73 | loss = self.loss_function(y, mean_mu, tf.sqrt(var)) 74 | 75 | rmse = edl.losses.RMSE(y, mean_mu) 76 | nll = edl.losses.Gaussian_NLL(y, mean_mu, tf.sqrt(var)) 77 | 78 | return mean_mu, var, loss, rmse, nll 79 | 80 | def save_train_summary(self, loss, x, y, y_hat): 81 | with self.train_summary_writer.as_default(): 82 | tf.summary.scalar('mse', tf.reduce_mean(edl.losses.MSE(y, y_hat)), step=self.iter) 83 | tf.summary.scalar('loss', tf.reduce_mean(loss), step=self.iter) 84 | idx = np.random.choice(int(tf.shape(x)[0]), 9) 85 | if tf.shape(x).shape==4: 86 | tf.summary.image("x", [gallery(tf.gather(x,idx).numpy())], max_outputs=1, step=self.iter) 87 | 88 | if tf.shape(y).shape==4: 89 | tf.summary.image("y", [gallery(tf.gather(y,idx).numpy())], max_outputs=1, step=self.iter) 90 | tf.summary.image("y_hat", [gallery(tf.gather(y_hat,idx).numpy())], max_outputs=1, step=self.iter) 91 | 92 | def save_val_summary(self, loss, x, y, mu, var): 93 | with self.val_summary_writer.as_default(): 94 | tf.summary.scalar('mse', tf.reduce_mean(edl.losses.MSE(y, mu)), step=self.iter) 95 | tf.summary.scalar('loss', tf.reduce_mean(loss), step=self.iter) 96 | idx = np.random.choice(int(tf.shape(x)[0]), 9) 97 | if tf.shape(x).shape==4: 98 | tf.summary.image("x", [gallery(tf.gather(x,idx).numpy())], max_outputs=1, step=self.iter) 99 | 100 | if tf.shape(y).shape==4: 101 | tf.summary.image("y", [gallery(tf.gather(y,idx).numpy())], max_outputs=1, step=self.iter) 102 | tf.summary.image("y_hat", [gallery(tf.gather(mu,idx).numpy())], max_outputs=1, step=self.iter) 103 | tf.summary.image("y_var", [gallery(normalize(tf.gather(var,idx)).numpy())], max_outputs=1, step=self.iter) 104 | 105 | def get_batch(self, x, y, batch_size): 106 | idx = np.random.choice(x.shape[0], batch_size, replace=False) 107 | if isinstance(x, tf.Tensor): 108 | x_ = x[idx,...] 109 | y_ = y[idx,...] 110 | elif isinstance(x, np.ndarray) or isinstance(x, h5py.Dataset): 111 | idx = np.sort(idx) 112 | x_ = x[idx,...] 113 | y_ = y[idx,...] 114 | 115 | x_divisor = 255. if x_.dtype == np.uint8 else 1.0 116 | y_divisor = 255. if y_.dtype == np.uint8 else 1.0 117 | 118 | x_ = tf.convert_to_tensor(x_/x_divisor, tf.float32) 119 | y_ = tf.convert_to_tensor(y_/y_divisor, tf.float32) 120 | else: 121 | print("unknown dataset type {} {}".format(type(x), type(y))) 122 | return x_, y_ 123 | 124 | def save(self, name): 125 | for i, model in enumerate(self.models): 126 | model.save(os.path.join(self.save_dir, "{}_{}.h5".format(name, i))) 127 | 128 | def train(self, x_train, y_train, x_test, y_test, y_scale, batch_size=128, iters=10000, verbose=True): 129 | tic = time.time() 130 | for self.iter in range(iters): 131 | x_input_batch, y_input_batch = self.get_batch(x_train, y_train, batch_size) 132 | loss, y_hat = self.run_train_step(x_input_batch, y_input_batch) 133 | 134 | if self.iter % 10 == 0: 135 | self.save_train_summary(loss, x_input_batch, y_input_batch, y_hat) 136 | 137 | if self.iter % 100 == 0: 138 | x_test_batch, y_test_batch = self.get_batch(x_test, y_test, min(100, x_test.shape[0])) 139 | mu, var, vloss, rmse, nll = self.evaluate(x_test_batch, y_test_batch) 140 | nll += np.log(y_scale[0,0]) 141 | rmse *= y_scale[0,0] 142 | 143 | self.save_val_summary(vloss, x_test_batch, y_test_batch, mu, var) 144 | 145 | if rmse.numpy() < self.min_rmse: 146 | self.min_rmse = rmse.numpy() 147 | self.save(f"model_rmse_{self.iter}") 148 | 149 | if nll.numpy() < self.min_nll: 150 | self.min_nll = nll.numpy() 151 | self.save(f"model_nll_{self.iter}") 152 | 153 | if vloss.numpy() < self.min_vloss: 154 | self.min_vloss = vloss.numpy() 155 | self.save(f"model_vloss_{self.iter}") 156 | 157 | if verbose: print("[{}] \t RMSE: {:.4f} \t NLL: {:.4f} \t train_loss: {:.4f} \t t: {:.2f} sec".format(self.iter, self.min_rmse, self.min_nll, loss, time.time()-tic)) 158 | tic = time.time() 159 | 160 | return self.models, self.min_rmse, self.min_nll 161 | -------------------------------------------------------------------------------- /neurips2020/trainers/evidential.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import time 4 | import datetime 5 | import os 6 | import sys 7 | import h5py 8 | from pathlib import Path 9 | 10 | import evidential_deep_learning as edl 11 | from .util import normalize, gallery 12 | 13 | class Evidential: 14 | def __init__(self, model, opts, dataset="", learning_rate=1e-3, lam=0.0, epsilon=1e-2, maxi_rate=1e-4, tag=""): 15 | self.nll_loss_function = edl.losses.NIG_NLL 16 | self.reg_loss_function = edl.losses.NIG_Reg 17 | 18 | self.model = model 19 | self.learning_rate = learning_rate 20 | self.maxi_rate = maxi_rate 21 | 22 | self.optimizer = tf.optimizers.Adam(self.learning_rate) 23 | self.lam = tf.Variable(lam) 24 | 25 | self.epsilon = epsilon 26 | 27 | self.min_rmse = self.running_rmse = float('inf') 28 | self.min_nll = self.running_nll = float('inf') 29 | self.min_vloss = self.running_vloss = float('inf') 30 | 31 | trainer = self.__class__.__name__ 32 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 33 | self.save_dir = os.path.join('save','{}_{}_{}_{}'.format(current_time, dataset, trainer, tag)) 34 | Path(self.save_dir).mkdir(parents=True, exist_ok=True) 35 | 36 | train_log_dir = os.path.join('logs', '{}_{}_{}_{}_train'.format(current_time, dataset, trainer, tag)) 37 | self.train_summary_writer = tf.summary.create_file_writer(train_log_dir) 38 | val_log_dir = os.path.join('logs', '{}_{}_{}_{}_val'.format(current_time, dataset, trainer, tag)) 39 | self.val_summary_writer = tf.summary.create_file_writer(val_log_dir) 40 | 41 | def loss_function(self, y, mu, v, alpha, beta, reduce=True, return_comps=False): 42 | nll_loss = self.nll_loss_function(y, mu, v, alpha, beta, reduce=reduce) 43 | reg_loss = self.reg_loss_function(y, mu, v, alpha, beta, reduce=reduce) 44 | loss = nll_loss + self.lam * (reg_loss - self.epsilon) 45 | # loss = nll_loss 46 | 47 | return (loss, (nll_loss, reg_loss)) if return_comps else loss 48 | 49 | @tf.function 50 | def run_train_step(self, x, y): 51 | with tf.GradientTape() as tape: 52 | outputs = self.model(x, training=True) 53 | mu, v, alpha, beta = tf.split(outputs, 4, axis=-1) 54 | loss, (nll_loss, reg_loss) = self.loss_function(y, mu, v, alpha, beta, return_comps=True) 55 | 56 | grads = tape.gradient(loss, self.model.trainable_variables) #compute gradient 57 | self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables)) 58 | self.lam = self.lam.assign_add(self.maxi_rate * (reg_loss - self.epsilon)) #update lambda 59 | 60 | return loss, nll_loss, reg_loss, mu, v, alpha, beta 61 | 62 | @tf.function 63 | def evaluate(self, x, y): 64 | outputs = self.model(x, training=False) 65 | mu, v, alpha, beta = tf.split(outputs, 4, axis=-1) 66 | 67 | rmse = edl.losses.RMSE(y, mu) 68 | loss, (nll, reg_loss) = self.loss_function(y, mu, v, alpha, beta, return_comps=True) 69 | 70 | return mu, v, alpha, beta, loss, rmse, nll, reg_loss 71 | 72 | def normalize(self, x): 73 | return tf.divide(tf.subtract(x, tf.reduce_min(x)), 74 | tf.subtract(tf.reduce_max(x), tf.reduce_min(x))) 75 | 76 | 77 | def save_train_summary(self, loss, x, y, y_hat, v, alpha, beta): 78 | with self.train_summary_writer.as_default(): 79 | tf.summary.scalar('mse', tf.reduce_mean(edl.losses.MSE(y, y_hat)), step=self.iter) 80 | tf.summary.scalar('loss', tf.reduce_mean(self.loss_function(y, y_hat, v, alpha, beta)), step=self.iter) 81 | idx = np.random.choice(int(tf.shape(x)[0]), 9) 82 | if tf.shape(x).shape==4: 83 | tf.summary.image("x", [gallery(tf.gather(x,idx).numpy())], max_outputs=1, step=self.iter) 84 | 85 | if tf.shape(y).shape==4: 86 | tf.summary.image("y", [gallery(tf.gather(y,idx).numpy())], max_outputs=1, step=self.iter) 87 | tf.summary.image("y_hat", [gallery(tf.gather(y_hat,idx).numpy())], max_outputs=1, step=self.iter) 88 | 89 | def save_val_summary(self, loss, x, y, mu, v, alpha, beta): 90 | with self.val_summary_writer.as_default(): 91 | tf.summary.scalar('mse', tf.reduce_mean(edl.losses.MSE(y, mu)), step=self.iter) 92 | tf.summary.scalar('loss', tf.reduce_mean(self.loss_function(y, mu, v, alpha, beta)), step=self.iter) 93 | idx = np.random.choice(int(tf.shape(x)[0]), 9) 94 | if tf.shape(x).shape==4: 95 | tf.summary.image("x", [gallery(tf.gather(x,idx).numpy())], max_outputs=1, step=self.iter) 96 | 97 | if tf.shape(y).shape==4: 98 | tf.summary.image("y", [gallery(tf.gather(y,idx).numpy())], max_outputs=1, step=self.iter) 99 | tf.summary.image("y_hat", [gallery(tf.gather(mu,idx).numpy())], max_outputs=1, step=self.iter) 100 | var = beta/(v*(alpha-1)) 101 | tf.summary.image("y_var", [gallery(normalize(tf.gather(var,idx)).numpy())], max_outputs=1, step=self.iter) 102 | 103 | def get_batch(self, x, y, batch_size): 104 | idx = np.random.choice(x.shape[0], batch_size, replace=False) 105 | if isinstance(x, tf.Tensor): 106 | x_ = x[idx,...] 107 | y_ = y[idx,...] 108 | elif isinstance(x, np.ndarray) or isinstance(x, h5py.Dataset): 109 | idx = np.sort(idx) 110 | x_ = x[idx,...] 111 | y_ = y[idx,...] 112 | 113 | x_divisor = 255. if x_.dtype == np.uint8 else 1.0 114 | y_divisor = 255. if y_.dtype == np.uint8 else 1.0 115 | 116 | x_ = tf.convert_to_tensor(x_/x_divisor, tf.float32) 117 | y_ = tf.convert_to_tensor(y_/y_divisor, tf.float32) 118 | else: 119 | print("unknown dataset type {} {}".format(type(x), type(y))) 120 | return x_, y_ 121 | 122 | def save(self, name): 123 | self.model.save(os.path.join(self.save_dir, "{}.h5".format(name))) 124 | 125 | def update_running(self, previous, current, alpha=0.0): 126 | if previous == float('inf'): 127 | new = current 128 | else: 129 | new = alpha*previous + (1-alpha)*current 130 | return new 131 | 132 | def train(self, x_train, y_train, x_test, y_test, y_scale, batch_size=128, iters=10000, verbose=True): 133 | tic = time.time() 134 | for self.iter in range(iters): 135 | x_input_batch, y_input_batch = self.get_batch(x_train, y_train, batch_size) 136 | loss, nll_loss, reg_loss, y_hat, v, alpha, beta = self.run_train_step(x_input_batch, y_input_batch) 137 | 138 | if self.iter % 10 == 0: 139 | self.save_train_summary(loss, x_input_batch, y_input_batch, y_hat, v, alpha, beta) 140 | 141 | if self.iter % 100 == 0: 142 | x_test_batch, y_test_batch = self.get_batch(x_test, y_test, min(100, x_test.shape[0])) 143 | mu, v, alpha, beta, vloss, rmse, nll, reg_loss = self.evaluate(x_test_batch, y_test_batch) 144 | 145 | nll += np.log(y_scale[0,0]) 146 | rmse *= y_scale[0,0] 147 | 148 | self.save_val_summary(vloss, x_test_batch, y_test_batch, mu, v, alpha, beta) 149 | 150 | self.running_rmse = self.update_running(self.running_rmse, rmse.numpy()) 151 | if self.running_rmse < self.min_rmse: 152 | self.min_rmse = self.running_rmse 153 | self.save(f"model_rmse_{self.iter}") 154 | 155 | self.running_nll = self.update_running(self.running_nll, nll.numpy()) 156 | if self.running_nll < self.min_nll: 157 | self.min_nll = self.running_nll 158 | self.save(f"model_nll_{self.iter}") 159 | 160 | self.running_vloss = self.update_running(self.running_vloss, vloss.numpy()) 161 | if self.running_vloss < self.min_vloss: 162 | self.min_vloss = self.running_vloss 163 | self.save(f"model_vloss_{self.iter}") 164 | 165 | if verbose: print("[{}] RMSE: {:.4f} \t NLL: {:.4f} \t loss: {:.4f} \t reg_loss: {:.4f} \t lambda: {:.2f} \t t: {:.2f} sec".format(self.iter, self.min_rmse, self.min_nll, vloss, reg_loss.numpy().mean(), self.lam.numpy(), time.time()-tic)) 166 | tic = time.time() 167 | 168 | return self.model, self.min_rmse, self.min_nll 169 | -------------------------------------------------------------------------------- /neurips2020/run_cubic_tests.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import random 4 | import tensorflow as tf 5 | import tensorflow_probability as tfp 6 | from pathlib import Path 7 | import os 8 | 9 | import evidential_deep_learning as edl 10 | import data_loader 11 | import trainers 12 | import models 13 | 14 | seed = 1234 15 | random.seed(seed) 16 | np.random.seed(seed) 17 | tf.random.set_seed(seed) 18 | 19 | save_fig_dir = "./figs/toy" 20 | batch_size = 128 21 | iterations = 5000 22 | show = True 23 | 24 | noise_changing = False 25 | train_bounds = [[-4, 4]] 26 | x_train = np.concatenate([np.linspace(xmin, xmax, 1000) for (xmin, xmax) in train_bounds]).reshape(-1,1) 27 | y_train, sigma_train = data_loader.generate_cubic(x_train, noise=True) 28 | 29 | test_bounds = [[-7,+7]] 30 | x_test = np.concatenate([np.linspace(xmin, xmax, 1000) for (xmin, xmax) in test_bounds]).reshape(-1,1) 31 | y_test, sigma_test = data_loader.generate_cubic(x_test, noise=False) 32 | 33 | ### Plotting helper functions ### 34 | def plot_scatter_with_var(mu, var, path, n_stds=3): 35 | plt.scatter(x_train, y_train, s=1., c='#463c3c', zorder=0) 36 | for k in np.linspace(0, n_stds, 4): 37 | plt.fill_between(x_test[:,0], (mu-k*var)[:,0], (mu+k*var)[:,0], alpha=0.3, edgecolor=None, facecolor='#00aeef', linewidth=0, antialiased=True, zorder=1) 38 | 39 | plt.plot(x_test, y_test, 'r--', zorder=2) 40 | plt.plot(x_test, mu, color='#007cab', zorder=3) 41 | plt.gca().set_xlim(*test_bounds) 42 | plt.gca().set_ylim(-150,150) 43 | plt.title(path) 44 | plt.savefig(path, transparent=True) 45 | if show: 46 | plt.show() 47 | plt.clf() 48 | 49 | def plot_ng(model, save="ng", ext=".pdf"): 50 | x_test_input = tf.convert_to_tensor(x_test, tf.float32) 51 | outputs = model(x_test_input) 52 | mu, v, alpha, beta = tf.split(outputs, 4, axis=1) 53 | 54 | epistemic = np.sqrt(beta/(v*(alpha-1))) 55 | epistemic = np.minimum(epistemic, 1e3) # clip the unc for vis 56 | plot_scatter_with_var(mu, epistemic, path=save+ext, n_stds=3) 57 | 58 | def plot_ensemble(models, save="ensemble", ext=".pdf"): 59 | x_test_input = tf.convert_to_tensor(x_test, tf.float32) 60 | preds = tf.stack([model(x_test_input, training=False) for model in models], axis=0) #forward pass 61 | mus, sigmas = tf.split(preds, 2, axis=-1) 62 | 63 | mean_mu = tf.reduce_mean(mus, axis=0) 64 | epistemic = tf.math.reduce_std(mus, axis=0) + tf.reduce_mean(sigmas, axis=0) 65 | plot_scatter_with_var(mean_mu, epistemic, path=save+ext, n_stds=3) 66 | 67 | def plot_dropout(model, save="dropout", ext=".pdf"): 68 | x_test_input = tf.convert_to_tensor(x_test, tf.float32) 69 | preds = tf.stack([model(x_test_input, training=True) for _ in range(15)], axis=0) #forward pass 70 | mus, logvar = tf.split(preds, 2, axis=-1) 71 | var = tf.exp(logvar) 72 | 73 | mean_mu = tf.reduce_mean(mus, axis=0) 74 | epistemic = tf.math.reduce_std(mus, axis=0) + tf.reduce_mean(var**0.5, axis=0) 75 | plot_scatter_with_var(mean_mu, epistemic, path=save+ext, n_stds=3) 76 | 77 | def plot_bbbp(model, save="bbbp", ext=".pdf"): 78 | x_test_input = tf.convert_to_tensor(x_test, tf.float32) 79 | preds = tf.stack([model(x_test_input, training=True) for _ in range(15)], axis=0) #forward pass 80 | 81 | mean_mu = tf.reduce_mean(preds, axis=0) 82 | epistemic = tf.math.reduce_std(preds, axis=0) 83 | plot_scatter_with_var(mean_mu, epistemic, path=save+ext, n_stds=3) 84 | 85 | def plot_gaussian(model, save="gaussian", ext=".pdf"): 86 | x_test_input = tf.convert_to_tensor(x_test, tf.float32) 87 | preds = model(x_test_input, training=False) #forward pass 88 | mu, sigma = tf.split(preds, 2, axis=-1) 89 | plot_scatter_with_var(mu, sigma, path=save+ext, n_stds=3) 90 | 91 | 92 | 93 | #### Different toy configurations to train and plot 94 | def evidence_reg_2_layers_50_neurons(): 95 | trainer_obj = trainers.Evidential 96 | model_generator = models.get_correct_model(dataset="toy", trainer=trainer_obj) 97 | model, opts = model_generator.create(input_shape=1, num_neurons=50, num_layers=2) 98 | trainer = trainer_obj(model, opts, learning_rate=5e-3, lam=1e-2, maxi_rate=0.) 99 | model, rmse, nll = trainer.train(x_train, y_train, x_train, y_train, np.array([[1.]]), iters=iterations, batch_size=batch_size, verbose=True) 100 | plot_ng(model, os.path.join(save_fig_dir,"evidence_reg_2_layer_50_neurons")) 101 | 102 | def evidence_reg_2_layers_100_neurons(): 103 | trainer_obj = trainers.Evidential 104 | model_generator = models.get_correct_model(dataset="toy", trainer=trainer_obj) 105 | model, opts = model_generator.create(input_shape=1, num_neurons=100, num_layers=2) 106 | trainer = trainer_obj(model, opts, learning_rate=5e-3, lam=1e-2, maxi_rate=0.) 107 | model, rmse, nll = trainer.train(x_train, y_train, x_train, y_train, np.array([[1.]]), iters=iterations, batch_size=batch_size, verbose=True) 108 | plot_ng(model, os.path.join(save_fig_dir,"evidence_reg_2_layers_100_neurons")) 109 | 110 | def evidence_reg_4_layers_50_neurons(): 111 | trainer_obj = trainers.Evidential 112 | model_generator = models.get_correct_model(dataset="toy", trainer=trainer_obj) 113 | model, opts = model_generator.create(input_shape=1, num_neurons=50, num_layers=4) 114 | trainer = trainer_obj(model, opts, learning_rate=5e-3, lam=1e-2, maxi_rate=0.) 115 | model, rmse, nll = trainer.train(x_train, y_train, x_train, y_train, np.array([[1.]]), iters=iterations, batch_size=batch_size, verbose=True) 116 | plot_ng(model, os.path.join(save_fig_dir,"evidence_reg_4_layers_50_neurons")) 117 | 118 | def evidence_reg_4_layers_100_neurons(): 119 | trainer_obj = trainers.Evidential 120 | model_generator = models.get_correct_model(dataset="toy", trainer=trainer_obj) 121 | model, opts = model_generator.create(input_shape=1, num_neurons=100, num_layers=4) 122 | trainer = trainer_obj(model, opts, learning_rate=5e-3, lam=1e-2, maxi_rate=0.) 123 | model, rmse, nll = trainer.train(x_train, y_train, x_train, y_train, np.array([[1.]]), iters=iterations, batch_size=batch_size, verbose=True) 124 | plot_ng(model, os.path.join(save_fig_dir,"evidence_reg_4_layers_100_neurons")) 125 | 126 | def evidence_noreg_4_layers_50_neurons(): 127 | trainer_obj = trainers.Evidential 128 | model_generator = models.get_correct_model(dataset="toy", trainer=trainer_obj) 129 | model, opts = model_generator.create(input_shape=1, num_neurons=50, num_layers=4) 130 | trainer = trainer_obj(model, opts, learning_rate=5e-3, lam=0., maxi_rate=0.) 131 | model, rmse, nll = trainer.train(x_train, y_train, x_train, y_train, np.array([[1.]]), iters=iterations, batch_size=batch_size, verbose=True) 132 | plot_ng(model, os.path.join(save_fig_dir,"evidence_noreg_4_layers_50_neurons")) 133 | 134 | def evidence_noreg_4_layers_100_neurons(): 135 | trainer_obj = trainers.Evidential 136 | model_generator = models.get_correct_model(dataset="toy", trainer=trainer_obj) 137 | model, opts = model_generator.create(input_shape=1, num_neurons=100, num_layers=4) 138 | trainer = trainer_obj(model, opts, learning_rate=5e-3, lam=0., maxi_rate=0.) 139 | model, rmse, nll = trainer.train(x_train, y_train, x_train, y_train, np.array([[1.]]), iters=iterations, batch_size=batch_size, verbose=True) 140 | plot_ng(model, os.path.join(save_fig_dir,"evidence_noreg_4_layers_100_neurons")) 141 | 142 | def ensemble_4_layers_100_neurons(): 143 | trainer_obj = trainers.Ensemble 144 | model_generator = models.get_correct_model(dataset="toy", trainer=trainer_obj) 145 | model, opts = model_generator.create(input_shape=1, num_neurons=100, num_layers=4) 146 | trainer = trainer_obj(model, opts, learning_rate=5e-3) 147 | model, rmse, nll = trainer.train(x_train, y_train, x_train, y_train, np.array([[1.]]), iters=iterations, batch_size=batch_size, verbose=True) 148 | plot_ensemble(model, os.path.join(save_fig_dir,"ensemble_4_layers_100_neurons")) 149 | 150 | def gaussian_4_layers_100_neurons(): 151 | trainer_obj = trainers.Gaussian 152 | model_generator = models.get_correct_model(dataset="toy", trainer=trainer_obj) 153 | model, opts = model_generator.create(input_shape=1, num_neurons=100, num_layers=4) 154 | trainer = trainer_obj(model, opts, learning_rate=5e-3) 155 | model, rmse, nll = trainer.train(x_train, y_train, x_train, y_train, np.array([[1.]]), iters=iterations, batch_size=batch_size, verbose=True) 156 | plot_gaussian(model, os.path.join(save_fig_dir,"gaussian_4_layers_100_neurons")) 157 | 158 | def dropout_4_layers_100_neurons(): 159 | trainer_obj = trainers.Dropout 160 | model_generator = models.get_correct_model(dataset="toy", trainer=trainer_obj) 161 | model, opts = model_generator.create(input_shape=1, num_neurons=100, num_layers=4, sigma=True) 162 | trainer = trainer_obj(model, opts, learning_rate=5e-3) 163 | model, rmse, nll = trainer.train(x_train, y_train, x_train, y_train, np.array([[1.]]), iters=iterations, batch_size=batch_size, verbose=True) 164 | plot_dropout(model, os.path.join(save_fig_dir,"dropout_4_layers_100_neurons")) 165 | 166 | def bbbp_4_layers_100_neurons(): 167 | trainer_obj = trainers.BBBP 168 | model_generator = models.get_correct_model(dataset="toy", trainer=trainer_obj) 169 | model, opts = model_generator.create(input_shape=1, num_neurons=100, num_layers=4) 170 | trainer = trainer_obj(model, opts, learning_rate=1e-3) 171 | model, rmse, nll = trainer.train(x_train, y_train, x_train, y_train, np.array([[1.]]), iters=iterations, batch_size=batch_size, verbose=True) 172 | plot_bbbp(model, os.path.join(save_fig_dir,"bbbp_4_layers_100_neurons")) 173 | 174 | 175 | ### Main file to run the different methods and compare results ### 176 | if __name__ == "__main__": 177 | Path(save_fig_dir).mkdir(parents=True, exist_ok=True) 178 | 179 | evidence_reg_4_layers_100_neurons() 180 | # evidence_noreg_4_layers_100_neurons() 181 | 182 | # ensemble_4_layers_100_neurons() 183 | # gaussian_4_layers_100_neurons() 184 | # dropout_4_layers_100_neurons() 185 | # bbbp_4_layers_100_neurons() 186 | 187 | print(f"Done! Figures saved to {save_fig_dir}") 188 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2019 The Keras Tuner Authors. 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /neurips2020/data/uci/yacht/yacht_hydrodynamics.data: -------------------------------------------------------------------------------- 1 | -2.3 0.568 4.78 3.99 3.17 0.125 0.11 2 | -2.3 0.568 4.78 3.99 3.17 0.150 0.27 3 | -2.3 0.568 4.78 3.99 3.17 0.175 0.47 4 | -2.3 0.568 4.78 3.99 3.17 0.200 0.78 5 | -2.3 0.568 4.78 3.99 3.17 0.225 1.18 6 | -2.3 0.568 4.78 3.99 3.17 0.250 1.82 7 | -2.3 0.568 4.78 3.99 3.17 0.275 2.61 8 | -2.3 0.568 4.78 3.99 3.17 0.300 3.76 9 | -2.3 0.568 4.78 3.99 3.17 0.325 4.99 10 | -2.3 0.568 4.78 3.99 3.17 0.350 7.16 11 | -2.3 0.568 4.78 3.99 3.17 0.375 11.93 12 | -2.3 0.568 4.78 3.99 3.17 0.400 20.11 13 | -2.3 0.568 4.78 3.99 3.17 0.425 32.75 14 | -2.3 0.568 4.78 3.99 3.17 0.450 49.49 15 | -2.3 0.569 4.78 3.04 3.64 0.125 0.04 16 | -2.3 0.569 4.78 3.04 3.64 0.150 0.17 17 | -2.3 0.569 4.78 3.04 3.64 0.175 0.37 18 | -2.3 0.569 4.78 3.04 3.64 0.200 0.66 19 | -2.3 0.569 4.78 3.04 3.64 0.225 1.06 20 | -2.3 0.569 4.78 3.04 3.64 0.250 1.59 21 | -2.3 0.569 4.78 3.04 3.64 0.275 2.33 22 | -2.3 0.569 4.78 3.04 3.64 0.300 3.29 23 | -2.3 0.569 4.78 3.04 3.64 0.325 4.61 24 | -2.3 0.569 4.78 3.04 3.64 0.350 7.11 25 | -2.3 0.569 4.78 3.04 3.64 0.375 11.99 26 | -2.3 0.569 4.78 3.04 3.64 0.400 21.09 27 | -2.3 0.569 4.78 3.04 3.64 0.425 35.01 28 | -2.3 0.569 4.78 3.04 3.64 0.450 51.80 29 | -2.3 0.565 4.78 5.35 2.76 0.125 0.09 30 | -2.3 0.565 4.78 5.35 2.76 0.150 0.29 31 | -2.3 0.565 4.78 5.35 2.76 0.175 0.56 32 | -2.3 0.565 4.78 5.35 2.76 0.200 0.86 33 | -2.3 0.565 4.78 5.35 2.76 0.225 1.31 34 | -2.3 0.565 4.78 5.35 2.76 0.250 1.99 35 | -2.3 0.565 4.78 5.35 2.76 0.275 2.94 36 | -2.3 0.565 4.78 5.35 2.76 0.300 4.21 37 | -2.3 0.565 4.78 5.35 2.76 0.325 5.54 38 | -2.3 0.565 4.78 5.35 2.76 0.350 8.25 39 | -2.3 0.565 4.78 5.35 2.76 0.375 13.08 40 | -2.3 0.565 4.78 5.35 2.76 0.400 21.40 41 | -2.3 0.565 4.78 5.35 2.76 0.425 33.14 42 | -2.3 0.565 4.78 5.35 2.76 0.450 50.14 43 | -2.3 0.564 5.10 3.95 3.53 0.125 0.20 44 | -2.3 0.564 5.10 3.95 3.53 0.150 0.35 45 | -2.3 0.564 5.10 3.95 3.53 0.175 0.65 46 | -2.3 0.564 5.10 3.95 3.53 0.200 0.93 47 | -2.3 0.564 5.10 3.95 3.53 0.225 1.37 48 | -2.3 0.564 5.10 3.95 3.53 0.250 1.97 49 | -2.3 0.564 5.10 3.95 3.53 0.275 2.83 50 | -2.3 0.564 5.10 3.95 3.53 0.300 3.99 51 | -2.3 0.564 5.10 3.95 3.53 0.325 5.19 52 | -2.3 0.564 5.10 3.95 3.53 0.350 8.03 53 | -2.3 0.564 5.10 3.95 3.53 0.375 12.86 54 | -2.3 0.564 5.10 3.95 3.53 0.400 21.51 55 | -2.3 0.564 5.10 3.95 3.53 0.425 33.97 56 | -2.3 0.564 5.10 3.95 3.53 0.450 50.36 57 | -2.4 0.574 4.36 3.96 2.76 0.125 0.20 58 | -2.4 0.574 4.36 3.96 2.76 0.150 0.35 59 | -2.4 0.574 4.36 3.96 2.76 0.175 0.65 60 | -2.4 0.574 4.36 3.96 2.76 0.200 0.93 61 | -2.4 0.574 4.36 3.96 2.76 0.225 1.37 62 | -2.4 0.574 4.36 3.96 2.76 0.250 1.97 63 | -2.4 0.574 4.36 3.96 2.76 0.275 2.83 64 | -2.4 0.574 4.36 3.96 2.76 0.300 3.99 65 | -2.4 0.574 4.36 3.96 2.76 0.325 5.19 66 | -2.4 0.574 4.36 3.96 2.76 0.350 8.03 67 | -2.4 0.574 4.36 3.96 2.76 0.375 12.86 68 | -2.4 0.574 4.36 3.96 2.76 0.400 21.51 69 | -2.4 0.574 4.36 3.96 2.76 0.425 33.97 70 | -2.4 0.574 4.36 3.96 2.76 0.450 50.36 71 | -2.4 0.568 4.34 2.98 3.15 0.125 0.12 72 | -2.4 0.568 4.34 2.98 3.15 0.150 0.26 73 | -2.4 0.568 4.34 2.98 3.15 0.175 0.43 74 | -2.4 0.568 4.34 2.98 3.15 0.200 0.69 75 | -2.4 0.568 4.34 2.98 3.15 0.225 1.09 76 | -2.4 0.568 4.34 2.98 3.15 0.250 1.67 77 | -2.4 0.568 4.34 2.98 3.15 0.275 2.46 78 | -2.4 0.568 4.34 2.98 3.15 0.300 3.43 79 | -2.4 0.568 4.34 2.98 3.15 0.325 4.62 80 | -2.4 0.568 4.34 2.98 3.15 0.350 6.86 81 | -2.4 0.568 4.34 2.98 3.15 0.375 11.56 82 | -2.4 0.568 4.34 2.98 3.15 0.400 20.63 83 | -2.4 0.568 4.34 2.98 3.15 0.425 34.50 84 | -2.4 0.568 4.34 2.98 3.15 0.450 54.23 85 | -2.3 0.562 5.14 4.95 3.17 0.125 0.28 86 | -2.3 0.562 5.14 4.95 3.17 0.150 0.44 87 | -2.3 0.562 5.14 4.95 3.17 0.175 0.70 88 | -2.3 0.562 5.14 4.95 3.17 0.200 1.07 89 | -2.3 0.562 5.14 4.95 3.17 0.225 1.57 90 | -2.3 0.562 5.14 4.95 3.17 0.250 2.23 91 | -2.3 0.562 5.14 4.95 3.17 0.275 3.09 92 | -2.3 0.562 5.14 4.95 3.17 0.300 4.09 93 | -2.3 0.562 5.14 4.95 3.17 0.325 5.82 94 | -2.3 0.562 5.14 4.95 3.17 0.350 8.28 95 | -2.3 0.562 5.14 4.95 3.17 0.375 12.80 96 | -2.3 0.562 5.14 4.95 3.17 0.400 20.41 97 | -2.3 0.562 5.14 4.95 3.17 0.425 32.34 98 | -2.3 0.562 5.14 4.95 3.17 0.450 47.29 99 | -2.4 0.585 4.78 3.84 3.32 0.125 0.20 100 | -2.4 0.585 4.78 3.84 3.32 0.150 0.38 101 | -2.4 0.585 4.78 3.84 3.32 0.175 0.64 102 | -2.4 0.585 4.78 3.84 3.32 0.200 0.97 103 | -2.4 0.585 4.78 3.84 3.32 0.225 1.36 104 | -2.4 0.585 4.78 3.84 3.32 0.250 1.98 105 | -2.4 0.585 4.78 3.84 3.32 0.275 2.91 106 | -2.4 0.585 4.78 3.84 3.32 0.300 4.35 107 | -2.4 0.585 4.78 3.84 3.32 0.325 5.79 108 | -2.4 0.585 4.78 3.84 3.32 0.350 8.04 109 | -2.4 0.585 4.78 3.84 3.32 0.375 12.15 110 | -2.4 0.585 4.78 3.84 3.32 0.400 19.18 111 | -2.4 0.585 4.78 3.84 3.32 0.425 30.09 112 | -2.4 0.585 4.78 3.84 3.32 0.450 44.38 113 | -2.2 0.546 4.78 4.13 3.07 0.125 0.15 114 | -2.2 0.546 4.78 4.13 3.07 0.150 0.32 115 | -2.2 0.546 4.78 4.13 3.07 0.175 0.55 116 | -2.2 0.546 4.78 4.13 3.07 0.200 0.86 117 | -2.2 0.546 4.78 4.13 3.07 0.225 1.24 118 | -2.2 0.546 4.78 4.13 3.07 0.250 1.76 119 | -2.2 0.546 4.78 4.13 3.07 0.275 2.49 120 | -2.2 0.546 4.78 4.13 3.07 0.300 3.45 121 | -2.2 0.546 4.78 4.13 3.07 0.325 4.83 122 | -2.2 0.546 4.78 4.13 3.07 0.350 7.37 123 | -2.2 0.546 4.78 4.13 3.07 0.375 12.76 124 | -2.2 0.546 4.78 4.13 3.07 0.400 21.99 125 | -2.2 0.546 4.78 4.13 3.07 0.425 35.64 126 | -2.2 0.546 4.78 4.13 3.07 0.450 53.07 127 | 0.0 0.565 4.77 3.99 3.15 0.125 0.11 128 | 0.0 0.565 4.77 3.99 3.15 0.150 0.24 129 | 0.0 0.565 4.77 3.99 3.15 0.175 0.49 130 | 0.0 0.565 4.77 3.99 3.15 0.200 0.79 131 | 0.0 0.565 4.77 3.99 3.15 0.225 1.28 132 | 0.0 0.565 4.77 3.99 3.15 0.250 1.96 133 | 0.0 0.565 4.77 3.99 3.15 0.275 2.88 134 | 0.0 0.565 4.77 3.99 3.15 0.300 4.14 135 | 0.0 0.565 4.77 3.99 3.15 0.325 5.96 136 | 0.0 0.565 4.77 3.99 3.15 0.350 9.07 137 | 0.0 0.565 4.77 3.99 3.15 0.375 14.93 138 | 0.0 0.565 4.77 3.99 3.15 0.400 24.13 139 | 0.0 0.565 4.77 3.99 3.15 0.425 38.12 140 | 0.0 0.565 4.77 3.99 3.15 0.450 55.44 141 | -5.0 0.565 4.77 3.99 3.15 0.125 0.07 142 | -5.0 0.565 4.77 3.99 3.15 0.150 0.18 143 | -5.0 0.565 4.77 3.99 3.15 0.175 0.40 144 | -5.0 0.565 4.77 3.99 3.15 0.200 0.70 145 | -5.0 0.565 4.77 3.99 3.15 0.225 1.14 146 | -5.0 0.565 4.77 3.99 3.15 0.250 1.83 147 | -5.0 0.565 4.77 3.99 3.15 0.275 2.77 148 | -5.0 0.565 4.77 3.99 3.15 0.300 4.12 149 | -5.0 0.565 4.77 3.99 3.15 0.325 5.41 150 | -5.0 0.565 4.77 3.99 3.15 0.350 7.87 151 | -5.0 0.565 4.77 3.99 3.15 0.375 12.71 152 | -5.0 0.565 4.77 3.99 3.15 0.400 21.02 153 | -5.0 0.565 4.77 3.99 3.15 0.425 34.58 154 | -5.0 0.565 4.77 3.99 3.15 0.450 51.77 155 | 0.0 0.565 5.10 3.94 3.51 0.125 0.08 156 | 0.0 0.565 5.10 3.94 3.51 0.150 0.26 157 | 0.0 0.565 5.10 3.94 3.51 0.175 0.50 158 | 0.0 0.565 5.10 3.94 3.51 0.200 0.83 159 | 0.0 0.565 5.10 3.94 3.51 0.225 1.28 160 | 0.0 0.565 5.10 3.94 3.51 0.250 1.90 161 | 0.0 0.565 5.10 3.94 3.51 0.275 2.68 162 | 0.0 0.565 5.10 3.94 3.51 0.300 3.76 163 | 0.0 0.565 5.10 3.94 3.51 0.325 5.57 164 | 0.0 0.565 5.10 3.94 3.51 0.350 8.76 165 | 0.0 0.565 5.10 3.94 3.51 0.375 14.24 166 | 0.0 0.565 5.10 3.94 3.51 0.400 23.05 167 | 0.0 0.565 5.10 3.94 3.51 0.425 35.46 168 | 0.0 0.565 5.10 3.94 3.51 0.450 51.99 169 | -5.0 0.565 5.10 3.94 3.51 0.125 0.08 170 | -5.0 0.565 5.10 3.94 3.51 0.150 0.24 171 | -5.0 0.565 5.10 3.94 3.51 0.175 0.45 172 | -5.0 0.565 5.10 3.94 3.51 0.200 0.77 173 | -5.0 0.565 5.10 3.94 3.51 0.225 1.19 174 | -5.0 0.565 5.10 3.94 3.51 0.250 1.76 175 | -5.0 0.565 5.10 3.94 3.51 0.275 2.59 176 | -5.0 0.565 5.10 3.94 3.51 0.300 3.85 177 | -5.0 0.565 5.10 3.94 3.51 0.325 5.27 178 | -5.0 0.565 5.10 3.94 3.51 0.350 7.74 179 | -5.0 0.565 5.10 3.94 3.51 0.375 12.40 180 | -5.0 0.565 5.10 3.94 3.51 0.400 20.91 181 | -5.0 0.565 5.10 3.94 3.51 0.425 33.23 182 | -5.0 0.565 5.10 3.94 3.51 0.450 49.14 183 | -2.3 0.530 5.11 3.69 3.51 0.125 0.08 184 | -2.3 0.530 5.11 3.69 3.51 0.150 0.25 185 | -2.3 0.530 5.11 3.69 3.51 0.175 0.46 186 | -2.3 0.530 5.11 3.69 3.51 0.200 0.75 187 | -2.3 0.530 5.11 3.69 3.51 0.225 1.11 188 | -2.3 0.530 5.11 3.69 3.51 0.250 1.57 189 | -2.3 0.530 5.11 3.69 3.51 0.275 2.17 190 | -2.3 0.530 5.11 3.69 3.51 0.300 2.98 191 | -2.3 0.530 5.11 3.69 3.51 0.325 4.42 192 | -2.3 0.530 5.11 3.69 3.51 0.350 7.84 193 | -2.3 0.530 5.11 3.69 3.51 0.375 14.11 194 | -2.3 0.530 5.11 3.69 3.51 0.400 24.14 195 | -2.3 0.530 5.11 3.69 3.51 0.425 37.95 196 | -2.3 0.530 5.11 3.69 3.51 0.450 55.17 197 | -2.3 0.530 4.76 3.68 3.16 0.125 0.10 198 | -2.3 0.530 4.76 3.68 3.16 0.150 0.23 199 | -2.3 0.530 4.76 3.68 3.16 0.175 0.47 200 | -2.3 0.530 4.76 3.68 3.16 0.200 0.76 201 | -2.3 0.530 4.76 3.68 3.16 0.225 1.15 202 | -2.3 0.530 4.76 3.68 3.16 0.250 1.65 203 | -2.3 0.530 4.76 3.68 3.16 0.275 2.28 204 | -2.3 0.530 4.76 3.68 3.16 0.300 3.09 205 | -2.3 0.530 4.76 3.68 3.16 0.325 4.41 206 | -2.3 0.530 4.76 3.68 3.16 0.350 7.51 207 | -2.3 0.530 4.76 3.68 3.16 0.375 13.77 208 | -2.3 0.530 4.76 3.68 3.16 0.400 23.96 209 | -2.3 0.530 4.76 3.68 3.16 0.425 37.38 210 | -2.3 0.530 4.76 3.68 3.16 0.450 56.46 211 | -2.3 0.530 4.34 2.81 3.15 0.125 0.05 212 | -2.3 0.530 4.34 2.81 3.15 0.150 0.17 213 | -2.3 0.530 4.34 2.81 3.15 0.175 0.35 214 | -2.3 0.530 4.34 2.81 3.15 0.200 0.63 215 | -2.3 0.530 4.34 2.81 3.15 0.225 1.01 216 | -2.3 0.530 4.34 2.81 3.15 0.250 1.43 217 | -2.3 0.530 4.34 2.81 3.15 0.275 2.05 218 | -2.3 0.530 4.34 2.81 3.15 0.300 2.73 219 | -2.3 0.530 4.34 2.81 3.15 0.325 3.87 220 | -2.3 0.530 4.34 2.81 3.15 0.350 7.19 221 | -2.3 0.530 4.34 2.81 3.15 0.375 13.96 222 | -2.3 0.530 4.34 2.81 3.15 0.400 25.18 223 | -2.3 0.530 4.34 2.81 3.15 0.425 41.34 224 | -2.3 0.530 4.34 2.81 3.15 0.450 62.42 225 | 0.0 0.600 4.78 4.24 3.15 0.125 0.03 226 | 0.0 0.600 4.78 4.24 3.15 0.150 0.18 227 | 0.0 0.600 4.78 4.24 3.15 0.175 0.40 228 | 0.0 0.600 4.78 4.24 3.15 0.200 0.73 229 | 0.0 0.600 4.78 4.24 3.15 0.225 1.30 230 | 0.0 0.600 4.78 4.24 3.15 0.250 2.16 231 | 0.0 0.600 4.78 4.24 3.15 0.275 3.35 232 | 0.0 0.600 4.78 4.24 3.15 0.300 5.06 233 | 0.0 0.600 4.78 4.24 3.15 0.325 7.14 234 | 0.0 0.600 4.78 4.24 3.15 0.350 10.36 235 | 0.0 0.600 4.78 4.24 3.15 0.375 15.25 236 | 0.0 0.600 4.78 4.24 3.15 0.400 23.15 237 | 0.0 0.600 4.78 4.24 3.15 0.425 34.62 238 | 0.0 0.600 4.78 4.24 3.15 0.450 51.50 239 | -5.0 0.600 4.78 4.24 3.15 0.125 0.06 240 | -5.0 0.600 4.78 4.24 3.15 0.150 0.15 241 | -5.0 0.600 4.78 4.24 3.15 0.175 0.34 242 | -5.0 0.600 4.78 4.24 3.15 0.200 0.63 243 | -5.0 0.600 4.78 4.24 3.15 0.225 1.13 244 | -5.0 0.600 4.78 4.24 3.15 0.250 1.85 245 | -5.0 0.600 4.78 4.24 3.15 0.275 2.84 246 | -5.0 0.600 4.78 4.24 3.15 0.300 4.34 247 | -5.0 0.600 4.78 4.24 3.15 0.325 6.20 248 | -5.0 0.600 4.78 4.24 3.15 0.350 8.62 249 | -5.0 0.600 4.78 4.24 3.15 0.375 12.49 250 | -5.0 0.600 4.78 4.24 3.15 0.400 20.41 251 | -5.0 0.600 4.78 4.24 3.15 0.425 32.46 252 | -5.0 0.600 4.78 4.24 3.15 0.450 50.94 253 | 0.0 0.530 4.78 3.75 3.15 0.125 0.16 254 | 0.0 0.530 4.78 3.75 3.15 0.150 0.32 255 | 0.0 0.530 4.78 3.75 3.15 0.175 0.59 256 | 0.0 0.530 4.78 3.75 3.15 0.200 0.92 257 | 0.0 0.530 4.78 3.75 3.15 0.225 1.37 258 | 0.0 0.530 4.78 3.75 3.15 0.250 1.94 259 | 0.0 0.530 4.78 3.75 3.15 0.275 2.62 260 | 0.0 0.530 4.78 3.75 3.15 0.300 3.70 261 | 0.0 0.530 4.78 3.75 3.15 0.325 5.45 262 | 0.0 0.530 4.78 3.75 3.15 0.350 9.45 263 | 0.0 0.530 4.78 3.75 3.15 0.375 16.31 264 | 0.0 0.530 4.78 3.75 3.15 0.400 27.34 265 | 0.0 0.530 4.78 3.75 3.15 0.425 41.77 266 | 0.0 0.530 4.78 3.75 3.15 0.450 60.85 267 | -5.0 0.530 4.78 3.75 3.15 0.125 0.09 268 | -5.0 0.530 4.78 3.75 3.15 0.150 0.24 269 | -5.0 0.530 4.78 3.75 3.15 0.175 0.47 270 | -5.0 0.530 4.78 3.75 3.15 0.200 0.78 271 | -5.0 0.530 4.78 3.75 3.15 0.225 1.21 272 | -5.0 0.530 4.78 3.75 3.15 0.250 1.85 273 | -5.0 0.530 4.78 3.75 3.15 0.275 2.62 274 | -5.0 0.530 4.78 3.75 3.15 0.300 3.69 275 | -5.0 0.530 4.78 3.75 3.15 0.325 5.07 276 | -5.0 0.530 4.78 3.75 3.15 0.350 7.95 277 | -5.0 0.530 4.78 3.75 3.15 0.375 13.73 278 | -5.0 0.530 4.78 3.75 3.15 0.400 23.55 279 | -5.0 0.530 4.78 3.75 3.15 0.425 37.14 280 | -5.0 0.530 4.78 3.75 3.15 0.450 55.87 281 | -2.3 0.600 5.10 4.17 3.51 0.125 0.01 282 | -2.3 0.600 5.10 4.17 3.51 0.150 0.16 283 | -2.3 0.600 5.10 4.17 3.51 0.175 0.39 284 | -2.3 0.600 5.10 4.17 3.51 0.200 0.73 285 | -2.3 0.600 5.10 4.17 3.51 0.225 1.24 286 | -2.3 0.600 5.10 4.17 3.51 0.250 1.96 287 | -2.3 0.600 5.10 4.17 3.51 0.275 3.04 288 | -2.3 0.600 5.10 4.17 3.51 0.300 4.46 289 | -2.3 0.600 5.10 4.17 3.51 0.325 6.31 290 | -2.3 0.600 5.10 4.17 3.51 0.350 8.68 291 | -2.3 0.600 5.10 4.17 3.51 0.375 12.39 292 | -2.3 0.600 5.10 4.17 3.51 0.400 20.14 293 | -2.3 0.600 5.10 4.17 3.51 0.425 31.77 294 | -2.3 0.600 5.10 4.17 3.51 0.450 47.13 295 | -2.3 0.600 4.34 4.23 2.73 0.125 0.04 296 | -2.3 0.600 4.34 4.23 2.73 0.150 0.17 297 | -2.3 0.600 4.34 4.23 2.73 0.175 0.36 298 | -2.3 0.600 4.34 4.23 2.73 0.200 0.64 299 | -2.3 0.600 4.34 4.23 2.73 0.225 1.02 300 | -2.3 0.600 4.34 4.23 2.73 0.250 1.62 301 | -2.3 0.600 4.34 4.23 2.73 0.275 2.63 302 | -2.3 0.600 4.34 4.23 2.73 0.300 4.15 303 | -2.3 0.600 4.34 4.23 2.73 0.325 6.00 304 | -2.3 0.600 4.34 4.23 2.73 0.350 8.47 305 | -2.3 0.600 4.34 4.23 2.73 0.375 12.27 306 | -2.3 0.600 4.34 4.23 2.73 0.400 19.59 307 | -2.3 0.600 4.34 4.23 2.73 0.425 30.48 308 | -2.3 0.600 4.34 4.23 2.73 0.450 46.66 309 | 310 | -------------------------------------------------------------------------------- /neurips2020/data_loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | IO module for train/test regression datasets 3 | """ 4 | import numpy as np 5 | import pandas as pd 6 | import os 7 | import h5py 8 | import tensorflow as tf 9 | 10 | 11 | 12 | def generate_cubic(x, noise=False): 13 | x = x.astype(np.float32) 14 | y = x**3 15 | 16 | if noise: 17 | sigma = 3 * np.ones_like(x) 18 | else: 19 | sigma = np.zeros_like(x) 20 | r = np.random.normal(0, sigma).astype(np.float32) 21 | return y+r, sigma 22 | 23 | 24 | ##################################### 25 | # individual data files # 26 | ##################################### 27 | vb_dir = os.path.dirname(__file__) 28 | data_dir = os.path.join(vb_dir, "data/uci") 29 | 30 | def _load_boston(): 31 | """ 32 | Attribute Information: 33 | 1. CRIM: per capita crime rate by town 34 | 2. ZN: proportion of residential land zoned for lots over 25,000 sq.ft. 35 | 3. INDUS: proportion of non-retail business acres per town 36 | 4. CHAS: Charles River dummy variable (= 1 if tract bounds river; 0 otherwise) 37 | 5. NOX: nitric oxides concentration (parts per 10 million) 38 | 6. RM: average number of rooms per dwelling 39 | 7. AGE: proportion of owner-occupied units built prior to 1940 40 | 8. DIS: weighted distances to five Boston employment centres 41 | 9. RAD: index of accessibility to radial highways 42 | 10. TAX: full-value property-tax rate per $10,000 43 | 11. PTRATIO: pupil-teacher ratio by town 44 | 12. B: 1000(Bk - 0.63)^2 where Bk is the proportion of blacks by town 45 | 13. LSTAT: % lower status of the population 46 | 14. MEDV: Median value of owner-occupied homes in $1000's 47 | """ 48 | data = np.loadtxt(os.path.join(data_dir, 49 | 'boston-housing/boston_housing.txt')) 50 | X = data[:, :-1] 51 | y = data[:, -1] 52 | return X, y 53 | 54 | 55 | def _load_powerplant(): 56 | """ 57 | attribute information: 58 | features consist of hourly average ambient variables 59 | - temperature (t) in the range 1.81 c and 37.11 c, 60 | - ambient pressure (ap) in the range 992.89-1033.30 millibar, 61 | - relative humidity (rh) in the range 25.56% to 100.16% 62 | - exhaust vacuum (v) in teh range 25.36-81.56 cm hg 63 | - net hourly electrical energy output (ep) 420.26-495.76 mw 64 | the averages are taken from various sensors located around the 65 | plant that record the ambient variables every second. 66 | the variables are given without normalization. 67 | """ 68 | data_file = os.path.join(data_dir, 'power-plant/Folds5x2_pp.xlsx') 69 | data = pd.read_excel(data_file) 70 | x = data.values[:, :-1] 71 | y = data.values[:, -1] 72 | return x, y 73 | 74 | 75 | def _load_concrete(): 76 | """ 77 | Summary Statistics: 78 | Number of instances (observations): 1030 79 | Number of Attributes: 9 80 | Attribute breakdown: 8 quantitative input variables, and 1 quantitative output variable 81 | Missing Attribute Values: None 82 | Name -- Data Type -- Measurement -- Description 83 | Cement (component 1) -- quantitative -- kg in a m3 mixture -- Input Variable 84 | Blast Furnace Slag (component 2) -- quantitative -- kg in a m3 mixture -- Input Variable 85 | Fly Ash (component 3) -- quantitative -- kg in a m3 mixture -- Input Variable 86 | Water (component 4) -- quantitative -- kg in a m3 mixture -- Input Variable 87 | Superplasticizer (component 5) -- quantitative -- kg in a m3 mixture -- Input Variable 88 | Coarse Aggregate (component 6) -- quantitative -- kg in a m3 mixture -- Input Variable 89 | Fine Aggregate (component 7) -- quantitative -- kg in a m3 mixture -- Input Variable 90 | Age -- quantitative -- Day (1~365) -- Input Variable 91 | Concrete compressive strength -- quantitative -- MPa -- Output Variable 92 | --------------------------------- 93 | """ 94 | data_file = os.path.join(data_dir, 'concrete/Concrete_Data.xls') 95 | data = pd.read_excel(data_file) 96 | X = data.values[:, :-1] 97 | y = data.values[:, -1] 98 | return X, y 99 | 100 | 101 | def _load_yacht(): 102 | """ 103 | Attribute Information: 104 | Variations concern hull geometry coefficients and the Froude number: 105 | 1. Longitudinal position of the center of buoyancy, adimensional. 106 | 2. Prismatic coefficient, adimensional. 107 | 3. Length-displacement ratio, adimensional. 108 | 4. Beam-draught ratio, adimensional. 109 | 5. Length-beam ratio, adimensional. 110 | 6. Froude number, adimensional. 111 | The measured variable is the residuary resistance per unit weight of displacement: 112 | 7. Residuary resistance per unit weight of displacement, adimensional. 113 | """ 114 | data_file = os.path.join(data_dir, 'yacht/yacht_hydrodynamics.data') 115 | data = pd.read_csv(data_file, delim_whitespace=True) 116 | X = data.values[:, :-1] 117 | y = data.values[:, -1] 118 | return X, y 119 | 120 | 121 | def _load_energy_efficiency(): 122 | """ 123 | Data Set Information: 124 | We perform energy analysis using 12 different building shapes simulated in 125 | Ecotect. The buildings differ with respect to the glazing area, the 126 | glazing area distribution, and the orientation, amongst other parameters. 127 | We simulate various settings as functions of the afore-mentioned 128 | characteristics to obtain 768 building shapes. The dataset comprises 129 | 768 samples and 8 features, aiming to predict two real valued responses. 130 | It can also be used as a multi-class classification problem if the 131 | response is rounded to the nearest integer. 132 | Attribute Information: 133 | The dataset contains eight attributes (or features, denoted by X1...X8) and two responses (or outcomes, denoted by y1 and y2). The aim is to use the eight features to predict each of the two responses. 134 | Specifically: 135 | X1 Relative Compactness 136 | X2 Surface Area 137 | X3 Wall Area 138 | X4 Roof Area 139 | X5 Overall Height 140 | X6 Orientation 141 | X7 Glazing Area 142 | X8 Glazing Area Distribution 143 | y1 Heating Load 144 | y2 Cooling Load 145 | """ 146 | data_file = os.path.join(data_dir, 'energy-efficiency/ENB2012_data.xlsx') 147 | data = pd.read_excel(data_file) 148 | X = data.values[:, :-2] 149 | y_heating = data.values[:, -2] 150 | y_cooling = data.values[:, -1] 151 | return X, y_cooling 152 | 153 | 154 | def _load_wine(): 155 | """ 156 | Attribute Information: 157 | For more information, read [Cortez et al., 2009]. 158 | Input variables (based on physicochemical tests): 159 | 1 - fixed acidity 160 | 2 - volatile acidity 161 | 3 - citric acid 162 | 4 - residual sugar 163 | 5 - chlorides 164 | 6 - free sulfur dioxide 165 | 7 - total sulfur dioxide 166 | 8 - density 167 | 9 - pH 168 | 10 - sulphates 169 | 11 - alcohol 170 | Output variable (based on sensory data): 171 | 12 - quality (score between 0 and 10) 172 | """ 173 | # data_file = os.path.join(data_dir, 'wine-quality/winequality-red.csv') 174 | data_file = os.path.join(data_dir, 'wine-quality/wine_data_new.txt') 175 | data = pd.read_csv(data_file, sep=' ', header=None) 176 | X = data.values[:, :-1] 177 | y = data.values[:, -1] 178 | return X, y 179 | 180 | def _load_kin8nm(): 181 | """ 182 | This is data set is concerned with the forward kinematics of an 8 link robot arm. Among the existing variants of 183 | this data set we have used the variant 8nm, which is known to be highly non-linear and medium noisy. 184 | 185 | Original source: DELVE repository of data. Source: collection of regression datasets by Luis Torgo 186 | (ltorgo@ncc.up.pt) at http://www.ncc.up.pt/~ltorgo/Regression/DataSets.html Characteristics: 8192 cases, 187 | 9 attributes (0 nominal, 9 continuous). 188 | 189 | Input variables: 190 | 1 - theta1 191 | 2 - theta2 192 | ... 193 | 8 - theta8 194 | Output variable: 195 | 9 - target 196 | """ 197 | data_file = os.path.join(data_dir, 'kin8nm/dataset_2175_kin8nm.csv') 198 | data = pd.read_csv(data_file, sep=',') 199 | X = data.values[:, :-1] 200 | y = data.values[:, -1] 201 | return X, y 202 | 203 | 204 | def _load_naval(): 205 | """ 206 | http://archive.ics.uci.edu/ml/datasets/Condition+Based+Maintenance+of+Naval+Propulsion+Plants 207 | 208 | Input variables: 209 | 1 - Lever position(lp)[] 210 | 2 - Ship speed(v)[knots] 211 | 3 - Gas Turbine shaft torque(GTT)[kNm] 212 | 4 - Gas Turbine rate of revolutions(GTn)[rpm] 213 | 5 - Gas Generator rate of revolutions(GGn)[rpm] 214 | 6 - Starboard Propeller Torque(Ts)[kN] 215 | 7 - Port Propeller Torque(Tp)[kN] 216 | 8 - HP Turbine exit temperature(T48)[C] 217 | 9 - GT Compressor inlet air temperature(T1)[C] 218 | 10 - GT Compressor outlet air temperature(T2)[C] 219 | 11 - HP Turbine exit pressure(P48)[bar] 220 | 12 - GT Compressor inlet air pressure(P1)[bar] 221 | 13 - GT Compressor outlet air pressure(P2)[bar] 222 | 14 - Gas Turbine exhaust gas pressure(Pexh)[bar] 223 | 15 - Turbine Injecton Control(TIC)[ %] 224 | 16 - Fuel flow(mf)[kg / s] 225 | Output variables: 226 | 17 - GT Compressor decay state coefficient. 227 | 18 - GT Turbine decay state coefficient. 228 | """ 229 | data = np.loadtxt(os.path.join(data_dir, 'naval/data.txt')) 230 | X = data[:, :-2] 231 | y_compressor = data[:, -2] 232 | y_turbine = data[:, -1] 233 | return X, y_turbine 234 | 235 | def _load_protein(): 236 | """ 237 | Physicochemical Properties of Protein Tertiary Structure Data Set 238 | Abstract: This is a data set of Physicochemical Properties of Protein Tertiary Structure. 239 | The data set is taken from CASP 5-9. There are 45730 decoys and size varying from 0 to 21 armstrong. 240 | 241 | TODO: Check that the output is correct 242 | 243 | Input variables: 244 | RMSD-Size of the residue. 245 | F1 - Total surface area. 246 | F2 - Non polar exposed area. 247 | F3 - Fractional area of exposed non polar residue. 248 | F4 - Fractional area of exposed non polar part of residue. 249 | F5 - Molecular mass weighted exposed area. 250 | F6 - Average deviation from standard exposed area of residue. 251 | F7 - Euclidian distance. 252 | F8 - Secondary structure penalty. 253 | Output variable: 254 | F9 - Spacial Distribution constraints (N,K Value). 255 | """ 256 | data_file = os.path.join(data_dir, 'protein/CASP.csv') 257 | data = pd.read_csv(data_file, sep=',') 258 | X = data.values[:, 1:] 259 | y = data.values[:, 0] 260 | return X, y 261 | 262 | def _load_song(): 263 | """ 264 | INSTRUCTIONS: 265 | 1) Download from http://archive.ics.uci.edu/ml/datasets/YearPredictionMSD 266 | 2) Place YearPredictionMSD.txt in data/uci/song/ 267 | 268 | Dataloader is slow since file is large. 269 | 270 | YearPredictionMSD Data Set 271 | Prediction of the release year of a song from audio features. Songs are mostly western, commercial tracks ranging 272 | from 1922 to 2011, with a peak in the year 2000s. 273 | 274 | 90 attributes, 12 = timbre average, 78 = timbre covariance 275 | The first value is the year (target), ranging from 1922 to 2011. 276 | Features extracted from the 'timbre' features from The Echo Nest API. 277 | We take the average and covariance over all 'segments', each segment 278 | being described by a 12-dimensional timbre vector. 279 | 280 | """ 281 | data = np.loadtxt(os.path.join(data_dir, 282 | 'song/YearPredictionMSD.txt'), delimiter=',') 283 | X = data[:, :-1] 284 | y = data[:, -1] 285 | return X, y 286 | 287 | 288 | def _load_depth(): 289 | train = h5py.File("data/depth_train.h5", "r") 290 | test = h5py.File("data/depth_test.h5", "r") 291 | return (train["image"], train["depth"]), (test["image"], test["depth"]) 292 | 293 | def load_depth(): 294 | return _load_depth() 295 | 296 | def load_apollo(): 297 | test = h5py.File("data/apolloscape_test.h5", "r") 298 | return (None, None), (test["image"], test["depth"]) 299 | 300 | def load_dataset(name, split_seed=0, test_fraction=.1, return_as_tensor=False): 301 | # load full dataset 302 | load_funs = { "wine" : _load_wine, 303 | "boston" : _load_boston, 304 | "concrete" : _load_concrete, 305 | "power-plant" : _load_powerplant, 306 | "yacht" : _load_yacht, 307 | "energy-efficiency" : _load_energy_efficiency, 308 | "kin8nm" : _load_kin8nm, 309 | "naval" : _load_naval, 310 | "protein" : _load_protein, 311 | "depth" : _load_depth, 312 | "song" : _load_song} 313 | 314 | print("Loading dataset {}....".format(name)) 315 | if name == "depth": 316 | (X_train, y_train), (X_test, y_test) = load_funs[name]() 317 | y_scale = np.array([[1.0]]) 318 | return (X_train, y_train), (X_test, y_test), y_scale 319 | 320 | X, y = load_funs[name]() 321 | X = X.astype(np.float32) 322 | y = y.astype(np.float32) 323 | def standardize(data): 324 | mu = data.mean(axis=0, keepdims=1) 325 | scale = data.std(axis=0, keepdims=1) 326 | scale[scale<1e-10] = 1.0 327 | 328 | data = (data - mu) / scale 329 | return data, mu, scale 330 | 331 | 332 | 333 | # We create the train and test sets with 90% and 10% of the data 334 | 335 | if split_seed == -1: # Do not shuffle! 336 | permutation = range(X.shape[0]) 337 | else: 338 | rs = np.random.RandomState(split_seed) 339 | permutation = rs.permutation(X.shape[0]) 340 | 341 | if name == "boston" or name == "wine": 342 | test_fraction = 0.2 343 | size_train = int(np.round(X.shape[ 0 ] * (1 - test_fraction))) 344 | index_train = permutation[ 0 : size_train ] 345 | index_test = permutation[ size_train : ] 346 | 347 | X_train = X[ index_train, : ] 348 | X_test = X[ index_test, : ] 349 | 350 | if name == "depth": 351 | y_train = y[index_train] 352 | y_test = y[index_test] 353 | else: 354 | y_train = y[index_train, None] 355 | y_test = y[index_test, None] 356 | 357 | 358 | X_train, x_train_mu, x_train_scale = standardize(X_train) 359 | X_test = (X_test - x_train_mu) / x_train_scale 360 | 361 | y_train, y_train_mu, y_train_scale = standardize(y_train) 362 | y_test = (y_test - y_train_mu) / y_train_scale 363 | 364 | if return_as_tensor: 365 | X_train = tf.convert_to_tensor(X_train, tf.float32) 366 | X_test = tf.convert_to_tensor(X_test, tf.float32) 367 | y_train = tf.convert_to_tensor(y_train, tf.float32) 368 | y_test = tf.convert_to_tensor(y_test, tf.float32) 369 | 370 | print("Done loading dataset {}".format(name)) 371 | return (X_train, y_train), (X_test, y_test), y_train_scale 372 | 373 | 374 | 375 | 376 | def load_flight_delay(): 377 | 378 | # Download from here: http://staffwww.dcs.shef.ac.uk/people/N.Lawrence/dataset_mirror/airline_delay/ 379 | data = pd.read_pickle("data/flight-delay/filtered_data.pickle") 380 | y = np.array(data['ArrDelay']) 381 | data.pop('ArrDelay') 382 | X = np.array(data[:]) 383 | 384 | def standardize(data): 385 | data -= data.mean(axis=0, keepdims=1) 386 | scale = data.std(axis=0, keepdims=1) 387 | data /= scale 388 | return data, scale 389 | 390 | X = X[:, np.where(data.var(axis=0) > 0)[0]] 391 | X, _ = standardize(X) 392 | y, y_scale = standardize(y.reshape(-1,1)) 393 | y = np.squeeze(y) 394 | # y_scale = np.array([[1.0]]) 395 | 396 | N = 700000 397 | S = 100000 398 | X_train = X[:N,:] 399 | X_test = X[N:N + S, :] 400 | y_train = y[:N] 401 | y_test = y[N:N + S] 402 | 403 | 404 | return (X_train, y_train), (X_test, y_test), y_scale 405 | 406 | 407 | # (X_train, y_train), (X_test, y_test) = load_dataset('boston') 408 | # (X_train, y_train), (X_test, y_test) = load_dataset('concrete') 409 | # (X_train, y_train), (X_test, y_test) = load_dataset('energy-efficiency') 410 | # (X_train, y_train), (X_test, y_test) = load_dataset('kin8nm') 411 | # (X_train, y_train), (X_test, y_test) = load_dataset('naval') 412 | # (X_train, y_train), (X_test, y_test) = load_dataset('power-plant') 413 | # (X_train, y_train), (X_test, y_test) = load_dataset('protein') 414 | # (X_train, y_train), (X_test, y_test) = load_dataset('wine') 415 | # (X_train, y_train), (X_test, y_test) = load_dataset('yacht', split_seed=-1) 416 | # (X_train, y_train), (X_test, y_test) = load_dataset('song', split_seed=-1) 417 | # (X_train, y_train), (X_test, y_test) = load_dataset('depth') 418 | 419 | # import pdb; pdb.set_trace() 420 | -------------------------------------------------------------------------------- /neurips2020/gen_depth_results.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import defaultdict 3 | import cv2 4 | from enum import Enum 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import os 8 | import pandas as pd 9 | from pathlib import Path 10 | import seaborn as sns 11 | import scipy.stats 12 | import tensorflow as tf 13 | from tqdm import tqdm 14 | 15 | import edl 16 | import models 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--load-pkl", action='store_true', 20 | help="Load predictions for a cached pickle file or \ 21 | recompute from scratch by feeding the data through \ 22 | trained models") 23 | args = parser.parse_args() 24 | 25 | 26 | class Model(Enum): 27 | GroundTruth = "GroundTruth" 28 | Dropout = "Dropout" 29 | Ensemble = "Ensemble" 30 | Evidential = "Evidential" 31 | 32 | save_dir = "pretrained_models" 33 | trained_models = { 34 | Model.Dropout: [ 35 | "dropout/trial1.h5", 36 | "dropout/trial2.h5", 37 | "dropout/trial3.h5", 38 | ], 39 | Model.Ensemble: [ 40 | "ensemble/trial1_*.h5", 41 | "ensemble/trial2_*.h5", 42 | "ensemble/trial3_*.h5", 43 | ], 44 | Model.Evidential: [ 45 | "evidence/trial1.h5", 46 | "evidence/trial2.h5", 47 | "evidence/trial3.h5", 48 | ], 49 | } 50 | output_dir = "figs/depth" 51 | 52 | def compute_predictions(batch_size=50, n_adv=9): 53 | (x_in, y_in), (x_ood, y_ood) = load_data() 54 | datasets = [(x_in, y_in, False), (x_ood, y_ood, True)] 55 | 56 | df_pred_image = pd.DataFrame( 57 | columns=["Method", "Model Path", "Input", 58 | "Target", "Mu", "Sigma", "Adv. Mask", "Epsilon", "OOD"]) 59 | 60 | adv_eps = np.linspace(0, 0.04, n_adv) 61 | 62 | for method, model_path_list in trained_models.items(): 63 | for model_i, model_path in enumerate(model_path_list): 64 | full_path = os.path.join(save_dir, model_path) 65 | model = models.load_depth_model(full_path, compile=False) 66 | 67 | model_log = defaultdict(list) 68 | print(f"Running {model_path}") 69 | 70 | for x, y, ood in datasets: 71 | # max(10,x.shape[0]//500-1) 72 | for start_i in tqdm(np.arange(0, 3*batch_size, batch_size)): 73 | inds = np.arange(start_i, min(start_i+batch_size, x.shape[0]-1)) 74 | x_batch = x[inds]/np.float32(255.) 75 | y_batch = y[inds]/np.float32(255.) 76 | 77 | if ood: 78 | ### Compute predictions and save 79 | summary_to_add = get_prediction_summary( 80 | method, model_path, model, x_batch, y_batch, ood) 81 | df_pred_image = df_pred_image.append(summary_to_add, ignore_index=True) 82 | 83 | else: 84 | ### Compute adversarial mask 85 | # mask_batch = create_adversarial_pattern(model, tf.convert_to_tensor(x_batch), tf.convert_to_tensor(y_batch)) 86 | mask_batch = create_adversarial_pattern(model, x_batch, y_batch) 87 | mask_batch = mask_batch.numpy().astype(np.int8) 88 | 89 | for eps in adv_eps: 90 | ### Apply adversarial noise 91 | x_batch += (eps * mask_batch.astype(np.float32)) 92 | x_batch = np.clip(x_batch, 0, 1) 93 | 94 | ### Compute predictions and save 95 | summary_to_add = get_prediction_summary( 96 | method, model_path, model, x_batch, y_batch, ood, mask_batch, eps) 97 | df_pred_image = df_pred_image.append(summary_to_add, ignore_index=True) 98 | 99 | 100 | 101 | return df_pred_image 102 | 103 | 104 | def get_prediction_summary(method, model_path, model, x_batch, y_batch, ood, mask_batch=None, eps=0.0): 105 | if mask_batch is None: 106 | mask_batch = np.zeros_like(x_batch) 107 | 108 | ### Collect the predictions 109 | mu_batch, sigma_batch = predict(method, model, x_batch) 110 | mu_batch = np.clip(mu_batch, 0, 1) 111 | sigma_batch = sigma_batch.numpy() 112 | 113 | ### Save the predictions to some dataframes for later analysis 114 | summary = [{"Method": method.value, "Model Path": model_path, 115 | "Input": x, "Target": y, "Mu": mu, "Sigma": sigma, 116 | "Adv. Mask": mask, "Epsilon": eps, "OOD": ood} 117 | for x,y,mu,sigma,mask in zip(x_batch, y_batch, mu_batch, sigma_batch, mask_batch)] 118 | return summary 119 | 120 | 121 | def df_image_to_pixels(df, keys=["Target", "Mu", "Sigma"]): 122 | required_keys = ["Method", "Model Path"] 123 | keys = required_keys + keys 124 | key_types = {key: type(df[key].iloc[0]) for key in keys} 125 | max_shape = max([np.prod(np.shape(df[key].iloc[0])) for key in keys]) 126 | 127 | contents = {} 128 | for key in keys: 129 | if np.prod(np.shape(df[key].iloc[0])) == 1: 130 | contents[key] = np.repeat(df[key], max_shape) 131 | else: 132 | contents[key] = np.stack(df[key], axis=0).flatten() 133 | 134 | df_pixel = pd.DataFrame(contents) 135 | return df_pixel 136 | 137 | 138 | def gen_cutoff_plot(df_image, eps=0.0, ood=False, plot=True): 139 | print(f"Generating cutoff plot with eps={eps}, ood={ood}") 140 | 141 | df = df_image[(df_image["Epsilon"]==eps) & (df_image["OOD"]==ood)] 142 | df_pixel = df_image_to_pixels(df, keys=["Target", "Mu", "Sigma"]) 143 | 144 | df_cutoff = pd.DataFrame( 145 | columns=["Method", "Model Path", "Percentile", "Error"]) 146 | 147 | for method, model_path_list in trained_models.items(): 148 | for model_i, model_path in enumerate(tqdm(model_path_list)): 149 | 150 | df_model = df_pixel[(df_pixel["Method"]==method.value) & (df_pixel["Model Path"]==model_path)] 151 | df_model = df_model.sort_values("Sigma", ascending=False) 152 | percentiles = np.arange(100)/100. 153 | cutoff_inds = (percentiles * df_model.shape[0]).astype(int) 154 | 155 | df_model["Error"] = np.abs(df_model["Mu"] - df_model["Target"]) 156 | mean_error = [df_model[cutoff:]["Error"].mean() 157 | for cutoff in cutoff_inds] 158 | df_single_cutoff = pd.DataFrame({'Method': method.value, 'Model Path': model_path, 159 | 'Percentile': percentiles, 'Error': mean_error}) 160 | 161 | df_cutoff = df_cutoff.append(df_single_cutoff) 162 | 163 | df_cutoff["Epsilon"] = eps 164 | 165 | if plot: 166 | print("Plotting cutoffs") 167 | sns.lineplot(x="Percentile", y="Error", hue="Method", data=df_cutoff) 168 | plt.savefig(os.path.join(output_dir, f"cutoff_eps-{eps}_ood-{ood}.pdf")) 169 | plt.show() 170 | 171 | sns.lineplot(x="Percentile", y="Error", hue="Model Path", style="Method", data=df_cutoff) 172 | plt.savefig(os.path.join(output_dir, f"cutoff_eps-{eps}_ood-{ood}_trial.pdf")) 173 | plt.show() 174 | 175 | g = sns.FacetGrid(df_cutoff, col="Method", legend_out=False) 176 | g = g.map_dataframe(sns.lineplot, x="Percentile", y="Error", hue="Model Path")#.add_legend() 177 | plt.savefig(os.path.join(output_dir, f"cutoff_eps-{eps}_ood-{ood}_trial_panel.pdf")) 178 | plt.show() 179 | 180 | 181 | return df_cutoff 182 | 183 | 184 | def gen_calibration_plot(df_image, eps=0.0, ood=False, plot=True): 185 | print(f"Generating calibration plot with eps={eps}, ood={ood}") 186 | df = df_image[(df_image["Epsilon"]==eps) & (df_image["OOD"]==ood)] 187 | # df = df.iloc[::10] 188 | df_pixel = df_image_to_pixels(df, keys=["Target", "Mu", "Sigma"]) 189 | 190 | df_calibration = pd.DataFrame( 191 | columns=["Method", "Model Path", "Expected Conf.", "Observed Conf."]) 192 | 193 | for method, model_path_list in trained_models.items(): 194 | for model_i, model_path in enumerate(tqdm(model_path_list)): 195 | 196 | df_model = df_pixel[(df_pixel["Method"]==method.value) & (df_pixel["Model Path"]==model_path)] 197 | expected_p = np.arange(41)/40. 198 | 199 | observed_p = [] 200 | for p in expected_p: 201 | ppf = scipy.stats.norm.ppf(p, loc=df_model["Mu"], scale=df_model["Sigma"]) 202 | obs_p = (df_model["Target"] < ppf).mean() 203 | observed_p.append(obs_p) 204 | 205 | df_single = pd.DataFrame({'Method': method.value, 'Model Path': model_path, 206 | 'Expected Conf.': expected_p, 'Observed Conf.': observed_p}) 207 | df_calibration = df_calibration.append(df_single) 208 | 209 | df_truth = pd.DataFrame({'Method': Model.GroundTruth.value, 'Model Path': "", 210 | 'Expected Conf.': expected_p, 'Observed Conf.': expected_p}) 211 | df_calibration = df_calibration.append(df_truth) 212 | 213 | df_calibration['Calibration Error'] = np.abs(df_calibration['Expected Conf.'] - df_calibration['Observed Conf.']) 214 | df_calibration["Epsilon"] = eps 215 | table = df_calibration.groupby(["Method", "Model Path"])["Calibration Error"].mean().reset_index() 216 | table = pd.pivot_table(table, values="Calibration Error", index="Method", aggfunc=[np.mean, np.std, scipy.stats.sem]) 217 | 218 | if plot: 219 | print(table) 220 | table.to_csv(os.path.join(output_dir, "calib_errors.csv")) 221 | 222 | print("Plotting confidence plots") 223 | sns.lineplot(x="Expected Conf.", y="Observed Conf.", hue="Method", data=df_calibration) 224 | plt.savefig(os.path.join(output_dir, f"calib_eps-{eps}_ood-{ood}.pdf")) 225 | plt.show() 226 | 227 | g = sns.FacetGrid(df_calibration, col="Method", legend_out=False) 228 | g = g.map_dataframe(sns.lineplot, x="Expected Conf.", y="Observed Conf.", hue="Model Path")#.add_legend() 229 | plt.savefig(os.path.join(output_dir, f"calib_eps-{eps}_ood-{ood}_panel.pdf")) 230 | plt.show() 231 | 232 | return df_calibration, table 233 | 234 | 235 | 236 | def gen_adv_plots(df_image, ood=False): 237 | print(f"Generating calibration plot with ood={ood}") 238 | df = df_image[df_image["OOD"]==ood] 239 | # df = df.iloc[::10] 240 | df_pixel = df_image_to_pixels(df, keys=["Target", "Mu", "Sigma", "Epsilon"]) 241 | df_pixel["Error"] = np.abs(df_pixel["Mu"] - df_pixel["Target"]) 242 | df_pixel["Entropy"] = 0.5*np.log(2*np.pi*np.exp(1.)*(df_pixel["Sigma"]**2)) 243 | 244 | ### Plot epsilon vs error per method 245 | df = df_pixel.groupby([df_pixel.index, "Method", "Model Path", "Epsilon"]).mean().reset_index() 246 | df_by_method = df_pixel.groupby(["Method", "Model Path", "Epsilon"]).mean().reset_index() 247 | sns.lineplot(x="Epsilon", y="Error", hue="Method", data=df_by_method) 248 | plt.savefig(os.path.join(output_dir, f"adv_ood-{ood}_method_error.pdf")) 249 | plt.show() 250 | 251 | ### Plot epsilon vs uncertainty per method 252 | sns.lineplot(x="Epsilon", y="Sigma", hue="Method", data=df_by_method) 253 | plt.savefig(os.path.join(output_dir, f"adv_ood-{ood}_method_sigma.pdf")) 254 | plt.show() 255 | # df_by_method["Entropy"] = 0.5*np.log(2*np.pi*np.exp(1.)*(df_by_method["Sigma"]**2)) 256 | # sns.lineplot(x="Epsilon", y="Entropy", hue="Method", data=df_by_method) 257 | # plt.savefig(os.path.join(output_dir, f"adv_ood-{ood}_method_entropy.pdf")) 258 | # plt.show() 259 | 260 | 261 | ### Plot entropy cdf for different epsilons 262 | df_cumdf = pd.DataFrame(columns=["Method", "Model Path", "Epsilon", "Entropy", "CDF"]) 263 | unc_ = np.linspace(df["Entropy"].min(), df["Entropy"].max(), 100) 264 | 265 | for method in df["Method"].unique(): 266 | for model_path in df["Model Path"].unique(): 267 | for eps in df["Epsilon"].unique(): 268 | df_subset = df[ 269 | (df["Method"]==method) & 270 | (df["Model Path"]==model_path) & 271 | (df["Epsilon"]==eps)] 272 | if len(df_subset) == 0: 273 | continue 274 | unc = np.sort(df_subset["Entropy"]) 275 | prob = np.linspace(0,1,unc.shape[0]) 276 | f_cdf = scipy.interpolate.interp1d(unc, prob, fill_value=(0.,1.), bounds_error=False) 277 | prob_ = f_cdf(unc_) 278 | 279 | df_single = pd.DataFrame({'Method': method, 'Model Path': model_path, 280 | 'Epsilon': eps, "Entropy": unc_, 'CDF': prob_}) 281 | df_cumdf = df_cumdf.append(df_single) 282 | 283 | g = sns.FacetGrid(df_cumdf, col="Method") 284 | g = g.map_dataframe(sns.lineplot, x="Entropy", y="CDF", hue="Epsilon", ci=None).add_legend() 285 | plt.savefig(os.path.join(output_dir, f"adv_ood-{ood}_cdf_method.pdf")) 286 | plt.show() 287 | 288 | # NOT USED FOR THE FINAL PAPER, BUT FEEL FREE TO UNCOMMENT AND RUN 289 | # ### Plot calibration for different epsilons/methods 290 | # print("Computing calibration plots per epsilon") 291 | # calibrations = [] 292 | # tables = [] 293 | # for eps in tqdm(df["Epsilon"].unique()): 294 | # df_calibration, table = gen_calibration_plot(df_image.copy(), eps, plot=False) 295 | # calibrations.append(df_calibration) 296 | # tables.append(table) 297 | # df_calibration = pd.concat(calibrations, ignore_index=True) 298 | # df_table = pd.concat(tables, ignore_index=True) 299 | # df_table.to_csv(os.path.join(output_dir, f"adv_ood-{ood}_calib_error.csv")) 300 | # 301 | # 302 | # sns.catplot(x="Method", y="Calibration Error", hue="Epsilon", data=df_calibration, kind="bar") 303 | # plt.savefig(os.path.join(output_dir, f"adv_ood-{ood}_calib_error_method.pdf")) 304 | # plt.show() 305 | # 306 | # sns.catplot(x="Epsilon", y="Calibration Error", hue="Method", data=df_calibration, kind="bar") 307 | # plt.savefig(os.path.join(output_dir, f"adv_ood-{ood}_calib_error_epsilon.pdf")) 308 | # plt.show() 309 | # 310 | # g = sns.FacetGrid(df_calibration, col="Method") 311 | # g = g.map_dataframe(sns.lineplot, x="Expected Conf.", y="Observed Conf.", hue="Epsilon") 312 | # g = g.add_legend() 313 | # plt.savefig(os.path.join(output_dir, f"adv_ood-{ood}_calib_method.pdf")) 314 | # plt.show() 315 | 316 | 317 | def gen_ood_comparison(df_image, unc_key="Entropy"): 318 | print(f"Generating OOD plots with unc_key={unc_key}") 319 | 320 | df = df_image[df_image["Epsilon"]==0.0] # Remove adversarial noise experiments 321 | # df = df.iloc[::5] 322 | df_pixel = df_image_to_pixels(df, keys=["Target", "Mu", "Sigma", "OOD"]) 323 | df_pixel["Entropy"] = 0.5*np.log(2*np.pi*np.exp(1.)*(df_pixel["Sigma"]**2)) 324 | 325 | df_by_method = df_pixel.groupby(["Method","Model Path", "OOD"]) 326 | df_by_image = df_pixel.groupby([df_pixel.index, "Method","Model Path", "OOD"]) 327 | 328 | df_mean_unc = df_by_method[unc_key].mean().reset_index() #mean of all pixels per method 329 | df_mean_unc_img = df_by_image[unc_key].mean().reset_index() #mean of all pixels in every method and image 330 | 331 | sns.catplot(x="Method", y=unc_key, hue="OOD", data=df_mean_unc_img, kind="violin") 332 | plt.savefig(os.path.join(output_dir, f"ood_{unc_key}_violin.pdf")) 333 | plt.show() 334 | 335 | sns.catplot(x="Method", y=unc_key, hue="OOD", data=df_mean_unc_img, kind="box", whis=0.5, showfliers=False) 336 | plt.savefig(os.path.join(output_dir, f"ood_{unc_key}_box.pdf")) 337 | plt.show() 338 | 339 | 340 | ### Plot PDF for each Method on both OOD and IN 341 | g = sns.FacetGrid(df_mean_unc_img, col="Method", hue="OOD") 342 | g.map(sns.distplot, "Entropy").add_legend() 343 | plt.savefig(os.path.join(output_dir, f"ood_{unc_key}_pdf_per_method.pdf")) 344 | plt.show() 345 | 346 | 347 | ### Grab some sample images of most and least uncertainty 348 | for method in df_mean_unc_img["Method"].unique(): 349 | imgs_max = dict() 350 | imgs_min = dict() 351 | for ood in df_mean_unc_img["OOD"].unique(): 352 | df_subset = df_mean_unc_img[ 353 | (df_mean_unc_img["Method"]==method) & 354 | (df_mean_unc_img["OOD"]==ood)] 355 | if len(df_subset) == 0: 356 | continue 357 | 358 | def get_imgs_from_idx(idx): 359 | i_img = df_subset.loc[idx]["level_0"] 360 | img_data = df_image.loc[i_img] 361 | sigma = np.array(img_data["Sigma"]) 362 | entropy = np.log(sigma**2) 363 | 364 | ret = [img_data["Input"], img_data["Mu"], entropy] 365 | return list(map(trim, ret)) 366 | 367 | def idxquantile(s, q=0.5, *args, **kwargs): 368 | qv = s.quantile(q, *args, **kwargs) 369 | return (s.sort_values()[::-1] <= qv).idxmax() 370 | 371 | imgs_max[ood] = get_imgs_from_idx(idx=idxquantile(df_subset["Entropy"], 0.95)) 372 | imgs_min[ood] = get_imgs_from_idx(idx=idxquantile(df_subset["Entropy"], 0.05)) 373 | 374 | all_entropy_imgs = np.array([ [d[ood][2] for ood in d.keys()] for d in (imgs_max, imgs_min)]) 375 | entropy_bounds = (all_entropy_imgs.min(), all_entropy_imgs.max()) 376 | 377 | Path(os.path.join(output_dir, "images")).mkdir(parents=True, exist_ok=True) 378 | for d in (imgs_max, imgs_min): 379 | for ood, (x, y, entropy) in d.items(): 380 | id = os.path.join(output_dir, f"images/method_{method}_ood_{ood}_entropy_{entropy.mean()}") 381 | cv2.imwrite(f"{id}_0.png", 255*x) 382 | cv2.imwrite(f"{id}_1.png", apply_cmap(y, cmap=cv2.COLORMAP_JET)) 383 | entropy = (entropy - entropy_bounds[0]) / (entropy_bounds[1]-entropy_bounds[0]) 384 | cv2.imwrite(f"{id}_2.png", apply_cmap(entropy)) 385 | 386 | 387 | 388 | ### Plot CDFs for every method on both OOD and IN 389 | df_cumdf = pd.DataFrame(columns=["Method", "Model Path", "OOD", unc_key, "CDF"]) 390 | unc_ = np.linspace(df_mean_unc_img[unc_key].min(), df_mean_unc_img[unc_key].max(), 200) 391 | 392 | for method in df_mean_unc_img["Method"].unique(): 393 | for model_path in df_mean_unc_img["Model Path"].unique(): 394 | for ood in df_mean_unc_img["OOD"].unique(): 395 | df = df_mean_unc_img[ 396 | (df_mean_unc_img["Method"]==method) & 397 | (df_mean_unc_img["Model Path"]==model_path) & 398 | (df_mean_unc_img["OOD"]==ood)] 399 | if len(df) == 0: 400 | continue 401 | unc = np.sort(df[unc_key]) 402 | prob = np.linspace(0,1,unc.shape[0]) 403 | f_cdf = scipy.interpolate.interp1d(unc, prob, fill_value=(0.,1.), bounds_error=False) 404 | prob_ = f_cdf(unc_) 405 | 406 | df_single = pd.DataFrame({'Method': method, 'Model Path': model_path, 407 | 'OOD': ood, unc_key: unc_, 'CDF': prob_}) 408 | df_cumdf = df_cumdf.append(df_single) 409 | 410 | sns.lineplot(data=df_cumdf, x=unc_key, y="CDF", hue="Method", style="OOD") 411 | plt.savefig(os.path.join(output_dir, f"ood_{unc_key}_cdfs.pdf")) 412 | plt.show() 413 | 414 | 415 | 416 | 417 | 418 | 419 | def load_data(): 420 | import data_loader 421 | _, (x_test, y_test) = data_loader.load_depth() 422 | _, (x_ood_test, y_ood_test) = data_loader.load_apollo() 423 | print("Loaded data:", x_test.shape, x_ood_test.shape) 424 | return (x_test, y_test), (x_ood_test, y_ood_test) 425 | 426 | def predict(method, model, x, n_samples=10): 427 | 428 | if method == Model.Dropout: 429 | preds = tf.stack([model(x, training=True) for _ in range(n_samples)], axis=0) #forward pass 430 | mu, var = tf.nn.moments(preds, axes=0) 431 | return mu, tf.sqrt(var) 432 | 433 | elif method == Model.Evidential: 434 | outputs = model(x, training=False) 435 | mu, v, alpha, beta = tf.split(outputs, 4, axis=-1) 436 | sigma = tf.sqrt(beta/(v*(alpha-1))) 437 | return mu, sigma 438 | 439 | elif method == Model.Ensemble: 440 | # preds = tf.stack([f(x) for f in model], axis=0) 441 | # y, _ = tf.split(preds, 2, axis=-1) 442 | # mu = tf.reduce_mean(y, axis=0) 443 | # sigma = tf.math.reduce_std(y, axis=0) 444 | preds = tf.stack([f(x) for f in model], axis=0) 445 | mu, var = tf.nn.moments(preds, 0) 446 | return mu, tf.sqrt(var) 447 | 448 | else: 449 | raise ValueError("Unknown model") 450 | 451 | def apply_cmap(gray, cmap=cv2.COLORMAP_MAGMA): 452 | if gray.dtype == np.float32: 453 | gray = np.clip(255*gray, 0, 255).astype(np.uint8) 454 | im_color = cv2.applyColorMap(gray, cmap) 455 | return im_color 456 | 457 | def trim(img, k=10): 458 | return img[k:-k, k:-k] 459 | def normalize(x, t_min=0, t_max=1): 460 | return ((x-x.min())/(x.max()-x.min())) * (t_max-t_min) + t_min 461 | 462 | 463 | @tf.function 464 | def create_adversarial_pattern(model, x, y): 465 | x_ = tf.convert_to_tensor(x) 466 | with tf.GradientTape() as tape: 467 | tape.watch(x_) 468 | if isinstance(model, list): 469 | preds = tf.stack([model_(x_, training=False) for model_ in model], axis=0) #forward pass 470 | pred, _ = tf.nn.moments(preds, axes=0) 471 | else: 472 | (pred) = model(x_, training=True) 473 | if pred.shape[-1] == 4: 474 | pred = tf.split(pred, 4, axis=-1)[0] 475 | loss = edl.losses.MSE(y, pred) 476 | # Get the gradients of the loss w.r.t to the input image. 477 | gradient = tape.gradient(loss, x_) 478 | # Get the sign of the gradients to create the perturbation 479 | signed_grad = tf.sign(gradient) 480 | return signed_grad 481 | 482 | 483 | 484 | if args.load_pkl: 485 | print("Loading!") 486 | df_image = pd.read_pickle("cached_depth_results.pkl") 487 | else: 488 | df_image = compute_predictions() 489 | df_image.to_pickle("cached_depth_results.pkl") 490 | 491 | 492 | """ ================================================== """ 493 | Path(output_dir).mkdir(parents=True, exist_ok=True) 494 | gen_cutoff_plot(df_image) 495 | gen_calibration_plot(df_image) 496 | gen_adv_plots(df_image) 497 | gen_ood_comparison(df_image) 498 | """ ================================================== """ 499 | --------------------------------------------------------------------------------