├── lib ├── __init__.py ├── transform.py ├── utils.py ├── visualize_flow.py ├── dataloader.py └── toy_data.py ├── models ├── Conditionners │ ├── strADEConditioner.py │ ├── __init__.py │ ├── Conditioner.py │ ├── CouplingConditioner.py │ ├── AutoregressiveConditioner.py │ └── DAGConditioner.py ├── Normalizers │ ├── __init__.py │ ├── AffineNormalizer.py │ ├── Normalizer.py │ └── MonotonicNormalizer.py ├── __init__.py ├── MLP.py ├── NormalizingFlowFactories.py └── NormalizingFlow.py ├── requirements.txt ├── UCIdatasets ├── data │ └── Human_Protein_Network │ │ ├── X_train.pkt │ │ └── X_valid.pkt ├── __init__.py ├── bsds300.py ├── proteins.py ├── gas.py ├── miniboone.py ├── digits.py ├── power.py ├── hepmass.py └── download_dataset.py ├── scripts └── ImageExperiments4gpus.sh ├── NoStepsUCIExperimentsConfigurations.yml ├── README.md ├── UCIExperimentsConfigurations.yml ├── ToyExperiments.py ├── license ├── UCIExperiments.py ├── ImageExperimentsTest.py └── ImageExperiments.py /lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/Conditionners/strADEConditioner.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.5.0 2 | torchvision==0.6.0 3 | UMNN==1.0 4 | h5py 5 | networkx==2.4 6 | pandas 7 | PyYAML==5.3.1 8 | scikit-learn 9 | 10 | -------------------------------------------------------------------------------- /UCIdatasets/data/Human_Protein_Network/X_train.pkt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AWehenkel/Graphical-Normalizing-Flows/HEAD/UCIdatasets/data/Human_Protein_Network/X_train.pkt -------------------------------------------------------------------------------- /UCIdatasets/data/Human_Protein_Network/X_valid.pkt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AWehenkel/Graphical-Normalizing-Flows/HEAD/UCIdatasets/data/Human_Protein_Network/X_valid.pkt -------------------------------------------------------------------------------- /models/Normalizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .Normalizer import Normalizer 2 | from .AffineNormalizer import AffineNormalizer 3 | from .MonotonicNormalizer import MonotonicNormalizer -------------------------------------------------------------------------------- /models/Conditionners/__init__.py: -------------------------------------------------------------------------------- 1 | from .Conditioner import Conditioner 2 | from .AutoregressiveConditioner import AutoregressiveConditioner 3 | from .CouplingConditioner import CouplingConditioner 4 | from .DAGConditioner import DAGConditioner -------------------------------------------------------------------------------- /UCIdatasets/__init__.py: -------------------------------------------------------------------------------- 1 | root = 'UCIdatasets/data/' 2 | 3 | from .power import POWER 4 | from .gas import GAS 5 | from .hepmass import HEPMASS 6 | from .miniboone import MINIBOONE 7 | from .bsds300 import BSDS300 8 | from .digits import DIGITS 9 | from .proteins import PROTEINS, get_shd 10 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .MLP import MLP, MNISTCNN, CIFAR10CNN 2 | from .NormalizingFlowFactories import buildFCNormalizingFlow 3 | from .Conditionners import AutoregressiveConditioner, DAGConditioner, CouplingConditioner, Conditioner 4 | from .Normalizers import AffineNormalizer, MonotonicNormalizer 5 | 6 | -------------------------------------------------------------------------------- /scripts/ImageExperiments4gpus.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --gres=gpu:4 3 | #SBATCH --cpus-per-gpu=2 4 | #SBATCH --job-name "DAG-NF-Image" 5 | #SBATCH --mem=16G 6 | #SBATCH --time=144:00:00 7 | #SBATCH --output="DAG-NF-$2-%j.out" 8 | 9 | args=("$@") 10 | supp_args=$(printf "%s " "${args[@]:1}") 11 | 12 | source activate UMNN 13 | python ImageExperiments.py -dataset $1 -nb_gpus 4 $supp_args -------------------------------------------------------------------------------- /models/Normalizers/AffineNormalizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .Normalizer import Normalizer 3 | 4 | 5 | class AffineNormalizer(Normalizer): 6 | def __init__(self): 7 | super(AffineNormalizer, self).__init__() 8 | 9 | def forward(self, x, h, context=None): 10 | mu, sigma = h[:, :, 0].clamp_(-5., 5.), torch.exp(h[:, :, 1].clamp_(-5., 2.)) 11 | z = x * sigma + mu 12 | return z, sigma 13 | 14 | def inverse_transform(self, z, h, context=None): 15 | mu, sigma = h[:, :, 0].clamp_(-5., 5.), torch.exp(h[:, :, 1].clamp_(-5., 2.)) 16 | x = (z - mu)/sigma 17 | return x 18 | -------------------------------------------------------------------------------- /models/Conditionners/Conditioner.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | class Conditioner(nn.Module): 5 | def __init__(self): 6 | super(Conditioner, self).__init__() 7 | #self.register_buffer("is_invertible", torch.tensor(True)) 8 | self.is_invertible = True 9 | 10 | ''' 11 | forward(self, x, context=None): 12 | :param x: A tensor [B, d] 13 | :param context: A tensor [B, c] 14 | :return: conditioning factors: [B, d, h] where h is the size of the embeddings. 15 | ''' 16 | def forward(self, x, context=None): 17 | pass 18 | 19 | ''' 20 | This returns the length of the longest path of the equivalent Bayesian Network also called 21 | ''' 22 | def depth(self): 23 | pass 24 | -------------------------------------------------------------------------------- /UCIdatasets/bsds300.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | 4 | import UCIdatasets as datasets 5 | 6 | 7 | class BSDS300: 8 | """ 9 | A dataset of patches from BSDS300. 10 | """ 11 | 12 | class Data: 13 | """ 14 | Constructs the dataset. 15 | """ 16 | 17 | def __init__(self, data): 18 | 19 | self.x = data[:] 20 | self.N = self.x.shape[0] 21 | 22 | def __init__(self): 23 | 24 | # load dataset 25 | f = h5py.File(datasets.root + 'BSDS300/BSDS300.hdf5', 'r') 26 | 27 | self.trn = self.Data(f['train']) 28 | self.val = self.Data(f['validation']) 29 | self.tst = self.Data(f['test']) 30 | 31 | self.n_dims = self.trn.x.shape[1] 32 | self.image_size = [int(np.sqrt(self.n_dims + 1))] * 2 33 | 34 | f.close() 35 | -------------------------------------------------------------------------------- /models/Normalizers/Normalizer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Normalizer(nn.Module): 5 | def __init__(self): 6 | super(Normalizer, self).__init__() 7 | 8 | ''' 9 | forward(self, x, context=None): 10 | :param x: A tensor [B, d] 11 | :param h: A tensor [B, d, h] 12 | :param context: A tensor [B, c] 13 | :return: z: [B, d] x transformed by a one-to-one mapping conditioned on h. 14 | jac: [B, d] the diagonal terms of the Jacobian. 15 | ''' 16 | def forward(self, x, h, context=None): 17 | pass 18 | 19 | 20 | ''' 21 | inverse_transform(self, z, h, context=None): 22 | :param z: A tensor [B, d] 23 | :param h: A tensor [B, d, h] 24 | :param context: A tensor [B, c] 25 | :return x: [B, d] the x that would generate z given the embedding and context. 26 | ''' 27 | def inverse_transform(self, z, h, context=None): 28 | pass 29 | -------------------------------------------------------------------------------- /models/Conditionners/CouplingConditioner.py: -------------------------------------------------------------------------------- 1 | from .Conditioner import Conditioner 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class CouplingMLP(nn.Module): 7 | def __init__(self, in_size, hidden, out_size, cond_in = 0): 8 | super(CouplingMLP, self).__init__() 9 | l1 = [in_size - int(in_size/2) + cond_in] + hidden 10 | l2 = hidden + [out_size * int(in_size/2)] 11 | layers = [] 12 | for h1, h2 in zip(l1, l2): 13 | layers += [nn.Linear(h1, h2), nn.ReLU()] 14 | layers.pop() 15 | self.net = nn.Sequential(*layers) 16 | 17 | def forward(self, x): 18 | return self.net(x) 19 | 20 | 21 | class CouplingConditioner(Conditioner): 22 | def __init__(self, in_size, hidden, out_size, cond_in=0): 23 | super(CouplingConditioner, self).__init__() 24 | self.in_size = in_size 25 | self.out_size = out_size 26 | self.cond_size = int(in_size/2) 27 | self.indep_size = in_size - self.cond_size 28 | self.embeding_net = CouplingMLP(in_size, hidden, out_size, cond_in) 29 | self.constants = nn.Parameter(torch.randn(self.indep_size, out_size)) 30 | 31 | def forward(self, x, context=None): 32 | if context is not None: 33 | x = torch.cat((x, context), 1) 34 | h1 = self.constants.unsqueeze(0).expand(x.shape[0], -1, -1) 35 | h2 = self.embeding_net(x[:, :self.indep_size]).view(x.shape[0], self.cond_size, self.out_size) 36 | return torch.cat((h1, h2), 1) 37 | 38 | def depth(self): 39 | return 1 40 | -------------------------------------------------------------------------------- /UCIdatasets/proteins.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | class PROTEINS: 5 | 6 | class Data: 7 | 8 | def __init__(self, data): 9 | 10 | self.x = data.astype(np.float32) 11 | self.N = self.x.shape[0] 12 | 13 | def __init__(self): 14 | 15 | trn, val, tst = load_data() 16 | 17 | self.trn = self.Data(trn) 18 | self.val = self.Data(val) 19 | self.tst = self.Data(tst) 20 | 21 | self.n_dims = self.trn.x.shape[1] 22 | 23 | self.A = get_adj_matrix().T 24 | 25 | 26 | def get_shd(A): 27 | A_true = get_adj_matrix() 28 | return np.abs(A - A_true).sum(), np.abs(A - A_true.T).sum() 29 | 30 | 31 | def get_adj_matrix(): 32 | A = np.zeros((11, 11)) 33 | # PKC Children 34 | A[8, 1] = 1 35 | A[8, 9] = 1 36 | A[8, 10] = 1 37 | A[8, 7] = 1 38 | A[8, 0] = 1 39 | # PKA Children 40 | A[7, 1] = 1 41 | A[7, 9] = 1 42 | A[7, 10] = 1 43 | A[7, 5] = 1 44 | A[7, 0] = 1 45 | A[7, 6] = 1 46 | # RAF Child 47 | A[0, 1] = 1 48 | # MEK Child 49 | A[1, 5] = 1 50 | # P44/42 Child 51 | A[5, 6] = 1 52 | #PlcGamma Children 53 | A[2, 3] = 1 54 | A[2, 4] = 1 55 | A[2, 8] = 1 56 | #PIP3 Children 57 | A[4, 3] = 1 58 | A[4, 6] = 1 59 | #PIP2 Child 60 | A[3, 8] = 1 61 | return A 62 | 63 | 64 | 65 | 66 | def load_data(): 67 | dir_f = "Datasets/Human_Protein_Network/" 68 | train = torch.load(dir_f + "X_train.pkt") 69 | mu, sigma = train.mean(0), train.std(0) 70 | valid = torch.load(dir_f + "X_train.pkt") 71 | test = None 72 | return (train - mu)/sigma, (valid - mu)/sigma, (valid - mu)/sigma 73 | 74 | -------------------------------------------------------------------------------- /lib/transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision 4 | 5 | def logit(x, alpha=1E-6): 6 | y = alpha + (1.-2*alpha)*x 7 | return np.log(y) - np.log(1. - y) 8 | 9 | def logit_back(x, alpha=1E-6): 10 | y = torch.sigmoid(x) 11 | return (y - alpha)/(1.-2*alpha) 12 | 13 | class AddUniformNoise(object): 14 | def __init__(self, alpha=1E-6): 15 | self.alpha = alpha 16 | def __call__(self,samples): 17 | samples = np.array(samples,dtype = np.float32) 18 | samples += np.random.uniform(size = samples.shape) 19 | samples = logit(samples/256., self.alpha) 20 | return samples 21 | 22 | class ToTensor(object): 23 | def __init__(self): 24 | pass 25 | def __call__(self,samples): 26 | samples = torch.from_numpy(samples).float() 27 | return samples 28 | 29 | class ZeroPadding(object): 30 | def __init__(self,num): 31 | self.num = num 32 | def __call__(self,samples): 33 | samples = np.array(samples,dtype = np.float32) 34 | tmp = np.zeros((32,32)) 35 | tmp[self.num:samples.shape[0]+self.num,self.num:samples.shape[1]+self.num] = samples 36 | return tmp 37 | 38 | class Crop(object): 39 | def __init__(self,num): 40 | self.num = num 41 | def __call__(self,samples): 42 | samples = np.array(samples,dtype = np.float32) 43 | return samples[self.num:-self.num,self.num:-self.num] 44 | 45 | class HorizontalFlip(object): 46 | def __init__(self): 47 | pass 48 | def __call__(self,samples): 49 | return torchvision.transforms.functional.hflip(samples) 50 | 51 | class Transpose(object): 52 | def __init__(self): 53 | pass 54 | def __call__(self,samples): 55 | return np.transpose(samples, (2, 0, 1)) 56 | 57 | class Resize(object): 58 | def __init__(self): 59 | pass 60 | def __call__(self, samples): 61 | return torchvision.transforms.functional.resize(samples, [32, 32]) 62 | -------------------------------------------------------------------------------- /UCIdatasets/gas.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | 4 | import UCIdatasets as datasets 5 | 6 | 7 | class GAS: 8 | 9 | class Data: 10 | 11 | def __init__(self, data): 12 | 13 | self.x = data.astype(np.float32) 14 | self.N = self.x.shape[0] 15 | 16 | def __init__(self): 17 | 18 | file = datasets.root + 'gas/ethylene_CO.pickle' 19 | trn, val, tst = load_data_and_clean_and_split(file) 20 | 21 | self.trn = self.Data(trn) 22 | self.val = self.Data(val) 23 | self.tst = self.Data(tst) 24 | 25 | self.n_dims = self.trn.x.shape[1] 26 | 27 | 28 | def load_data(file): 29 | 30 | data = pd.read_pickle(file) 31 | # data = pd.read_pickle(file).sample(frac=0.25) 32 | # data.to_pickle(file) 33 | data.drop("Meth", axis=1, inplace=True) 34 | data.drop("Eth", axis=1, inplace=True) 35 | data.drop("Time", axis=1, inplace=True) 36 | return data 37 | 38 | 39 | def get_correlation_numbers(data): 40 | C = data.corr() 41 | A = C > 0.98 42 | B = A.values.sum(axis=1) 43 | return B 44 | 45 | 46 | def load_data_and_clean(file): 47 | 48 | data = load_data(file) 49 | B = get_correlation_numbers(data) 50 | 51 | while np.any(B > 1): 52 | col_to_remove = np.where(B > 1)[0][0] 53 | col_name = data.columns[col_to_remove] 54 | data.drop(col_name, axis=1, inplace=True) 55 | B = get_correlation_numbers(data) 56 | # print(data.corr()) 57 | data = (data - data.mean()) / data.std() 58 | 59 | return data 60 | 61 | 62 | def load_data_and_clean_and_split(file): 63 | 64 | data = load_data_and_clean(file).values 65 | N_test = int(0.1 * data.shape[0]) 66 | data_test = data[-N_test:] 67 | data_train = data[0:-N_test] 68 | N_validate = int(0.1 * data_train.shape[0]) 69 | data_validate = data_train[-N_validate:] 70 | data_train = data_train[0:-N_validate] 71 | 72 | return data_train, data_validate, data_test 73 | -------------------------------------------------------------------------------- /NoStepsUCIExperimentsConfigurations.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | power-affine-coupling: 5 | dataset: 'power' 6 | b_size: 2500 7 | nb_epoch: 10000 8 | conditioner: 'Coupling' 9 | emb_net: [60, 60, 60, 2] 10 | normalizer: 'affine' 11 | 12 | 13 | power-affine-autoregressive: 14 | dataset: 'power' 15 | b_size: 2500 16 | nb_epoch: 10000 17 | conditioner: 'Autoregressive' 18 | emb_net: [60, 60, 60, 2] 19 | normalizer: 'affine' 20 | 21 | 22 | gas-affine-coupling: 23 | dataset: 'gas' 24 | b_size: 10000 25 | nb_epoch: 10000 26 | conditioner: 'Coupling' 27 | emb_net: [80, 80, 80, 30] 28 | normalizer: 'affine' 29 | weight_decay: 1e-3 30 | 31 | 32 | 33 | gas-affine-autoregressive: 34 | dataset: 'gas' 35 | b_size: 10000 36 | nb_epoch: 10000 37 | conditioner: 'Autoregressive' 38 | emb_net: [80, 80, 80, 30] 39 | normalizer: 'affine' 40 | weight_decay: 1e-3 41 | 42 | hepmass-affine-coupling: 43 | dataset: 'hepmass' 44 | b_size: 100 45 | nb_epoch: 10000 46 | conditioner: 'Coupling' 47 | emb_net: [210, 210, 210, 30] 48 | normalizer: 'affine' 49 | weight_decay: 1e-4 50 | 51 | hepmass-affine-autoregressive: 52 | dataset: 'hepmass' 53 | b_size: 100 54 | nb_epoch: 10000 55 | conditioner: 'Autoregressive' 56 | emb_net: [210, 210, 210, 30] 57 | normalizer: 'affine' 58 | weight_decay: 1e-4 59 | 60 | miniboone-affine-coupling: 61 | dataset: 'miniboone' 62 | b_size: 100 63 | nb_epoch: 10000 64 | conditioner: 'Coupling' 65 | emb_net: [430, 430, 430, 30] 66 | normalizer: 'affine' 67 | weight_decay: 1e-2 68 | 69 | miniboone-affine-autoregressive: 70 | dataset: 'miniboone' 71 | b_size: 100 72 | nb_epoch: 10000 73 | conditioner: 'Autoregressive' 74 | emb_net: [430, 430, 430, 30] 75 | normalizer: 'affine' 76 | weight_decay: 1e-2 77 | 78 | bsds300-affine-coupling: 79 | dataset: 'bsds300' 80 | b_size: 100 81 | nb_epoch: 10000 82 | conditioner: 'Coupling' 83 | emb_net: [630, 630, 630, 30] 84 | normalizer: 'affine' 85 | weight_decay: 1e-4 86 | 87 | bsds300-affine-autoregressive: 88 | dataset: 'bsds300' 89 | b_size: 100 90 | nb_epoch: 10000 91 | conditioner: 'Autoregressive' 92 | emb_net: [630, 630, 630, 30] 93 | normalizer: 'affine' 94 | weight_decay: 1e-4 -------------------------------------------------------------------------------- /UCIdatasets/miniboone.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import UCIdatasets as datasets 4 | 5 | 6 | class MINIBOONE: 7 | 8 | class Data: 9 | 10 | def __init__(self, data): 11 | 12 | self.x = data.astype(np.float32) 13 | self.N = self.x.shape[0] 14 | 15 | def __init__(self): 16 | 17 | file = datasets.root + 'miniboone/data.npy' 18 | trn, val, tst = load_data_normalised(file) 19 | 20 | self.trn = self.Data(trn) 21 | self.val = self.Data(val) 22 | self.tst = self.Data(tst) 23 | 24 | self.n_dims = self.trn.x.shape[1] 25 | 26 | 27 | def load_data(root_path): 28 | # NOTE: To remember how the pre-processing was done. 29 | # data = pd.read_csv(root_path, names=[str(x) for x in range(50)], delim_whitespace=True) 30 | # print data.head() 31 | # data = data.as_matrix() 32 | # # Remove some random outliers 33 | # indices = (data[:, 0] < -100) 34 | # data = data[~indices] 35 | # 36 | # i = 0 37 | # # Remove any features that have too many re-occuring real values. 38 | # features_to_remove = [] 39 | # for feature in data.T: 40 | # c = Counter(feature) 41 | # max_count = np.array([v for k, v in sorted(c.iteritems())])[0] 42 | # if max_count > 5: 43 | # features_to_remove.append(i) 44 | # i += 1 45 | # data = data[:, np.array([i for i in range(data.shape[1]) if i not in features_to_remove])] 46 | # np.save("~/data/miniboone/data.npy", data) 47 | 48 | data = np.load(root_path) 49 | N_test = int(0.1 * data.shape[0]) 50 | data_test = data[-N_test:] 51 | data = data[0:-N_test] 52 | N_validate = int(0.1 * data.shape[0]) 53 | data_validate = data[-N_validate:] 54 | data_train = data[0:-N_validate] 55 | 56 | return data_train, data_validate, data_test 57 | 58 | 59 | def load_data_normalised(root_path): 60 | 61 | data_train, data_validate, data_test = load_data(root_path) 62 | data = np.vstack((data_train, data_validate)) 63 | mu = data.mean(axis=0) 64 | s = data.std(axis=0) 65 | data_train = (data_train - mu) / s 66 | data_validate = (data_validate - mu) / s 67 | data_test = (data_test - mu) / s 68 | 69 | return data_train, data_validate, data_test 70 | -------------------------------------------------------------------------------- /UCIdatasets/digits.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.datasets import load_digits 3 | 4 | 5 | class DIGITS: 6 | 7 | class Data: 8 | 9 | def __init__(self, data): 10 | 11 | self.x = data.astype(np.float32) 12 | self.N = self.x.shape[0] 13 | 14 | def __init__(self): 15 | 16 | trn, val, tst = load_data_normalised() 17 | 18 | self.trn = self.Data(trn) 19 | self.val = self.Data(val) 20 | self.tst = self.Data(tst) 21 | 22 | self.n_dims = self.trn.x.shape[1] 23 | 24 | 25 | def load_data(): 26 | return load_digits()['data'] 27 | 28 | 29 | def load_data_split_with_noise(): 30 | 31 | rng = np.random.RandomState(42) 32 | 33 | data = load_data() 34 | rng.shuffle(data) 35 | N = data.shape[0] 36 | 37 | N_test = int(0.1 * data.shape[0]) 38 | data_test = data[-N_test:] 39 | data = data[0:-N_test] 40 | N_validate = int(0.1 * data.shape[0]) 41 | data_validate = data[-N_validate:] 42 | data_train = data[0:-N_validate] 43 | 44 | # [global_active_power, voltage, sub_metering_[1:3], time] 45 | """ 46 | 0.global_active_power: household global minute-averaged active power (in kilowatt) 47 | 1.voltage: minute-averaged voltage (in volt) 48 | 2.sub_metering_1: energy sub-metering No. 1 (in watt-hour of active energy). It corresponds to the kitchen, containing mainly a dishwasher, an oven and a microwave (hot plates are not electric but gas powered). 49 | 3.sub_metering_2: energy sub-metering No. 2 (in watt-hour of active energy). It corresponds to the laundry room, containing a washing-machine, a tumble-drier, a refrigerator and a light. 50 | 4.sub_metering_3: energy sub-metering No. 3 (in watt-hour of active energy). It corresponds to an electric water-heater and an air-conditioner. 51 | 5.time: time in format hh:mm:ss 52 | """ 53 | 54 | return data_train, data_validate, data_test 55 | 56 | 57 | def load_data_normalised(): 58 | 59 | data_train, data_validate, data_test = load_data_split_with_noise() 60 | data = np.vstack((data_train, data_validate)) 61 | mu = data.mean(axis=0) 62 | s = data.std(axis=0) 63 | s[s == 0.] = 1. 64 | 65 | data_train = (data_train - mu) / s 66 | data_validate = (data_validate - mu) / s 67 | data_test = (data_test - mu) / s 68 | 69 | return data_train, data_validate, data_test -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graphical Normalizing Flows 2 | Offical codes and experiments for the paper: 3 | > Graphical Normalizing Flows, Antoine Wehenkel and Gilles Louppe. (May 2020). 4 | > [[arxiv]](https://arxiv.org/abs/2006.02548) 5 | # Dependencies 6 | The list of dependencies can be found in requirements.txt text file and installed with the following command: 7 | ```bash 8 | pip install -r requirements.txt 9 | ``` 10 | # Code architecture 11 | This repository provides some code to build diverse types normalizing flow models in PyTorch. The core components are located in the **models** folder. The different flow models are described in the file **NormalizingFlow.py** and they all follow the structure of the parent **class NormalizingFlow**. 12 | A flow step is usually designed as a combination of a **normalizer** (such as the ones described in Normalizers sub-folder) with a **conditioner** (such as the ones described in Conditioners sub-folder). Following the code hierarchy provided makes the implementation of new conditioners, normalizers or even complete flow architecture very easy. 13 | # Paper's experiments 14 | ## UCI Datasets 15 | You first have to download the datasets with the following command: 16 | ```bash 17 | python UCIdatasets/download_dataset.py 18 | ``` 19 | Then you can run the experiment of your choice with the following command: 20 | ```bash 21 | python UCIExperiments.py -load_config 22 | ``` 23 | where defines the experimental configuration loaded from *UCIExperimentsConfigurations.yml* file, e.g. *power-mono-DAG*. 24 | See also UCIExperiments.py for other optional arguments. 25 | ## MNIST 26 | ### Affine Normalizers 27 | ##### Graphical Conditioner 28 | ```bash 29 | python ImageExperiments.py -dataset MNIST -b_size 100 -normalizer Affine -conditioner DAG -nb_flow 1 -nb_steps_dual 10 -l1 0. -prior_A_kernel 2 30 | ``` 31 | ##### Autoregressive Conditioner 32 | ```bash 33 | python ImageExperiments.py -dataset MNIST -b_size 100 -normalizer Affine -conditioner Autoregressive -nb_flow 1 -emb_net 1024 1024 1024 2 34 | ``` 35 | ##### Coupling Conditioner 36 | 37 | ```bash 38 | python ImageExperiments.py -dataset MNIST -b_size 100 -normalizer Affine -conditioner Coupling -nb_flow 1 -emb_net 1024 1024 1024 2 39 | ``` 40 | ### Monotonic Normalizers 41 | ##### Graphical Conditioner 42 | ```bash 43 | python ImageExperiments.py -dataset MNIST -b_size 100 -normalizer Monotonic -conditioner DAG -nb_flow 1 -nb_steps_dual 10 -l1 0. -prior_A_kernel 2 44 | ``` 45 | ##### Autoregressive Conditioner 46 | ```bash 47 | python ImageExperiments.py -dataset MNIST -b_size 100 -normalizer Monotonic -conditioner Autoregressive -nb_flow 1 -emb_net 1024 1024 1024 30 48 | ``` 49 | ##### Coupling Conditioner 50 | 51 | ```bash 52 | python ImageExperiments.py -dataset MNIST -b_size 100 -normalizer Monotonic -conditioner Coupling -nb_flow 1 -emb_net 1024 1024 1024 30 53 | ``` 54 | -------------------------------------------------------------------------------- /models/MLP.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | class MLP(nn.Module): 7 | def __init__(self, in_d, hidden, out_d, act_f=nn.ReLU()): 8 | super().__init__() 9 | self.in_d = in_d 10 | self.hiddens = hidden 11 | self.out_d = out_d 12 | self.act_f = act_f 13 | layers_dim = [in_d] + hidden + [out_d] 14 | layers = [] 15 | for dim_in, dim_out in zip(layers_dim[:-1], layers_dim[1:]): 16 | layers += [nn.Linear(dim_in, dim_out), act_f] 17 | layers.pop() 18 | self.net = nn.Sequential(*layers) 19 | 20 | def forward(self, x, context=None): 21 | return self.net(x) 22 | 23 | 24 | class MNISTCNN(nn.Module): 25 | def __init__(self, out_d=10, fc_l=[2304, 128], size_img=[1, 28, 28]): 26 | super(MNISTCNN, self).__init__() 27 | self.conv1 = nn.Conv2d(size_img[0], 16, 3, 1) 28 | self.conv2 = nn.Conv2d(16, 16, 3, 1) 29 | self.dropout1 = nn.Dropout2d(0.25) 30 | self.dropout2 = nn.Dropout2d(0.5) 31 | self.fc1 = nn.Linear(fc_l[0], fc_l[1]) 32 | self.fc2 = nn.Linear(fc_l[1], out_d) 33 | self.out_d = out_d 34 | self.size_img = size_img 35 | 36 | def forward(self, x, context=None): 37 | b_size = x.shape[0] 38 | x = self.conv1(x.view(-1, self.size_img[0], self.size_img[1], self.size_img[2])) 39 | x = F.relu(x) 40 | x = self.conv2(x) 41 | x = F.max_pool2d(x, 2) 42 | #x = self.dropout1(x) 43 | x = torch.flatten(x, 1) 44 | x = self.fc1(x) 45 | x = F.relu(x) 46 | #x = self.dropout2(x) 47 | x = self.fc2(x).view(b_size, -1) 48 | return x 49 | 50 | 51 | class CIFAR10CNN(nn.Module): 52 | def __init__(self, out_d=10, fc_l=[400, 128, 84], size_img=[3, 32, 32], k_size=5): 53 | super(CIFAR10CNN, self).__init__() 54 | self.conv1 = nn.Conv2d(size_img[0], 6, k_size) 55 | self.pool = nn.MaxPool2d(2, 2) 56 | self.conv2 = nn.Conv2d(6, 16, k_size) 57 | self.fc1 = nn.Linear(fc_l[0], fc_l[1]) 58 | self.fc2 = nn.Linear(fc_l[1], fc_l[2]) 59 | self.fc3 = nn.Linear(fc_l[2], out_d) 60 | 61 | self.out_d = out_d 62 | self.size_img = size_img 63 | 64 | def forward(self, x, context=None): 65 | b_size = x.shape[0] 66 | x = self.pool(F.relu(self.conv1(x.view(-1, self.size_img[0], self.size_img[1], self.size_img[2])))) 67 | x = self.pool(F.relu(self.conv2(x))) 68 | x = x.view(b_size, -1) 69 | x = F.relu(self.fc1(x)) 70 | x = F.relu(self.fc2(x)) 71 | x = self.fc3(x).view(b_size, -1) 72 | return x 73 | 74 | 75 | class IdentityNN(nn.Module): 76 | def __init__(self): 77 | super().__init__() 78 | 79 | def forward(self, x, context=None): 80 | return x 81 | -------------------------------------------------------------------------------- /UCIdatasets/power.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import UCIdatasets as datasets 4 | 5 | 6 | class POWER: 7 | 8 | class Data: 9 | 10 | def __init__(self, data): 11 | 12 | self.x = data.astype(np.float32) 13 | self.N = self.x.shape[0] 14 | 15 | def __init__(self): 16 | 17 | trn, val, tst = load_data_normalised() 18 | 19 | self.trn = self.Data(trn) 20 | self.val = self.Data(val) 21 | self.tst = self.Data(tst) 22 | 23 | self.n_dims = self.trn.x.shape[1] 24 | 25 | 26 | def load_data(): 27 | return np.load(datasets.root + 'power/data.npy') 28 | 29 | 30 | def load_data_split_with_noise(): 31 | 32 | rng = np.random.RandomState(42) 33 | 34 | data = load_data() 35 | rng.shuffle(data) 36 | N = data.shape[0] 37 | 38 | data = np.delete(data, 3, axis=1) 39 | data = np.delete(data, 1, axis=1) 40 | ############################ 41 | # Add noise 42 | ############################ 43 | # global_intensity_noise = 0.1*rng.rand(N, 1) 44 | voltage_noise = 0.01 * rng.rand(N, 1) 45 | # grp_noise = 0.001*rng.rand(N, 1) 46 | gap_noise = 0.001 * rng.rand(N, 1) 47 | sm_noise = rng.rand(N, 3) 48 | time_noise = np.zeros((N, 1)) 49 | # noise = np.hstack((gap_noise, grp_noise, voltage_noise, global_intensity_noise, sm_noise, time_noise)) 50 | # noise = np.hstack((gap_noise, grp_noise, voltage_noise, sm_noise, time_noise)) 51 | noise = np.hstack((gap_noise, voltage_noise, sm_noise, time_noise)) 52 | data = data + noise 53 | 54 | N_test = int(0.1 * data.shape[0]) 55 | data_test = data[-N_test:] 56 | data = data[0:-N_test] 57 | N_validate = int(0.1 * data.shape[0]) 58 | data_validate = data[-N_validate:] 59 | data_train = data[0:-N_validate] 60 | 61 | # [global_active_power, voltage, sub_metering_[1:3], time] 62 | """ 63 | 0.global_active_power: household global minute-averaged active power (in kilowatt) 64 | 1.voltage: minute-averaged voltage (in volt) 65 | 2.sub_metering_1: energy sub-metering No. 1 (in watt-hour of active energy). It corresponds to the kitchen, containing mainly a dishwasher, an oven and a microwave (hot plates are not electric but gas powered). 66 | 3.sub_metering_2: energy sub-metering No. 2 (in watt-hour of active energy). It corresponds to the laundry room, containing a washing-machine, a tumble-drier, a refrigerator and a light. 67 | 4.sub_metering_3: energy sub-metering No. 3 (in watt-hour of active energy). It corresponds to an electric water-heater and an air-conditioner. 68 | 5.time: time in format hh:mm:ss 69 | """ 70 | 71 | return data_train, data_validate, data_test 72 | 73 | 74 | def load_data_normalised(): 75 | 76 | data_train, data_validate, data_test = load_data_split_with_noise() 77 | data = np.vstack((data_train, data_validate)) 78 | mu = data.mean(axis=0) 79 | s = data.std(axis=0) 80 | data_train = (data_train - mu) / s 81 | data_validate = (data_validate - mu) / s 82 | data_test = (data_test - mu) / s 83 | 84 | return data_train, data_validate, data_test 85 | -------------------------------------------------------------------------------- /UCIdatasets/hepmass.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from collections import Counter 4 | from os.path import join 5 | 6 | import UCIdatasets as datasets 7 | 8 | 9 | class HEPMASS: 10 | """ 11 | The HEPMASS data set. 12 | http://archive.ics.uci.edu/ml/datasets/HEPMASS 13 | """ 14 | 15 | class Data: 16 | 17 | def __init__(self, data): 18 | 19 | self.x = data.astype(np.float32) 20 | self.N = self.x.shape[0] 21 | 22 | def __init__(self): 23 | 24 | path = datasets.root + 'hepmass/' 25 | trn, val, tst = load_data_no_discrete_normalised_as_array(path) 26 | 27 | self.trn = self.Data(trn) 28 | self.val = self.Data(val) 29 | self.tst = self.Data(tst) 30 | 31 | self.n_dims = self.trn.x.shape[1] 32 | 33 | 34 | def load_data(path): 35 | 36 | data_train = pd.read_csv(filepath_or_buffer=join(path, "1000_train.csv"), index_col=False) 37 | data_test = pd.read_csv(filepath_or_buffer=join(path, "1000_test.csv"), index_col=False) 38 | 39 | return data_train, data_test 40 | 41 | 42 | def load_data_no_discrete(path): 43 | """ 44 | Loads the positive class examples from the first 10 percent of the dataset. 45 | """ 46 | data_train, data_test = load_data(path) 47 | 48 | # Gets rid of any background noise examples i.e. class label 0. 49 | data_train = data_train[data_train[data_train.columns[0]] == 1] 50 | data_train = data_train.drop(data_train.columns[0], axis=1) 51 | data_test = data_test[data_test[data_test.columns[0]] == 1] 52 | data_test = data_test.drop(data_test.columns[0], axis=1) 53 | # Because the data set is messed up! 54 | data_test = data_test.drop(data_test.columns[-1], axis=1) 55 | 56 | return data_train, data_test 57 | 58 | 59 | def load_data_no_discrete_normalised(path): 60 | 61 | data_train, data_test = load_data_no_discrete(path) 62 | mu = data_train.mean() 63 | s = data_train.std() 64 | data_train = (data_train - mu) / s 65 | data_test = (data_test - mu) / s 66 | 67 | return data_train, data_test 68 | 69 | 70 | def load_data_no_discrete_normalised_as_array(path): 71 | 72 | data_train, data_test = load_data_no_discrete_normalised(path) 73 | data_train, data_test = data_train.values, data_test.values 74 | 75 | i = 0 76 | # Remove any features that have too many re-occurring real values. 77 | features_to_remove = [] 78 | for feature in data_train.T: 79 | c = Counter(feature) 80 | max_count = np.array([v for k, v in sorted(c.items())])[0] 81 | if max_count > 5: 82 | features_to_remove.append(i) 83 | i += 1 84 | data_train = data_train[:, np.array([i for i in range(data_train.shape[1]) if i not in features_to_remove])] 85 | data_test = data_test[:, np.array([i for i in range(data_test.shape[1]) if i not in features_to_remove])] 86 | 87 | N = data_train.shape[0] 88 | N_validate = int(N * 0.1) 89 | data_validate = data_train[-N_validate:] 90 | data_train = data_train[0:-N_validate] 91 | 92 | return data_train, data_validate, data_test 93 | -------------------------------------------------------------------------------- /models/Normalizers/MonotonicNormalizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from UMNN import NeuralIntegral, ParallelNeuralIntegral 3 | from .Normalizer import Normalizer 4 | import torch.nn as nn 5 | 6 | 7 | def _flatten(sequence): 8 | flat = [p.contiguous().view(-1) for p in sequence] 9 | return torch.cat(flat) if len(flat) > 0 else torch.tensor([]) 10 | 11 | 12 | class ELUPlus(nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | self.elu = nn.ELU() 16 | 17 | def forward(self, x): 18 | return self.elu(x) + 1.05 19 | 20 | 21 | class IntegrandNet(nn.Module): 22 | def __init__(self, hidden, cond_in): 23 | super(IntegrandNet, self).__init__() 24 | l1 = [1 + cond_in] + hidden 25 | l2 = hidden + [1] 26 | layers = [] 27 | for h1, h2 in zip(l1, l2): 28 | layers += [nn.Linear(h1, h2), nn.ReLU()] 29 | layers.pop() 30 | layers.append(ELUPlus()) 31 | self.net = nn.Sequential(*layers) 32 | 33 | def forward(self, x, h): 34 | nb_batch, in_d = x.shape 35 | x = torch.cat((x, h), 1) 36 | x_he = x.view(nb_batch, -1, in_d).transpose(1, 2).contiguous().view(nb_batch * in_d, -1) 37 | y = self.net(x_he).view(nb_batch, -1) 38 | return y 39 | 40 | 41 | class MonotonicNormalizer(Normalizer): 42 | def __init__(self, integrand_net, cond_size, nb_steps=20, solver="CC"): 43 | super(MonotonicNormalizer, self).__init__() 44 | if type(integrand_net) is list: 45 | self.integrand_net = IntegrandNet(integrand_net, cond_size) 46 | else: 47 | self.integrand_net = integrand_net 48 | self.solver = solver 49 | self.nb_steps = nb_steps 50 | 51 | def forward(self, x, h, context=None): 52 | x0 = torch.zeros(x.shape).to(x.device) 53 | xT = x 54 | z0 = h[:, :, 0] 55 | h = h.permute(0, 2, 1).contiguous().view(x.shape[0], -1) 56 | 57 | if self.solver == "CC": 58 | z = NeuralIntegral.apply(x0, xT, self.integrand_net, _flatten(self.integrand_net.parameters()), 59 | h, self.nb_steps) + z0 60 | elif self.solver == "CCParallel": 61 | z = ParallelNeuralIntegral.apply(x0, xT, self.integrand_net, 62 | _flatten(self.integrand_net.parameters()), 63 | h, self.nb_steps) + z0 64 | else: 65 | return None 66 | return z, self.integrand_net(x, h) 67 | 68 | 69 | def inverse_transform(self, z, h, context=None): 70 | x_max = torch.ones_like(z) * 20 71 | x_min = -torch.ones_like(z) * 20 72 | z_max, _ = self.forward(x_max, h, context) 73 | z_min, _ = self.forward(x_min, h, context) 74 | for i in range(20): 75 | x_middle = (x_max + x_min) / 2 76 | z_middle, _ = self.forward(x_middle, h, context) 77 | left = (z_middle > z).float() 78 | right = 1 - left 79 | x_max = left * x_middle + right * x_max 80 | x_min = right * x_middle + left * x_min 81 | z_max = left * z_middle + right * z_max 82 | z_min = right * z_middle + left * z_min 83 | return (x_max + x_min) / 2 84 | -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | from numbers import Number 4 | import logging 5 | import torch 6 | 7 | 8 | def makedirs(dirname): 9 | if not os.path.exists(dirname): 10 | os.makedirs(dirname) 11 | 12 | 13 | def get_logger(logpath, filepath, package_files=[], displaying=True, saving=True, debug=False): 14 | logger = logging.getLogger() 15 | if debug: 16 | level = logging.DEBUG 17 | else: 18 | level = logging.INFO 19 | logger.setLevel(level) 20 | if saving: 21 | info_file_handler = logging.FileHandler(logpath, mode="a") 22 | info_file_handler.setLevel(level) 23 | logger.addHandler(info_file_handler) 24 | if displaying: 25 | console_handler = logging.StreamHandler() 26 | console_handler.setLevel(level) 27 | logger.addHandler(console_handler) 28 | logger.info(filepath) 29 | with open(filepath, "r") as f: 30 | logger.info(f.read()) 31 | 32 | for f in package_files: 33 | logger.info(f) 34 | with open(f, "r") as package_f: 35 | logger.info(package_f.read()) 36 | 37 | return logger 38 | 39 | 40 | class AverageMeter(object): 41 | """Computes and stores the average and current value""" 42 | 43 | def __init__(self): 44 | self.reset() 45 | 46 | def reset(self): 47 | self.val = 0 48 | self.avg = 0 49 | self.sum = 0 50 | self.count = 0 51 | 52 | def update(self, val, n=1): 53 | self.val = val 54 | self.sum += val * n 55 | self.count += n 56 | self.avg = self.sum / self.count 57 | 58 | 59 | class RunningAverageMeter(object): 60 | """Computes and stores the average and current value""" 61 | 62 | def __init__(self, momentum=0.99): 63 | self.momentum = momentum 64 | self.reset() 65 | 66 | def reset(self): 67 | self.val = None 68 | self.avg = 0 69 | 70 | def update(self, val): 71 | if self.val is None: 72 | self.avg = val 73 | else: 74 | self.avg = self.avg * self.momentum + val * (1 - self.momentum) 75 | self.val = val 76 | 77 | 78 | def inf_generator(iterable): 79 | """Allows training with DataLoaders in a single infinite loop: 80 | for i, (x, y) in enumerate(inf_generator(train_loader)): 81 | """ 82 | iterator = iterable.__iter__() 83 | while True: 84 | try: 85 | yield iterator.__next__() 86 | except StopIteration: 87 | iterator = iterable.__iter__() 88 | 89 | 90 | def save_checkpoint(state, save, epoch): 91 | if not os.path.exists(save): 92 | os.makedirs(save) 93 | filename = os.path.join(save, 'checkpt-%04d.pth' % epoch) 94 | torch.save(state, filename) 95 | 96 | 97 | def isnan(tensor): 98 | return (tensor != tensor) 99 | 100 | 101 | def logsumexp(value, dim=None, keepdim=False): 102 | """Numerically stable implementation of the operation 103 | value.exp().sum(dim, keepdim).log() 104 | """ 105 | if dim is not None: 106 | m, _ = torch.max(value, dim=dim, keepdim=True) 107 | value0 = value - m 108 | if keepdim is False: 109 | m = m.squeeze(dim) 110 | return m + torch.log(torch.sum(torch.exp(value0), dim=dim, keepdim=keepdim)) 111 | else: 112 | m = torch.max(value) 113 | sum_exp = torch.sum(torch.exp(value - m)) 114 | if isinstance(sum_exp, Number): 115 | return m + math.log(sum_exp) 116 | else: 117 | return m + torch.log(sum_exp) 118 | -------------------------------------------------------------------------------- /lib/visualize_flow.py: -------------------------------------------------------------------------------- 1 | # 2 | 3 | import numpy as np 4 | import matplotlib 5 | matplotlib.use("Agg") 6 | import matplotlib.pyplot as plt 7 | import torch 8 | 9 | LOW = -4.5 10 | HIGH = 4.5 11 | 12 | font = {'family' : 'normal', 13 | 'weight' : 'normal', 14 | 'size' : 22} 15 | 16 | matplotlib.rc('font', **font) 17 | 18 | 19 | def plt_potential_func(potential, ax, npts=100, title="$p(x)$"): 20 | """ 21 | Args: 22 | potential: computes U(z_k) given z_k 23 | """ 24 | xside = np.linspace(LOW, HIGH, npts) 25 | yside = np.linspace(LOW, HIGH, npts) 26 | xx, yy = np.meshgrid(xside, yside) 27 | z = np.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)]) 28 | 29 | z = torch.Tensor(z) 30 | u = potential(z).cpu().numpy() 31 | p = np.exp(-u).reshape(npts, npts) 32 | 33 | plt.pcolormesh(xx, yy, p) 34 | ax.invert_yaxis() 35 | ax.get_xaxis().set_ticks([]) 36 | ax.get_yaxis().set_ticks([]) 37 | ax.set_title(title) 38 | 39 | 40 | def plt_flow(transform, ax, npts=50, title="$q(x)$", device="cpu"): 41 | """ 42 | Args: 43 | transform: computes z_k and log(q_k) given z_0 44 | """ 45 | side = np.linspace(LOW, HIGH, npts) 46 | xx, yy = np.meshgrid(side, side) 47 | x = np.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)]) 48 | with torch.no_grad(): 49 | logqx, z = transform(torch.tensor(x).float().to(device)) 50 | 51 | #xx = z[:, 0].cpu().numpy().reshape(npts, npts) 52 | #yy = z[:, 1].cpu().numpy().reshape(npts, npts) 53 | qz = np.exp(logqx.cpu().numpy()).reshape(npts, npts) 54 | qz_1 = qz.sum(1) 55 | qz_2 = qz.sum(0) 56 | 57 | pcol = plt.pcolormesh(xx, yy, qz, linewidth=0, rasterized=True, cmap="BuPu") 58 | pcol.set_edgecolor('face') 59 | ax.set_xlim(-4.5, 4.5) 60 | ax.set_ylim(-4.5, 4.5) 61 | cmap = matplotlib.cm.get_cmap(None) 62 | ax.set_facecolor(cmap(0.)) 63 | #ax.invert_yaxis() 64 | #plt.xlabel('$x_1$') 65 | #plt.ylabel('$x_2$') 66 | ax.get_xaxis().set_ticks([])#[-4, 0, 4]) 67 | ax.get_yaxis().set_ticks([])#[-4, 0, 4]) 68 | 69 | #ax.set_title(title) 70 | return qz_1, qz_2 71 | 72 | def plt_stream(transform, ax, npts=200, title="Density streamflow", device="cpu"): 73 | side = np.linspace(LOW, HIGH, npts) 74 | xx, yy = np.meshgrid(side, side) 75 | x = np.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)]) 76 | with torch.no_grad(): 77 | logqx, z = transform(torch.tensor(x).float().to(device)) 78 | d_z_x = -(x - z.cpu().numpy())[:, 0].reshape(xx.shape) 79 | d_z_y = -(x - z.cpu().numpy())[:, 1].reshape(xx.shape) 80 | plt.streamplot(xx, yy, d_z_x, d_z_y, color=(d_z_y**2 + d_z_x**2)/2, cmap='autumn') 81 | 82 | 83 | def plt_flow_density(prior_logdensity, inverse_transform, ax, npts=100, memory=100, title="$q(x)$", device="cpu"): 84 | side = np.linspace(LOW, HIGH, npts) 85 | xx, yy = np.meshgrid(side, side) 86 | x = np.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)]) 87 | 88 | x = torch.from_numpy(x).type(torch.float32).to(device) 89 | zeros = torch.zeros(x.shape[0], 1).to(x) 90 | 91 | z, delta_logp = [], [] 92 | inds = torch.arange(0, x.shape[0]).to(torch.int64) 93 | for ii in torch.split(inds, int(memory**2)): 94 | z_, delta_logp_ = inverse_transform(x[ii], zeros[ii]) 95 | z.append(z_) 96 | delta_logp.append(delta_logp_) 97 | z = torch.cat(z, 0) 98 | delta_logp = torch.cat(delta_logp, 0) 99 | 100 | logpz = prior_logdensity(z).view(z.shape[0], -1).sum(1, keepdim=True) # logp(z) 101 | logpx = logpz - delta_logp 102 | 103 | px = np.exp(logpx.cpu().numpy()).reshape(npts, npts) 104 | 105 | ax.imshow(px) 106 | ax.get_xaxis().set_ticks([]) 107 | ax.get_yaxis().set_ticks([]) 108 | ax.set_title(title) 109 | 110 | 111 | def plt_flow_samples(prior_sample, transform, ax, npts=200, memory=100, title="$x ~ q(x)$", device="cpu"): 112 | z = prior_sample(npts * npts, 2).type(torch.float32).to(device) 113 | zk = [] 114 | inds = torch.arange(0, z.shape[0]).to(torch.int64) 115 | for ii in torch.split(inds, int(memory**2)): 116 | zk.append(transform(z[ii])) 117 | zk = torch.cat(zk, 0).cpu().numpy() 118 | ax.hist2d(zk[:, 0], zk[:, 1], range=[[LOW, HIGH], [LOW, HIGH]], bins=npts) 119 | #ax.invert_yaxis() 120 | ax.get_xaxis().set_ticks([2, 0, 2]) 121 | ax.get_yaxis().set_ticks([2, 0, 2]) 122 | ax.set_title(title) 123 | 124 | 125 | def plt_samples(samples, ax, npts=200, title="$x \sim p(x)$"): 126 | ax.hist2d(samples[:, 0], samples[:, 1], range=[[LOW, HIGH], [LOW, HIGH]], bins=npts) 127 | ax.invert_yaxis() 128 | ax.get_xaxis().set_ticks([]) 129 | ax.get_yaxis().set_ticks([]) 130 | ax.set_title(title) 131 | 132 | 133 | def visualize_transform( 134 | potential_or_samples, prior_sample, prior_density, transform=None, inverse_transform=None, samples=True, npts=100, 135 | memory=100, device="cpu" 136 | ): 137 | """Produces visualization for the model density and samples from the model.""" 138 | plt.clf() 139 | ax = plt.subplot(1, 3, 1, aspect="equal") 140 | if samples: 141 | plt_samples(potential_or_samples, ax, npts=npts) 142 | else: 143 | plt_potential_func(potential_or_samples, ax, npts=npts) 144 | 145 | ax = plt.subplot(1, 3, 2, aspect="equal") 146 | if inverse_transform is None: 147 | plt_flow(prior_density, transform, ax, npts=npts, device=device) 148 | else: 149 | plt_flow_density(prior_density, inverse_transform, ax, npts=npts, memory=memory, device=device) 150 | 151 | ax = plt.subplot(1, 3, 3, aspect="equal") 152 | if transform is not None: 153 | plt_flow_samples(prior_sample, transform, ax, npts=npts, memory=memory, device=device) 154 | -------------------------------------------------------------------------------- /lib/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import datasets, transforms 3 | from lib.transform import AddUniformNoise, ToTensor, HorizontalFlip, Transpose, Resize 4 | 5 | def dataloader(dataset, batch_size, cuda, conditionnal=False): 6 | 7 | if dataset == 'CIFAR10': 8 | data = datasets.CIFAR10('./CIFAR10', train=True, download=True, 9 | transform=transforms.Compose([ 10 | AddUniformNoise(0.05), 11 | Transpose(), 12 | ToTensor() 13 | ])) 14 | 15 | data_hflip = datasets.CIFAR10('./CIFAR10', train=True, download=True, 16 | transform=transforms.Compose([ 17 | HorizontalFlip(), 18 | AddUniformNoise(0.05), 19 | Transpose(), 20 | ToTensor() 21 | ])) 22 | data = torch.utils.data.ConcatDataset([data, data_hflip]) 23 | 24 | train_data, valid_data = torch.utils.data.random_split(data, [90000, 10000]) 25 | 26 | test_data = datasets.CIFAR10('./CIFAR10', train=False, download=True, 27 | transform=transforms.Compose([ 28 | AddUniformNoise(0.05), 29 | Transpose(), 30 | ToTensor() 31 | ])) 32 | 33 | elif dataset == 'MNIST': 34 | data = datasets.MNIST('./MNIST', train=True, download=True, 35 | transform=transforms.Compose([ 36 | AddUniformNoise(), 37 | ToTensor() 38 | ])) 39 | 40 | 41 | train_data, valid_data = torch.utils.data.random_split(data, [50000, 10000]) 42 | 43 | test_data = datasets.MNIST('./MNIST', train=False, download=True, 44 | transform=transforms.Compose([ 45 | AddUniformNoise(), 46 | ToTensor() 47 | ])) 48 | 49 | 50 | 51 | elif len(dataset) == 6 and dataset[:5] == 'MNIST': 52 | data = datasets.MNIST('./MNIST', train=True, download=True, 53 | transform=transforms.Compose([ 54 | AddUniformNoise(), 55 | ToTensor() 56 | ])) 57 | label = int(dataset[5]) 58 | idx = data.train_labels == label 59 | data.train_labels = data.train_labels[idx] 60 | data.train_data = data.train_data[idx] 61 | 62 | train_data, valid_data = torch.utils.data.random_split(data, [5000, idx.sum() - 5000]) 63 | 64 | test_data = datasets.MNIST('./MNIST', train=False, download=True, 65 | transform=transforms.Compose([ 66 | AddUniformNoise(), 67 | ToTensor() 68 | ])) 69 | idx = test_data.test_labels == label 70 | test_data.test_labels = test_data.test_labels[idx] 71 | test_data.test_data = test_data.test_data[idx] 72 | elif dataset == 'MNIST32': 73 | data = datasets.MNIST('./MNIST', train=True, download=True, 74 | transform=transforms.Compose([ 75 | Resize(), 76 | AddUniformNoise(), 77 | ToTensor() 78 | ])) 79 | 80 | train_data, valid_data = torch.utils.data.random_split(data, [50000, 10000]) 81 | 82 | 83 | test_data = datasets.MNIST('./MNIST', train=False, download=True, 84 | transform=transforms.Compose([ 85 | Resize(), 86 | AddUniformNoise(), 87 | ToTensor() 88 | ])) 89 | elif len(dataset) == 8 and dataset[:7] == 'MNIST32': 90 | data = datasets.MNIST('./MNIST', train=True, download=True, 91 | transform=transforms.Compose([ 92 | Resize(), 93 | AddUniformNoise(), 94 | ToTensor() 95 | ])) 96 | 97 | label = int(dataset[7]) 98 | idx = data.train_labels == label 99 | data.train_labels = data.train_labels[idx] 100 | data.train_data = data.train_data[idx] 101 | 102 | train_data, valid_data = torch.utils.data.random_split(data, [5000, idx.sum() - 5000]) 103 | 104 | test_data = datasets.MNIST('./MNIST', train=False, download=True, 105 | transform=transforms.Compose([ 106 | Resize(), 107 | AddUniformNoise(), 108 | ToTensor() 109 | ])) 110 | idx = test_data.test_labels == label 111 | test_data.test_labels = test_data.test_labels[idx] 112 | test_data.test_data = test_data.test_data[idx] 113 | else: 114 | print ('what network ?', args.net) 115 | sys.exit(1) 116 | 117 | #load data 118 | kwargs = {'num_workers': 0, 'pin_memory': True} if cuda > -1 else {} 119 | 120 | train_loader = torch.utils.data.DataLoader( 121 | train_data, 122 | batch_size=batch_size, shuffle=True, **kwargs) 123 | 124 | valid_loader = torch.utils.data.DataLoader( 125 | valid_data, 126 | batch_size=batch_size, shuffle=True, **kwargs) 127 | 128 | test_loader = torch.utils.data.DataLoader(test_data, 129 | batch_size=batch_size, shuffle=True, **kwargs) 130 | 131 | return train_loader, valid_loader, test_loader 132 | -------------------------------------------------------------------------------- /models/NormalizingFlowFactories.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .Normalizers import * 4 | from .Conditionners import * 5 | from .NormalizingFlow import NormalizingFlowStep, FCNormalizingFlow, CNNormalizingFlow 6 | from math import pi 7 | from .MLP import MNISTCNN, CIFAR10CNN 8 | 9 | 10 | class NormalLogDensity(nn.Module): 11 | def __init__(self): 12 | super(NormalLogDensity, self).__init__() 13 | self.register_buffer("pi", torch.tensor(pi)) 14 | 15 | def forward(self, z): 16 | return -.5 * (torch.log(self.pi * 2) + z ** 2).sum(1) 17 | 18 | 19 | def buildFCNormalizingFlow(nb_steps, conditioner_type, conditioner_args, normalizer_type, normalizer_args): 20 | """ 21 | Function that returns a normalizing flow used in UCI experiments 22 | 23 | @param nb_steps: number of normalizing flow steps 24 | 25 | """ 26 | flow_steps = [] 27 | for step in range(nb_steps): 28 | conditioner = conditioner_type(**conditioner_args) 29 | normalizer = normalizer_type(**normalizer_args) 30 | flow_step = NormalizingFlowStep(conditioner, normalizer) 31 | flow_steps.append(flow_step) 32 | return FCNormalizingFlow(flow_steps, NormalLogDensity()) 33 | 34 | 35 | def MNIST_A_prior(in_size, kernel): 36 | A = torch.zeros(in_size**2, in_size**2) 37 | row_pix = torch.arange(in_size).view(1, -1).expand(in_size, -1).contiguous().view(-1, 1) 38 | col_pix = torch.arange(in_size).view(-1, 1).expand(-1, in_size).contiguous().view(-1, 1) 39 | 40 | for i in range(-kernel, kernel + 1): 41 | for j in range(-kernel, kernel + 1): 42 | mask = ((col_pix + i) < in_size) * ((col_pix + i) >= 0) * ((row_pix + j) < in_size) * ((row_pix + j) >= 0) 43 | idx = ((row_pix * in_size + col_pix) * in_size**2 + col_pix + i + in_size * (row_pix + j)) * mask 44 | A.view(-1)[idx] = 1. 45 | A.view(-1)[torch.arange(0, in_size**4, in_size**2+1)] = 0 46 | return A 47 | 48 | 49 | def buildMNISTNormalizingFlow(nb_inner_steps, normalizer_type, normalizer_args, l1=0., nb_epoch_update=10, 50 | hot_encoding=False, prior_kernel=None): 51 | if len(nb_inner_steps) == 3: 52 | img_sizes = [[1, 28, 28], [1, 14, 14], [1, 7, 7]] 53 | dropping_factors = [[1, 2, 2], [1, 2, 2], [1, 1, 1]] 54 | fc_l = [[2304, 128], [400, 64], [16, 16]] 55 | 56 | outter_steps = [] 57 | for i, fc in zip(range(len(fc_l)), fc_l): 58 | in_size = img_sizes[i][0] * img_sizes[i][1] * img_sizes[i][2] 59 | inner_steps = [] 60 | for step in range(nb_inner_steps[i]): 61 | emb_s = 2 if normalizer_type is AffineNormalizer else 30 62 | 63 | hidden = MNISTCNN(fc_l=fc, size_img=img_sizes[i], out_d=emb_s) 64 | A_prior = MNIST_A_prior(img_sizes[i][1], prior_kernel) if prior_kernel is not None else None 65 | cond = DAGConditioner(in_size, hidden, emb_s, l1=l1, nb_epoch_update=nb_epoch_update, 66 | hot_encoding=hot_encoding, A_prior=A_prior) 67 | if normalizer_type is MonotonicNormalizer: 68 | emb_s = 30 + in_size if hot_encoding else 30 69 | norm = normalizer_type(**normalizer_args, cond_size=emb_s) 70 | else: 71 | norm = normalizer_type(**normalizer_args) 72 | flow_step = NormalizingFlowStep(cond, norm) 73 | inner_steps.append(flow_step) 74 | flow = FCNormalizingFlow(inner_steps, None) 75 | flow.img_sizes = img_sizes[i] 76 | outter_steps.append(flow) 77 | 78 | return CNNormalizingFlow(outter_steps, NormalLogDensity(), dropping_factors) 79 | elif len(nb_inner_steps) == 1: 80 | inner_steps = [] 81 | for step in range(nb_inner_steps[0]): 82 | emb_s = 2 if normalizer_type is AffineNormalizer else 30 83 | hidden = MNISTCNN(fc_l=[2304, 128], size_img=[1, 28, 28], out_d=emb_s) 84 | A_prior = MNIST_A_prior(28, prior_kernel) if prior_kernel is not None else None 85 | cond = DAGConditioner(1*28*28, hidden, emb_s, l1=l1, nb_epoch_update=nb_epoch_update, 86 | hot_encoding=hot_encoding, A_prior=A_prior) 87 | if normalizer_type is MonotonicNormalizer: 88 | emb_s = 30 + 28*28 if hot_encoding else 30 89 | norm = normalizer_type(**normalizer_args, cond_size=emb_s) 90 | else: 91 | norm = normalizer_type(**normalizer_args) 92 | flow_step = NormalizingFlowStep(cond, norm) 93 | inner_steps.append(flow_step) 94 | flow = FCNormalizingFlow(inner_steps, NormalLogDensity()) 95 | return flow 96 | else: 97 | return None 98 | 99 | 100 | def buildCIFAR10NormalizingFlow(nb_inner_steps, normalizer_type, normalizer_args, l1=0., nb_epoch_update=5): 101 | if len(nb_inner_steps) == 4: 102 | img_sizes = [[3, 32, 32], [1, 32, 32], [1, 16, 16], [1, 8, 8]] 103 | dropping_factors = [[3, 1, 1], [1, 2, 2], [1, 2, 2]] 104 | fc_l = [[400, 128, 84], [576, 128, 32], [64, 32, 32], [16, 32, 32]] 105 | k_sizes = [5, 3, 3, 2] 106 | 107 | outter_steps = [] 108 | for i, fc in zip(range(len(fc_l)), fc_l): 109 | in_size = img_sizes[i][0] * img_sizes[i][1] * img_sizes[i][2] 110 | inner_steps = [] 111 | for step in range(nb_inner_steps[i]): 112 | emb_s = 2 if normalizer_type is AffineNormalizer else 30 113 | hidden = CIFAR10CNN(out_d=emb_s, fc_l=fc, size_img=img_sizes[i], k_size=k_sizes[i]) 114 | cond = DAGConditioner(in_size, hidden, emb_s, l1=l1, nb_epoch_update=nb_epoch_update) 115 | norm = normalizer_type(**normalizer_args) 116 | flow_step = NormalizingFlowStep(cond, norm) 117 | inner_steps.append(flow_step) 118 | flow = FCNormalizingFlow(inner_steps, None) 119 | flow.img_sizes = img_sizes[i] 120 | outter_steps.append(flow) 121 | 122 | return CNNormalizingFlow(outter_steps, NormalLogDensity(), dropping_factors) 123 | elif len(nb_inner_steps) == 1: 124 | inner_steps = [] 125 | for step in range(nb_inner_steps[0]): 126 | emb_s = 2 if normalizer_type is AffineNormalizer else 30 127 | hidden = CIFAR10CNN(fc_l=[400, 128, 84], size_img=[3, 32, 32], out_d=emb_s, k_size=5) 128 | cond = DAGConditioner(3*32*32, hidden, emb_s, l1=l1, nb_epoch_update=nb_epoch_update) 129 | norm = normalizer_type(**normalizer_args) 130 | flow_step = NormalizingFlowStep(cond, norm) 131 | inner_steps.append(flow_step) 132 | flow = FCNormalizingFlow(inner_steps, NormalLogDensity()) 133 | return flow 134 | else: 135 | return None 136 | -------------------------------------------------------------------------------- /models/Conditionners/AutoregressiveConditioner.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements Masked AutoEncoder for Density Estimation, by Germain et al. 2015 3 | Re-implementation by Andrej Karpathy based on https://arxiv.org/abs/1502.03509 4 | Modified by Antoine Wehenkel 5 | """ 6 | 7 | import numpy as np 8 | from .Conditioner import Conditioner 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import math 13 | 14 | class MaskedLinear(nn.Linear): 15 | """ same as Linear except has a configurable mask on the weights """ 16 | 17 | def __init__(self, in_features, out_features, bias=True): 18 | super().__init__(in_features, out_features, bias) 19 | self.register_buffer('mask', torch.ones(out_features, in_features)) 20 | 21 | def set_mask(self, mask): 22 | self.mask.data.copy_(torch.from_numpy(mask.astype(np.uint8).T)) 23 | 24 | def forward(self, input): 25 | return F.linear(input, self.mask * self.weight, self.bias) 26 | 27 | 28 | class MADE(nn.Module): 29 | def __init__(self, nin, hidden_sizes, nout, num_masks=1, natural_ordering=False, random=False, device="cpu"): 30 | """ 31 | nin: integer; number of inputs 32 | hidden sizes: a list of integers; number of units in hidden layers 33 | nout: integer; number of outputs, which usually collectively parameterize some kind of 1D distribution 34 | note: if nout is e.g. 2x larger than nin (perhaps the mean and std), then the first nin 35 | will be all the means and the second nin will be stds. i.e. output dimensions depend on the 36 | same input dimensions in "chunks" and should be carefully decoded downstream appropriately. 37 | the output of running the tests for this file makes this a bit more clear with examples. 38 | num_masks: can be used to train ensemble over orderings/connections 39 | natural_ordering: force natural ordering of dimensions, don't use random permutations 40 | """ 41 | 42 | super().__init__() 43 | self.random = random 44 | self.nin = nin 45 | self.nout = nout 46 | self.hidden_sizes = hidden_sizes 47 | assert self.nout % self.nin == 0, "nout must be integer multiple of nin" 48 | 49 | # define a simple MLP neural net 50 | self.net = [] 51 | hs = [nin] + hidden_sizes + [nout] 52 | for h0, h1 in zip(hs, hs[1:]): 53 | self.net.extend([ 54 | MaskedLinear(h0, h1), 55 | nn.ReLU(), 56 | ]) 57 | self.net.pop() # pop the last ReLU for the output layer 58 | self.net = nn.Sequential(*self.net) 59 | 60 | # seeds for orders/connectivities of the model ensemble 61 | self.natural_ordering = natural_ordering 62 | self.num_masks = num_masks 63 | self.seed = 0 # for cycling through num_masks orderings 64 | 65 | self.m = {} 66 | self.update_masks() # builds the initial self.m connectivity 67 | # note, we could also precompute the masks and cache them, but this 68 | # could get memory expensive for large number of masks. 69 | 70 | def update_masks(self): 71 | if self.m and self.num_masks == 1: return # only a single seed, skip for efficiency 72 | L = len(self.hidden_sizes) 73 | 74 | # fetch the next seed and construct a random stream 75 | rng = np.random.RandomState(self.seed) 76 | self.seed = (self.seed + 1) % self.num_masks 77 | 78 | # sample the order of the inputs and the connectivity of all neurons 79 | if self.random: 80 | self.m[-1] = np.arange(self.nin) if self.natural_ordering else rng.permutation(self.nin) 81 | for l in range(L): 82 | self.m[l] = rng.randint(self.m[l - 1].min(), self.nin - 1, size=self.hidden_sizes[l]) 83 | else: 84 | self.m[-1] = np.arange(self.nin) 85 | for l in range(L): 86 | self.m[l] = np.array([self.nin - 1 - (i % self.nin) for i in range(self.hidden_sizes[l])]) 87 | 88 | # construct the mask matrices 89 | masks = [self.m[l - 1][:, None] <= self.m[l][None, :] for l in range(L)] 90 | masks.append(self.m[L - 1][:, None] < self.m[-1][None, :]) 91 | 92 | # handle the case where nout = nin * k, for integer k > 1 93 | if self.nout > self.nin: 94 | k = int(self.nout / self.nin) 95 | # replicate the mask across the other outputs 96 | masks[-1] = np.concatenate([masks[-1]] * k, axis=1) 97 | 98 | # set the masks in all MaskedLinear layers 99 | layers = [l for l in self.net.modules() if isinstance(l, MaskedLinear)] 100 | for l, m in zip(layers, masks): 101 | l.set_mask(m) 102 | 103 | # map between in_d and order 104 | self.i_map = self.m[-1].copy() 105 | for k in range(len(self.m[-1])): 106 | self.i_map[self.m[-1][k]] = k 107 | 108 | def forward(self, x): 109 | return self.net(x).view(x.shape[0], -1, x.shape[1]).permute(0, 2, 1) 110 | 111 | 112 | # ------------------------------------------------------------------------------ 113 | 114 | 115 | class ConditionnalMADE(MADE): 116 | 117 | def __init__(self, nin, cond_in, hidden_sizes, nout, num_masks=1, natural_ordering=False, random=False, 118 | device="cpu"): 119 | """ 120 | nin: integer; number of inputs 121 | hidden sizes: a list of integers; number of units in hidden layers 122 | nout: integer; number of outputs, which usually collectively parameterize some kind of 1D distribution 123 | note: if nout is e.g. 2x larger than nin (perhaps the mean and std), then the first nin 124 | will be all the means and the second nin will be stds. i.e. output dimensions depend on the 125 | same input dimensions in "chunks" and should be carefully decoded downstream appropriately. 126 | the output of running the tests for this file makes this a bit more clear with examples. 127 | num_masks: can be used to train ensemble over orderings/connections 128 | natural_ordering: force natural ordering of dimensions, don't use random permutations 129 | """ 130 | 131 | super().__init__(nin + cond_in, hidden_sizes, nout, num_masks, natural_ordering, random, device) 132 | self.nin_non_cond = nin 133 | self.cond_in = cond_in 134 | 135 | def forward(self, x, context): 136 | if context is not None: 137 | out = super().forward(torch.cat((context, x), 1)) 138 | else: 139 | out = super().forward(x) 140 | out = out.contiguous()[:, self.cond_in:, :] 141 | return out 142 | 143 | 144 | class AutoregressiveConditioner(Conditioner): 145 | def __init__(self, in_size, hidden, out_size, cond_in=0): 146 | super(AutoregressiveConditioner, self).__init__() 147 | self.in_size = in_size 148 | self.masked_autoregressive_net = ConditionnalMADE(in_size, cond_in=cond_in, hidden_sizes=hidden, nout=out_size*in_size) 149 | 150 | def forward(self, x, context=None): 151 | return self.masked_autoregressive_net(x, context) 152 | 153 | def depth(self): 154 | return self.in_size - 1 155 | -------------------------------------------------------------------------------- /models/NormalizingFlow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .Conditionners import Conditioner, DAGConditioner 4 | from .Normalizers import Normalizer 5 | 6 | 7 | class NormalizingFlow(nn.Module): 8 | def __init__(self): 9 | super(NormalizingFlow, self).__init__() 10 | 11 | ''' 12 | Should return the x transformed and the log determinant of the Jacobian of the transformation 13 | ''' 14 | def forward(self, x, context=None): 15 | pass 16 | 17 | ''' 18 | Should return a term relative to the loss. 19 | ''' 20 | def constraintsLoss(self): 21 | pass 22 | 23 | ''' 24 | Should return the dagness of the associated graph. 25 | ''' 26 | def DAGness(self): 27 | pass 28 | 29 | ''' 30 | Step in the optimization procedure; 31 | ''' 32 | def step(self, epoch_number, loss_avg): 33 | pass 34 | 35 | ''' 36 | Return a list containing the conditioners. 37 | ''' 38 | def getConditioners(self): 39 | pass 40 | 41 | ''' 42 | Return True if the architecture is invertible. 43 | ''' 44 | def isInvertible(self): 45 | pass 46 | 47 | ''' 48 | Return a list containing the normalizers. 49 | ''' 50 | 51 | def getNormalizers(self): 52 | pass 53 | 54 | ''' 55 | Return the x that would generate z: [B, d] tensor. 56 | ''' 57 | def invert(self, z, context=None): 58 | pass 59 | 60 | 61 | class NormalizingFlowStep(NormalizingFlow): 62 | def __init__(self, conditioner: Conditioner, normalizer: Normalizer): 63 | super(NormalizingFlowStep, self).__init__() 64 | self.conditioner = conditioner 65 | self.normalizer = normalizer 66 | 67 | def forward(self, x, context=None): 68 | h = self.conditioner(x, context) 69 | z, jac = self.normalizer(x, h, context) 70 | return z, torch.log(jac).sum(1) 71 | 72 | def constraintsLoss(self): 73 | if type(self.conditioner) is DAGConditioner: 74 | return self.conditioner.loss() 75 | return 0. 76 | 77 | def DAGness(self): 78 | if type(self.conditioner) is DAGConditioner: 79 | return [self.conditioner.get_power_trace()] 80 | return [0.] 81 | 82 | def step(self, epoch_number, loss_avg): 83 | if type(self.conditioner) is DAGConditioner: 84 | self.conditioner.step(epoch_number, loss_avg) 85 | 86 | def getConditioners(self): 87 | return [self.conditioner] 88 | 89 | def getNormalizers(self): 90 | return [self.normalizer] 91 | 92 | def isInvertible(self): 93 | for conditioner in self.getConditioners(): 94 | if not conditioner.is_invertible: 95 | return False 96 | return True 97 | 98 | def invert(self, z, context=None): 99 | x = torch.zeros_like(z) 100 | for i in range(self.conditioner.depth() + 1): 101 | print(i, "/", self.conditioner.depth() + 1) 102 | h = self.conditioner(x, context) 103 | x_prev = x 104 | x = self.normalizer.inverse_transform(z, h, context) 105 | if torch.norm(x - x_prev) == 0.: 106 | break 107 | return x 108 | 109 | 110 | class FCNormalizingFlow(NormalizingFlow): 111 | def __init__(self, steps, z_log_density): 112 | super(FCNormalizingFlow, self).__init__() 113 | self.steps = nn.ModuleList() 114 | self.z_log_density = z_log_density 115 | for step in steps: 116 | self.steps.append(step) 117 | 118 | def forward(self, x, context=None): 119 | jac_tot = 0. 120 | inv_idx = torch.arange(x.shape[1] - 1, -1, -1).long() 121 | for step in self.steps: 122 | z, jac = step(x, context) 123 | x = z[:, inv_idx] 124 | jac_tot += jac 125 | 126 | return z, jac_tot 127 | 128 | def constraintsLoss(self): 129 | loss = 0. 130 | for step in self.steps: 131 | loss += step.constraintsLoss() 132 | return loss 133 | 134 | def DAGness(self): 135 | dagness = [] 136 | for step in self.steps: 137 | dagness += step.DAGness() 138 | return dagness 139 | 140 | def step(self, epoch_number, loss_avg): 141 | for step in self.steps: 142 | step.step(epoch_number, loss_avg) 143 | 144 | def loss(self, z, jac): 145 | log_p_x = jac + self.z_log_density(z) 146 | return self.constraintsLoss() - log_p_x.mean() 147 | 148 | def getNormalizers(self): 149 | normalizers = [] 150 | for step in self.steps: 151 | normalizers += step.getNormalizers() 152 | return normalizers 153 | 154 | def getConditioners(self): 155 | conditioners = [] 156 | for step in self.steps: 157 | conditioners += step.getConditioners() 158 | return conditioners 159 | 160 | def isInvertible(self): 161 | for conditioner in self.getConditioners(): 162 | if not conditioner.is_invertible: 163 | return False 164 | return True 165 | 166 | def invert(self, z, context=None): 167 | for step in range(len(self.steps)): 168 | z = self.steps[-step].invert(z, context) 169 | return z 170 | 171 | 172 | class CNNormalizingFlow(FCNormalizingFlow): 173 | def __init__(self, steps, z_log_density, dropping_factors): 174 | super(CNNormalizingFlow, self).__init__(steps, z_log_density) 175 | self.dropping_factors = dropping_factors 176 | 177 | def forward(self, x, context=None): 178 | b_size = x.shape[0] 179 | jac_tot = 0. 180 | z_all = [] 181 | for step, drop_factors in zip(self.steps, self.dropping_factors): 182 | z, jac = step(x, context) 183 | d_c, d_h, d_w = drop_factors 184 | C, H, W = step.img_sizes 185 | c, h, w = int(C/d_c), int(H/d_h), int(W/d_w) 186 | z_reshaped = z.view(-1, C, H, W).unfold(1, d_c, d_c).unfold(2, d_h, d_h) \ 187 | .unfold(3, d_w, d_w).contiguous().view(b_size, c, h, w, -1) 188 | z_all += [z_reshaped[:, :, :, :, 1:].contiguous().view(b_size, -1)] 189 | x = z.view(-1, C, H, W).unfold(1, d_c, d_c).unfold(2, d_h, d_h) \ 190 | .unfold(3, d_w, d_w).contiguous().view(b_size, c, h, w, -1)[:, :, :, :, 0] \ 191 | .contiguous().view(b_size, -1) 192 | jac_tot += jac 193 | z_all += [x] 194 | z = torch.cat(z_all, 1) 195 | return z, jac_tot 196 | 197 | def invert(self, z, context=None): 198 | b_size = z.shape[0] 199 | z_all = [] 200 | i = 0 201 | for step, drop_factors in zip(self.steps, self.dropping_factors): 202 | d_c, d_h, d_w = drop_factors 203 | C, H, W = step.img_sizes 204 | c, h, w = int(C / d_c), int(H / d_h), int(W / d_w) 205 | nb_z = C*H*W - c*h*w if C*H*W != c*h*w else c*h*w 206 | z_all += [z[:, i:i+nb_z]] 207 | i += nb_z 208 | 209 | x = 0. 210 | for i in range(1, len(self.steps) + 1): 211 | step = self.steps[-i] 212 | drop_factors = self.dropping_factors[-i] 213 | d_c, d_h, d_w = drop_factors 214 | C, H, W = step.img_sizes 215 | c, h, w = int(C / d_c), int(H / d_h), int(W / d_w) 216 | z = z_all[-i] 217 | if c*h*w != C*H*W: 218 | z = z.view(b_size, c, h, w, -1) 219 | x = x.view(b_size, c, h, w, 1) 220 | z = torch.cat((x, z), 4) 221 | z = z.view(b_size, c, h, w, d_c, d_h, d_w) 222 | z = z.permute(0, 1, 2, 3, 6, 4, 5).contiguous().view(b_size, c, h, W, d_c, d_h) 223 | z = z.permute(0, 1, 2, 5, 3, 4).contiguous().view(b_size, c, H, W, d_c) 224 | z = z.permute(0, 1, 4, 2, 3).contiguous().view(b_size, C, H, W) 225 | x = step.invert(z.view(b_size, -1), context) 226 | return x 227 | 228 | -------------------------------------------------------------------------------- /UCIExperimentsConfigurations.yml: -------------------------------------------------------------------------------- 1 | power-mono-DAG: 2 | dataset: 'power' 3 | nb_flow: 1 4 | b_size: 2500 5 | nb_epoch: 10000 6 | conditioner: 'DAG' 7 | emb_net: [60, 60, 60, 30] 8 | nb_steps_dual: 30 9 | l1: 0. 10 | gumble_T: .5 11 | normalizer: 'monotonic' 12 | int_net: [100, 100, 100] 13 | nb_steps: 20 14 | solver: 'CC' 15 | 16 | power-affine-DAG: 17 | dataset: 'power' 18 | nb_flow: 1 19 | b_size: 2500 20 | nb_epoch: 10000 21 | conditioner: 'DAG' 22 | emb_net: [60, 60, 60, 2] 23 | nb_steps_dual: 30 24 | l1: 0. 25 | gumble_T: .5 26 | normalizer: 'affine' 27 | 28 | power-mono-coupling: 29 | dataset: 'power' 30 | nb_flow: 1 31 | b_size: 2500 32 | nb_epoch: 10000 33 | conditioner: 'Coupling' 34 | emb_net: [60, 60, 60, 30] 35 | normalizer: 'monotonic' 36 | int_net: [100, 100, 100] 37 | nb_steps: 20 38 | solver: 'CC' 39 | 40 | power-affine-coupling: 41 | dataset: 'power' 42 | nb_flow: 1 43 | b_size: 2500 44 | nb_epoch: 10000 45 | conditioner: 'Coupling' 46 | emb_net: [60, 60, 60, 2] 47 | normalizer: 'affine' 48 | 49 | power-mono-autoregressive: 50 | dataset: 'power' 51 | nb_flow: 1 52 | b_size: 2500 53 | nb_epoch: 10000 54 | conditioner: 'Autoregressive' 55 | emb_net: [60, 60, 60, 30] 56 | normalizer: 'monotonic' 57 | int_net: [100, 100, 100] 58 | nb_steps: 20 59 | solver: 'CC' 60 | 61 | power-affine-autoregressive: 62 | dataset: 'power' 63 | nb_flow: 1 64 | b_size: 2500 65 | nb_epoch: 10000 66 | conditioner: 'Autoregressive' 67 | emb_net: [60, 60, 60, 2] 68 | normalizer: 'affine' 69 | 70 | gas-mono-DAG: 71 | dataset: 'gas' 72 | nb_flow: 1 73 | b_size: 10000 74 | nb_epoch: 10000 75 | conditioner: 'DAG' 76 | emb_net: [80, 80, 80, 30] 77 | nb_steps_dual: 100 78 | l1: 0. 79 | gumble_T: .5 80 | normalizer: 'monotonic' 81 | int_net: [200, 200, 200] 82 | nb_steps: 20 83 | solver: 'CC' 84 | weight_decay: 1e-3 85 | 86 | gas-affine-DAG: 87 | dataset: 'gas' 88 | nb_flow: 1 89 | b_size: 10000 90 | nb_epoch: 10000 91 | conditioner: 'DAG' 92 | emb_net: [80, 80, 80, 30] 93 | nb_steps_dual: 100 94 | l1: 0. 95 | gumble_T: .5 96 | normalizer: 'affine' 97 | weight_decay: 1e-3 98 | 99 | gas-mono-coupling: 100 | dataset: 'gas' 101 | nb_flow: 1 102 | b_size: 10000 103 | nb_epoch: 10000 104 | conditioner: 'Coupling' 105 | emb_net: [80, 80, 80, 30] 106 | normalizer: 'monotonic' 107 | int_net: [200, 200, 200] 108 | nb_steps: 20 109 | solver: 'CC' 110 | weight_decay: 1e-3 111 | 112 | gas-affine-coupling: 113 | dataset: 'gas' 114 | nb_flow: 1 115 | b_size: 10000 116 | nb_epoch: 10000 117 | conditioner: 'Coupling' 118 | emb_net: [80, 80, 80, 30] 119 | normalizer: 'affine' 120 | weight_decay: 1e-3 121 | 122 | gas-mono-autoregressive: 123 | dataset: 'gas' 124 | nb_flow: 1 125 | b_size: 10000 126 | nb_epoch: 10000 127 | conditioner: 'Autoregressive' 128 | emb_net: [80, 80, 80, 30] 129 | normalizer: 'monotonic' 130 | int_net: [200, 200, 200] 131 | nb_steps: 20 132 | solver: 'CC' 133 | weight_decay: 1e-3 134 | 135 | gas-affine-autoregressive: 136 | dataset: 'gas' 137 | nb_flow: 1 138 | b_size: 10000 139 | nb_epoch: 10000 140 | conditioner: 'Autoregressive' 141 | emb_net: [80, 80, 80, 30] 142 | normalizer: 'affine' 143 | weight_decay: 1e-3 144 | 145 | hepmass-mono-DAG: 146 | dataset: 'hepmass' 147 | nb_flow: 1 148 | b_size: 100 149 | nb_epoch: 10000 150 | conditioner: 'DAG' 151 | emb_net: [210, 210, 210, 30] 152 | nb_steps_dual: 25 153 | l1: 0. 154 | gumble_T: .5 155 | normalizer: 'monotonic' 156 | int_net: [200, 200, 200] 157 | nb_steps: 20 158 | solver: 'CCParallel' 159 | weight_decay: 1e-4 160 | 161 | hepmass-affine-DAG: 162 | dataset: 'hepmass' 163 | nb_flow: 1 164 | b_size: 100 165 | nb_epoch: 10000 166 | conditioner: 'DAG' 167 | emb_net: [210, 210, 210, 30] 168 | nb_steps_dual: 25 169 | l1: 0. 170 | gumble_T: .5 171 | normalizer: 'affine' 172 | weight_decay: 1e-4 173 | 174 | hepmass-mono-coupling: 175 | dataset: 'hepmass' 176 | nb_flow: 1 177 | b_size: 100 178 | nb_epoch: 10000 179 | conditioner: 'Coupling' 180 | emb_net: [210, 210, 210, 30] 181 | normalizer: 'monotonic' 182 | int_net: [200, 200, 200] 183 | nb_steps: 20 184 | solver: 'CCParallel' 185 | weight_decay: 1e-4 186 | 187 | hepmass-affine-coupling: 188 | dataset: 'hepmass' 189 | nb_flow: 1 190 | b_size: 100 191 | nb_epoch: 10000 192 | conditioner: 'Coupling' 193 | emb_net: [210, 210, 210, 30] 194 | normalizer: 'affine' 195 | weight_decay: 1e-4 196 | 197 | hepmass-mono-autoregressive: 198 | dataset: 'hepmass' 199 | nb_flow: 1 200 | b_size: 100 201 | nb_epoch: 10000 202 | conditioner: 'Autoregressive' 203 | emb_net: [210, 210, 210, 30] 204 | normalizer: 'monotonic' 205 | int_net: [200, 200, 200] 206 | nb_steps: 20 207 | solver: 'CCParallel' 208 | weight_decay: 1e-4 209 | 210 | hepmass-affine-autoregressive: 211 | dataset: 'hepmass' 212 | nb_flow: 1 213 | b_size: 100 214 | nb_epoch: 10000 215 | conditioner: 'Autoregressive' 216 | emb_net: [210, 210, 210, 30] 217 | normalizer: 'affine' 218 | weight_decay: 1e-4 219 | 220 | miniboone-mono-DAG: 221 | dataset: 'miniboone' 222 | nb_flow: 1 223 | b_size: 100 224 | nb_epoch: 10000 225 | conditioner: 'DAG' 226 | emb_net: [430, 430, 430, 30] 227 | nb_steps_dual: 200 228 | l1: 0. 229 | gumble_T: .5 230 | normalizer: 'monotonic' 231 | int_net: [40, 40, 40] 232 | nb_steps: 20 233 | solver: 'CCParallel' 234 | weight_decay: 1e-2 235 | 236 | miniboone-affine-DAG: 237 | dataset: 'miniboone' 238 | nb_flow: 1 239 | b_size: 100 240 | nb_epoch: 10000 241 | conditioner: 'DAG' 242 | emb_net: [430, 430, 430, 30] 243 | nb_steps_dual: 200 244 | l1: 0. 245 | gumble_T: .5 246 | normalizer: 'affine' 247 | weight_decay: 1e-2 248 | 249 | miniboone-mono-coupling: 250 | dataset: 'miniboone' 251 | nb_flow: 1 252 | b_size: 100 253 | nb_epoch: 10000 254 | conditioner: 'Coupling' 255 | emb_net: [430, 430, 430, 30] 256 | normalizer: 'monotonic' 257 | int_net: [40, 40, 40] 258 | nb_steps: 20 259 | solver: 'CCParallel' 260 | weight_decay: 1e-2 261 | 262 | miniboone-affine-coupling: 263 | dataset: 'miniboone' 264 | nb_flow: 1 265 | b_size: 100 266 | nb_epoch: 10000 267 | conditioner: 'Coupling' 268 | emb_net: [430, 430, 430, 30] 269 | normalizer: 'affine' 270 | weight_decay: 1e-2 271 | 272 | miniboone-mono-autoregressive: 273 | dataset: 'miniboone' 274 | nb_flow: 1 275 | b_size: 100 276 | nb_epoch: 10000 277 | conditioner: 'Autoregressive' 278 | emb_net: [430, 430, 430, 30] 279 | normalizer: 'monotonic' 280 | int_net: [40, 40, 40] 281 | nb_steps: 20 282 | solver: 'CCParallel' 283 | weight_decay: 1e-2 284 | 285 | miniboone-affine-autoregressive: 286 | dataset: 'miniboone' 287 | nb_flow: 1 288 | b_size: 100 289 | nb_epoch: 10000 290 | conditioner: 'Autoregressive' 291 | emb_net: [430, 430, 430, 30] 292 | normalizer: 'affine' 293 | weight_decay: 1e-2 294 | 295 | bsds300-mono-DAG: 296 | dataset: 'bsds300' 297 | nb_flow: 1 298 | b_size: 100 299 | nb_epoch: 10000 300 | conditioner: 'DAG' 301 | emb_net: [630, 630, 630, 30] 302 | nb_steps_dual: 20 303 | l1: 0. 304 | gumble_T: .5 305 | normalizer: 'monotonic' 306 | int_net: [150, 150, 150] 307 | nb_steps: 20 308 | solver: 'CCParallel' 309 | weight_decay: 1e-4 310 | 311 | bsds300-affine-DAG: 312 | dataset: 'bsds300' 313 | nb_flow: 1 314 | b_size: 100 315 | nb_epoch: 10000 316 | conditioner: 'DAG' 317 | emb_net: [630, 630, 630, 30] 318 | nb_steps_dual: 20 319 | l1: 0. 320 | gumble_T: .5 321 | normalizer: 'affine' 322 | weight_decay: 1e-4 323 | 324 | bsds300-mono-coupling: 325 | dataset: 'bsds300' 326 | nb_flow: 1 327 | b_size: 100 328 | nb_epoch: 10000 329 | conditioner: 'Coupling' 330 | emb_net: [630, 630, 630, 30] 331 | normalizer: 'monotonic' 332 | int_net: [150, 150, 150] 333 | nb_steps: 20 334 | solver: 'CCParallel' 335 | weight_decay: 1e-4 336 | 337 | bsds300-affine-coupling: 338 | dataset: 'bsds300' 339 | nb_flow: 1 340 | b_size: 100 341 | nb_epoch: 10000 342 | conditioner: 'Coupling' 343 | emb_net: [630, 630, 630, 30] 344 | normalizer: 'affine' 345 | weight_decay: 1e-4 346 | 347 | bsds300-mono-autoregressive: 348 | dataset: 'bsds300' 349 | nb_flow: 1 350 | b_size: 100 351 | nb_epoch: 10000 352 | conditioner: 'Autoregressive' 353 | emb_net: [630, 630, 630, 30] 354 | normalizer: 'monotonic' 355 | int_net: [150, 150, 150] 356 | nb_steps: 20 357 | solver: 'CCParallel' 358 | weight_decay: 1e-4 359 | 360 | bsds300-affine-autoregressive: 361 | dataset: 'bsds300' 362 | nb_flow: 1 363 | b_size: 100 364 | nb_epoch: 10000 365 | conditioner: 'Autoregressive' 366 | emb_net: [630, 630, 630, 30] 367 | normalizer: 'affine' 368 | weight_decay: 1e-4 -------------------------------------------------------------------------------- /ToyExperiments.py: -------------------------------------------------------------------------------- 1 | import lib.toy_data as toy_data 2 | from models import * 3 | import torch 4 | from timeit import default_timer as timer 5 | import lib.utils as utils 6 | import os 7 | import lib.visualize_flow as vf 8 | import matplotlib.pyplot as plt 9 | import networkx as nx 10 | import numpy as np 11 | import math 12 | import matplotlib 13 | import seaborn as sns 14 | sns.set() 15 | from matplotlib import gridspec 16 | flatui = ["#9b59b6", "#3498db", "#95a5a6", "#e74c3c", "#34495e", "#2ecc71"] 17 | sns.palplot(sns.color_palette(flatui)) 18 | 19 | cond_types = {"DAG": DAGConditioner, "Coupling": CouplingConditioner, "Autoregressive": AutoregressiveConditioner} 20 | norm_types = {"Affine": AffineNormalizer, "Monotonic": MonotonicNormalizer} 21 | 22 | def train_toy(toy, load=True, nb_step_dual=300, nb_steps=15, folder="", l1=1., nb_epoch=20000, pre_heating_epochs=10, 23 | nb_flow=3, cond_type = "Coupling", emb_net = [150, 150, 150]): 24 | logger = utils.get_logger(logpath=os.path.join(folder, toy, 'logs'), filepath=os.path.abspath(__file__)) 25 | 26 | logger.info("Creating model...") 27 | 28 | device = "cpu" if not(torch.cuda.is_available()) else "cuda:0" 29 | 30 | nb_samp = 100 31 | batch_size = 100 32 | 33 | x_test = torch.tensor(toy_data.inf_train_gen(toy, batch_size=1000)).to(device) 34 | x = torch.tensor(toy_data.inf_train_gen(toy, batch_size=1000)).to(device) 35 | 36 | dim = x.shape[1] 37 | 38 | norm_type = "Affine" 39 | save_name = norm_type + str(emb_net) + str(nb_flow) 40 | solver = "CCParallel" 41 | int_net = [150, 150, 150] 42 | 43 | conditioner_type = cond_types[cond_type] 44 | conditioner_args = {"in_size": dim, "hidden": emb_net[:-1], "out_size": emb_net[-1]} 45 | if conditioner_type is DAGConditioner: 46 | conditioner_args['l1'] = l1 47 | conditioner_args['gumble_T'] = .5 48 | conditioner_args['nb_epoch_update'] = nb_step_dual 49 | conditioner_args["hot_encoding"] = True 50 | normalizer_type = norm_types[norm_type] 51 | if normalizer_type is MonotonicNormalizer: 52 | normalizer_args = {"integrand_net": int_net, "cond_size": emb_net[-1], "nb_steps": nb_steps, 53 | "solver": solver} 54 | else: 55 | normalizer_args = {} 56 | 57 | model = buildFCNormalizingFlow(nb_flow, conditioner_type, conditioner_args, normalizer_type, normalizer_args) 58 | 59 | opt = torch.optim.Adam(model.parameters(), 1e-4, weight_decay=1e-5) 60 | 61 | if load: 62 | logger.info("Loading model...") 63 | model.load_state_dict(torch.load(folder + toy + '/' + save_name + 'model.pt')) 64 | model.train() 65 | opt.load_state_dict(torch.load(folder + toy + '/' + save_name + 'ADAM.pt')) 66 | logger.info("Model loaded.") 67 | 68 | if True: 69 | for step in model.steps: 70 | step.conditioner.stoch_gate = True 71 | step.conditioner.noise_gate = False 72 | step.conditioner.gumble_T = .5 73 | torch.autograd.set_detect_anomaly(True) 74 | for epoch in range(nb_epoch): 75 | loss_tot = 0 76 | start = timer() 77 | for j in range(0, nb_samp, batch_size): 78 | cur_x = torch.tensor(toy_data.inf_train_gen(toy, batch_size=batch_size)).to(device) 79 | z, jac = model(cur_x) 80 | loss = model.loss(z, jac) 81 | loss_tot += loss.detach() 82 | if math.isnan(loss.item()): 83 | ll, z = model.compute_ll(cur_x) 84 | print(ll) 85 | print(z) 86 | print(ll.max(), z.max()) 87 | exit() 88 | opt.zero_grad() 89 | loss.backward(retain_graph=True) 90 | opt.step() 91 | model.step(epoch, loss_tot) 92 | 93 | 94 | end = timer() 95 | z, jac = model(x_test) 96 | ll = (model.z_log_density(z) + jac) 97 | ll_test = -ll.mean() 98 | dagness = max(model.DAGness()) 99 | logger.info("epoch: {:d} - Train loss: {:4f} - Test loss: {:4f} - <>: {:4f} - Elapsed time per epoch {:4f} (seconds)". 100 | format(epoch, loss_tot.item(), ll_test.item(), dagness, end-start)) 101 | 102 | 103 | if epoch % 100 == 0 and False: 104 | with torch.no_grad(): 105 | stoch_gate = model.getDag().stoch_gate 106 | noise_gate = model.getDag().noise_gate 107 | s_thresh = model.getDag().s_thresh 108 | model.getDag().stoch_gate = False 109 | model.getDag().noise_gate = False 110 | model.getDag().s_thresh = True 111 | for threshold in [.95, .1, .01, .0001, 1e-8]: 112 | model.set_h_threshold(threshold) 113 | # Valid loop 114 | z, jac = model(x_test) 115 | ll = (model.z_log_density(z) + jac) 116 | ll_test = -ll.mean().item() 117 | dagness = max(model.DAGness()).item() 118 | logger.info("epoch: {:d} - Threshold: {:4f} - Valid log-likelihood: {:4f} - <>: {:4f}". 119 | format(epoch, threshold, ll_test, dagness)) 120 | model.getDag().stoch_gate = stoch_gate 121 | model.getDag().noise_gate = noise_gate 122 | model.getDag().s_thresh = s_thresh 123 | model.set_h_threshold(0.) 124 | 125 | 126 | if epoch % 500 == 0: 127 | font = {'family': 'normal', 128 | 'weight': 'normal', 129 | 'size': 25} 130 | 131 | matplotlib.rc('font', **font) 132 | if toy in ["2spirals-8gaussians", "4-2spirals-8gaussians", "8-2spirals-8gaussians", "2gaussians", 133 | "4gaussians", "2igaussians", "8gaussians"] or True: 134 | def compute_ll(x): 135 | z, jac = model(x) 136 | ll = (model.z_log_density(z) + jac) 137 | return ll, z 138 | with torch.no_grad(): 139 | npts = 100 140 | plt.figure(figsize=(12, 12)) 141 | gs = gridspec.GridSpec(2, 2, width_ratios=[3, 1], height_ratios=[3, 1]) 142 | ax = plt.subplot(gs[0]) 143 | qz_1, qz_2 = vf.plt_flow(compute_ll, ax, npts=npts, device=device) 144 | plt.subplot(gs[1]) 145 | plt.plot(qz_1, np.linspace(-4, 4, npts)) 146 | plt.ylabel('$x_2$', fontsize=25, rotation=-90, labelpad=20) 147 | 148 | plt.xticks([]) 149 | plt.subplot(gs[2]) 150 | plt.plot(np.linspace(-4, 4, npts), qz_2) 151 | plt.xlabel('$x_1$', fontsize=25) 152 | plt.yticks([]) 153 | plt.savefig("%s%s/flow_%s_%d.pdf" % (folder, toy, save_name, epoch)) 154 | torch.save(model.state_dict(), folder + toy + '/' + save_name + 'model.pt') 155 | torch.save(opt.state_dict(), folder + toy + '/'+ save_name + 'ADAM.pt') 156 | 157 | toy = "8gaussians" 158 | 159 | import argparse 160 | datasets = ["2igaussians", "2gaussians", "8gaussians", "swissroll", "moons", "pinwheel", "cos", "2spirals", "checkerboard", "line", "line-noisy", 161 | "circles", "joint_gaussian", "2spirals-8gaussians", "4-2spirals-8gaussians", "8-2spirals-8gaussians", 162 | "8-MIX", "7-MIX", "4gaussians"] 163 | 164 | parser = argparse.ArgumentParser(description='') 165 | parser.add_argument("-dataset", default=None, choices=datasets, help="Which toy problem ?") 166 | parser.add_argument("-load", default=False, action="store_true", help="Load a model ?") 167 | parser.add_argument("-folder", default="", help="Folder") 168 | parser.add_argument("-nb_steps_dual", default=50, type=int, help="number of step between updating Acyclicity constraint and sparsity constraint") 169 | parser.add_argument("-l1", default=.0, type=float, help="Maximum weight for l1 regularization") 170 | parser.add_argument("-nb_epoch", default=20000, type=int, help="Number of epochs") 171 | 172 | args = parser.parse_args() 173 | 174 | for d in ["pinwheel"]: 175 | for net in [[200, 200, 200, 200]]: 176 | for nb_flow in [5]: 177 | if not (os.path.isdir(args.folder + d)): 178 | os.makedirs(args.folder + d) 179 | train_toy(d, load=False, nb_epoch=50000, nb_flow=nb_flow, cond_type="Coupling", emb_net=net) 180 | 181 | if args.dataset is None: 182 | toys = datasets 183 | else: 184 | toys = [args.dataset] 185 | 186 | for toy in toys: 187 | if not(os.path.isdir(args.folder + toy)): 188 | os.makedirs(args.folder + toy) 189 | train_toy(toy, load=args.load, folder=args.folder, nb_step_dual=args.nb_steps_dual, l1=args.l1, 190 | nb_epoch=args.nb_epoch) 191 | -------------------------------------------------------------------------------- /lib/toy_data.py: -------------------------------------------------------------------------------- 1 | # 2 | 3 | import numpy as np 4 | import sklearn 5 | import torch 6 | import sklearn.datasets 7 | from PIL import Image 8 | import os 9 | 10 | # Dataset iterator 11 | def inf_train_gen(data, rng=None, batch_size=200): 12 | if rng is None: 13 | rng = np.random.RandomState() 14 | #print(rng) 15 | 16 | if data == "2spirals-8gaussians": 17 | data1 = inf_train_gen("2spirals", rng=rng, batch_size=batch_size) 18 | data2 = inf_train_gen("8gaussians", rng=rng, batch_size=batch_size) 19 | return np.concatenate([data1, data2], axis=1) 20 | 21 | if data == "4-2spirals-8gaussians": 22 | data1 = inf_train_gen("2spirals", rng=rng, batch_size=batch_size) 23 | data2 = inf_train_gen("8gaussians", rng=rng, batch_size=batch_size) 24 | data3 = inf_train_gen("2spirals", rng=rng, batch_size=batch_size) 25 | data4 = inf_train_gen("8gaussians", rng=rng, batch_size=batch_size) 26 | return np.concatenate([data1, data2, data3, data4], axis=1) 27 | 28 | if data == "8-2spirals-8gaussians": 29 | data1 = inf_train_gen("4-2spirals-8gaussians", rng=rng, batch_size=batch_size) 30 | data2 = inf_train_gen("4-2spirals-8gaussians", rng=rng, batch_size=batch_size) 31 | return np.concatenate([data1, data2], axis=1) 32 | 33 | if data == "8-MIX": 34 | data1 = inf_train_gen("2spirals", rng=rng, batch_size=batch_size) 35 | data2 = inf_train_gen("8gaussians", rng=rng, batch_size=batch_size) 36 | data3 = inf_train_gen("swissroll", rng=rng, batch_size=batch_size) 37 | data4 = inf_train_gen("circles", rng=rng, batch_size=batch_size) 38 | data8 = inf_train_gen("moons", rng=rng, batch_size=batch_size) 39 | data6 = inf_train_gen("pinwheel", rng=rng, batch_size=batch_size) 40 | data7 = inf_train_gen("checkerboard", rng=rng, batch_size=batch_size) 41 | data5 = inf_train_gen("line", rng=rng, batch_size=batch_size) 42 | std = np.array([1.604934 , 1.584863 , 2.0310535, 2.0305095, 1.337718 , 1.4043778, 1.6944685, 1.6935346, 43 | 1.7434783, 1.0092416, 1.4860426, 1.485661 , 2.3067558, 2.311637 , 1.4430547, 1.4430547], dtype=np.float32) 44 | data = np.concatenate([data1, data2, data3, data4, data5, data6, data7, data8], axis=1) 45 | 46 | return data/std 47 | if data == "7-MIX": 48 | data1 = inf_train_gen("2spirals", rng=rng, batch_size=batch_size) 49 | data2 = inf_train_gen("8gaussians", rng=rng, batch_size=batch_size) 50 | data3 = inf_train_gen("swissroll", rng=rng, batch_size=batch_size) 51 | data4 = inf_train_gen("circles", rng=rng, batch_size=batch_size) 52 | data5 = inf_train_gen("moons", rng=rng, batch_size=batch_size) 53 | data6 = inf_train_gen("pinwheel", rng=rng, batch_size=batch_size) 54 | data7 = inf_train_gen("checkerboard", rng=rng, batch_size=batch_size) 55 | std = np.array([1.604934 , 1.584863 , 2.0310535, 2.0305095, 1.337718 , 1.4043778, 1.6944685, 1.6935346, 56 | 1.7434783, 1.0092416, 1.4860426, 1.485661 , 2.3067558, 2.311637], dtype=np.float32) 57 | data = np.concatenate([data1, data2, data3, data4, data5, data6, data7], axis=1) 58 | 59 | return data/std 60 | 61 | 62 | if data == "swissroll": 63 | data = sklearn.datasets.make_swiss_roll(n_samples=batch_size, noise=1.0)[0] 64 | data = data.astype("float32")[:, [0, 2]] 65 | data /= 5 66 | return data 67 | 68 | elif data == "circles": 69 | data = sklearn.datasets.make_circles(n_samples=batch_size, factor=.5, noise=0.08)[0] 70 | data = data.astype("float32") 71 | data *= 3 72 | return data 73 | 74 | elif data == "moons": 75 | data = sklearn.datasets.make_moons(n_samples=batch_size, noise=0.1)[0] 76 | data = data.astype("float32") 77 | data = data * 2 + np.array([-1, -0.2]) 78 | data = data.astype("float32") 79 | return data 80 | 81 | elif data == "8gaussians": 82 | scale = 4. 83 | centers = [(1, 0), (-1, 0), (0, 1), (0, -1), (1. / np.sqrt(2), 1. / np.sqrt(2)), 84 | (1. / np.sqrt(2), -1. / np.sqrt(2)), (-1. / np.sqrt(2), 85 | 1. / np.sqrt(2)), (-1. / np.sqrt(2), -1. / np.sqrt(2))] 86 | centers = [(scale * x, scale * y) for x, y in centers] 87 | 88 | dataset = [] 89 | for i in range(batch_size): 90 | point = rng.randn(2) * 0.5 91 | idx = rng.randint(8) 92 | center = centers[idx] 93 | point[0] += center[0] 94 | point[1] += center[1] 95 | dataset.append(point) 96 | dataset = np.array(dataset, dtype="float32") 97 | dataset /= 1.414 98 | return dataset 99 | 100 | elif data == "2gaussians": 101 | scale = 4. 102 | centers = [(.5, -.5), (-.5, .5)] 103 | centers = [(scale * x, scale * y) for x, y in centers] 104 | 105 | dataset = [] 106 | for i in range(batch_size): 107 | point = rng.randn(2) * .75 108 | idx = rng.randint(2) 109 | center = centers[idx] 110 | point[0] += center[0] 111 | point[1] += center[1] 112 | dataset.append(point) 113 | dataset = np.array(dataset, dtype="float32") 114 | #dataset /= 1.414 115 | return dataset 116 | 117 | elif data == "4gaussians": 118 | scale = 4. 119 | centers = [(.5, -.5), (-.5, .5), (.5, .5), (-.5, -.5)] 120 | centers = [(scale * x, scale * y) for x, y in centers] 121 | 122 | dataset = [] 123 | for i in range(batch_size): 124 | point = rng.randn(2) * .75 125 | idx = rng.randint(4) 126 | center = centers[idx] 127 | point[0] += center[0] 128 | point[1] += center[1] 129 | dataset.append(point) 130 | dataset = np.array(dataset, dtype="float32") 131 | # dataset /= 1.414 132 | return dataset 133 | 134 | elif data == "2igaussians": 135 | scale = 4. 136 | centers = [(.5, 0.), (-.5, .0)] 137 | centers = [(scale * x, scale * y) for x, y in centers] 138 | 139 | dataset = [] 140 | for i in range(batch_size): 141 | point = rng.randn(2) * .75 142 | idx = rng.randint(2) 143 | center = centers[idx] 144 | point[0] += center[0] 145 | point[1] += center[1] 146 | dataset.append(point) 147 | dataset = np.array(dataset, dtype="float32") 148 | # dataset /= 1.414 149 | return dataset 150 | 151 | elif data == "conditionnal8gaussians": 152 | scale = 4. 153 | centers = [(1, 0), (-1, 0), (0, 1), (0, -1), (1. / np.sqrt(2), 1. / np.sqrt(2)), 154 | (1. / np.sqrt(2), -1. / np.sqrt(2)), (-1. / np.sqrt(2), 155 | 1. / np.sqrt(2)), (-1. / np.sqrt(2), -1. / np.sqrt(2))] 156 | centers = [(scale * x, scale * y) for x, y in centers] 157 | 158 | dataset = [] 159 | context = np.zeros((batch_size, 8)) 160 | for i in range(batch_size): 161 | point = rng.randn(2) * 0.5 162 | idx = rng.randint(8) 163 | context[i, idx] = 1 164 | center = centers[idx] 165 | point[0] += center[0] 166 | point[1] += center[1] 167 | dataset.append(point) 168 | dataset = np.array(dataset, dtype="float32") 169 | dataset /= 1.414 170 | return dataset, context 171 | 172 | elif data == "pinwheel": 173 | radial_std = 0.3 174 | tangential_std = 0.1 175 | num_classes = 5 176 | num_per_class = batch_size // 5 177 | rate = 0.25 178 | rads = np.linspace(0, 2 * np.pi, num_classes, endpoint=False) 179 | 180 | features = rng.randn(num_classes*num_per_class, 2) \ 181 | * np.array([radial_std, tangential_std]) 182 | features[:, 0] += 1. 183 | labels = np.repeat(np.arange(num_classes), num_per_class) 184 | 185 | angles = rads[labels] + rate * np.exp(features[:, 0]) 186 | rotations = np.stack([np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)]) 187 | rotations = np.reshape(rotations.T, (-1, 2, 2)) 188 | 189 | return 2 * rng.permutation(np.einsum("ti,tij->tj", features, rotations)).astype("float32") 190 | 191 | elif data == "2spirals": 192 | n = np.sqrt(rng.rand(batch_size // 2, 1)) * 540 * (2 * np.pi) / 360 193 | d1x = -np.cos(n) * n + rng.rand(batch_size // 2, 1) * 0.5 194 | d1y = np.sin(n) * n + rng.rand(batch_size // 2, 1) * 0.5 195 | x = np.vstack((np.hstack((d1x, d1y)), np.hstack((-d1x, -d1y)))) / 3 196 | x += rng.randn(*x.shape) * 0.1 197 | return x.astype("float32") 198 | 199 | elif data == "checkerboard": 200 | x1 = np.random.rand(batch_size) * 4 - 2 201 | x2_ = np.random.rand(batch_size) - np.random.randint(0, 2, batch_size) * 2 202 | x2 = x2_ + (np.floor(x1) % 2) 203 | return np.concatenate([x1[:, None], x2[:, None]], 1).astype("float32") * 2 204 | 205 | elif data == "line": 206 | x = rng.rand(batch_size) 207 | #x = np.arange(0., 1., 1/batch_size) 208 | x = x * 5 - 2.5 209 | y = x #- x + rng.rand(batch_size) 210 | return np.stack((x, y), 1).astype("float32") 211 | elif data == "line-noisy": 212 | x = rng.rand(batch_size) 213 | x = x * 5 - 2.5 214 | y = x + rng.randn(batch_size) 215 | return np.stack((x, y), 1).astype("float32") 216 | elif data == "cos": 217 | x = rng.rand(batch_size) * 6 - 3 218 | y = np.sin(x*5) * 2.5 + np.random.randn(batch_size) * .3 219 | return np.stack((x, y), 1).astype("float32") 220 | elif data == "joint_gaussian": 221 | x2 = torch.distributions.Normal(0., 4.).sample((batch_size, 1)) 222 | x1 = torch.distributions.Normal(0., 1.).sample((batch_size, 1)) + (x2**2)/4 223 | 224 | return torch.cat((x1, x2), 1) 225 | else: 226 | return inf_train_gen("8gaussians", rng, batch_size) 227 | 228 | 229 | -------------------------------------------------------------------------------- /UCIdatasets/download_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Apr 19 15:58:53 2017 4 | @author: Chin-Wei 5 | # some code adapted from https://github.com/yburda/iwae/blob/master/download_mnist.py 6 | 7 | LSUN 8 | https://github.com/fyu/lsun 9 | """ 10 | 11 | import urllib 12 | import pickle 13 | import os 14 | import struct 15 | import numpy as np 16 | import gzip 17 | import time 18 | import urllib.request 19 | 20 | savedir = 'UCIdatasets/data' 21 | mnist = False 22 | cifar10 = False 23 | omniglot = False 24 | maf = True 25 | 26 | 27 | class Progbar(object): 28 | def __init__(self, target, width=30, verbose=1): 29 | ''' 30 | @param target: total number of steps expected 31 | ''' 32 | self.width = width 33 | self.target = target 34 | self.sum_values = {} 35 | self.unique_values = [] 36 | self.start = time.time() 37 | self.total_width = 0 38 | self.seen_so_far = 0 39 | self.verbose = verbose 40 | 41 | def update(self, current, values=[]): 42 | ''' 43 | @param current: index of current step 44 | @param values: list of tuples (name, value_for_last_step). 45 | The progress bar will display averages for these values. 46 | ''' 47 | for k, v in values: 48 | if k not in self.sum_values: 49 | self.sum_values[k] = [v * (current - self.seen_so_far), current - self.seen_so_far] 50 | self.unique_values.append(k) 51 | else: 52 | self.sum_values[k][0] += v * (current - self.seen_so_far) 53 | self.sum_values[k][1] += (current - self.seen_so_far) 54 | self.seen_so_far = current 55 | 56 | now = time.time() 57 | if self.verbose == 1: 58 | prev_total_width = self.total_width 59 | sys.stdout.write("\b" * prev_total_width) 60 | sys.stdout.write("\r") 61 | 62 | numdigits = int(np.floor(np.log10(self.target))) + 1 63 | barstr = '%%%dd/%%%dd [' % (numdigits, numdigits) 64 | bar = barstr % (current, self.target) 65 | prog = float(current) / self.target 66 | prog_width = int(self.width * prog) 67 | if prog_width > 0: 68 | bar += ('=' * (prog_width - 1)) 69 | if current < self.target: 70 | bar += '>' 71 | else: 72 | bar += '=' 73 | bar += ('.' * (self.width - prog_width)) 74 | bar += ']' 75 | sys.stdout.write(bar) 76 | self.total_width = len(bar) 77 | 78 | if current: 79 | time_per_unit = (now - self.start) / current 80 | else: 81 | time_per_unit = 0 82 | eta = time_per_unit * (self.target - current) 83 | info = '' 84 | if current < self.target: 85 | info += ' - ETA: %ds' % eta 86 | else: 87 | info += ' - %ds' % (now - self.start) 88 | for k in self.unique_values: 89 | info += ' - %s:' % k 90 | if type(self.sum_values[k]) is list: 91 | avg = self.sum_values[k][0] / max(1, self.sum_values[k][1]) 92 | if abs(avg) > 1e-3: 93 | info += ' %.4f' % avg 94 | else: 95 | info += ' %.4e' % avg 96 | else: 97 | info += ' %s' % self.sum_values[k] 98 | 99 | self.total_width += len(info) 100 | if prev_total_width > self.total_width: 101 | info += ((prev_total_width - self.total_width) * " ") 102 | 103 | sys.stdout.write(info) 104 | sys.stdout.flush() 105 | 106 | if current >= self.target: 107 | sys.stdout.write("\n") 108 | 109 | if self.verbose == 2: 110 | if current >= self.target: 111 | info = '%ds' % (now - self.start) 112 | for k in self.unique_values: 113 | info += ' - %s:' % k 114 | avg = self.sum_values[k][0] / max(1, self.sum_values[k][1]) 115 | if avg > 1e-3: 116 | info += ' %.4f' % avg 117 | else: 118 | info += ' %.4e' % avg 119 | sys.stdout.write(info + "\n") 120 | 121 | def add(self, n, values=[]): 122 | self.update(self.seen_so_far + n, values) 123 | 124 | 125 | # mnist 126 | def load_mnist_images_np(imgs_filename): 127 | with open(imgs_filename, 'rb') as f: 128 | f.seek(4) 129 | nimages, rows, cols = struct.unpack('>iii', f.read(12)) 130 | dim = rows * cols 131 | images = np.fromfile(f, dtype=np.dtype(np.ubyte)) 132 | images = (images / 255.0).astype('float32').reshape((nimages, dim)) 133 | 134 | return images 135 | 136 | 137 | # cifar10 138 | from six.moves.urllib.request import FancyURLopener 139 | import tarfile 140 | import sys 141 | 142 | 143 | class ParanoidURLopener(FancyURLopener): 144 | def http_error_default(self, url, fp, errcode, errmsg, headers): 145 | raise Exception('URL fetch failure on {}: {} -- {}'.format(url, errcode, errmsg)) 146 | 147 | 148 | def get_file(fname, origin, untar=False): 149 | datadir_base = os.path.expanduser(os.path.join('~', '.keras')) 150 | if not os.access(datadir_base, os.W_OK): 151 | datadir_base = os.path.join('/tmp', '.keras') 152 | datadir = os.path.join(datadir_base, 'datasets') 153 | if not os.path.exists(datadir): 154 | os.makedirs(datadir) 155 | 156 | if untar: 157 | untar_fpath = os.path.join(datadir, fname) 158 | fpath = untar_fpath + '.tar.gz' 159 | else: 160 | fpath = os.path.join(datadir, fname) 161 | 162 | if not os.path.exists(fpath): 163 | print('Downloading data from', origin) 164 | global progbar 165 | progbar = None 166 | 167 | def dl_progress(count, block_size, total_size): 168 | global progbar 169 | if progbar is None: 170 | progbar = Progbar(total_size) 171 | else: 172 | progbar.update(count * block_size) 173 | 174 | ParanoidURLopener().retrieve(origin, fpath, dl_progress) 175 | progbar = None 176 | 177 | if untar: 178 | if not os.path.exists(untar_fpath): 179 | print('Untaring file...') 180 | tfile = tarfile.open(fpath, 'r:gz') 181 | tfile.extractall(path=datadir) 182 | tfile.close() 183 | return untar_fpath 184 | 185 | return fpath 186 | 187 | 188 | def load_batch(fpath, label_key='labels'): 189 | f = open(fpath, 'rb') 190 | if sys.version_info < (3,): 191 | d = pickle.load(f) 192 | else: 193 | d = pickle.load(f, encoding="bytes") 194 | # decode utf8 195 | for k, v in d.items(): 196 | del (d[str(k)]) 197 | d[str(k)] = v 198 | f.close() 199 | data = d["data"] 200 | labels = d[label_key] 201 | 202 | data = data.reshape(data.shape[0], 3, 32, 32) 203 | return data, labels 204 | 205 | 206 | def load_cifar10(): 207 | dirname = "cifar-10-batches-py" 208 | origin = "http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 209 | path = get_file(dirname, origin=origin, untar=True) 210 | print(path) 211 | nb_train_samples = 50000 212 | 213 | X_train = np.zeros((nb_train_samples, 3, 32, 32), dtype="uint8") 214 | y_train = np.zeros((nb_train_samples,), dtype="uint8") 215 | 216 | for i in range(1, 6): 217 | fpath = os.path.join(path, 'data_batch_' + str(i)) 218 | print(fpath) 219 | data, labels = load_batch(fpath) 220 | X_train[(i - 1) * 10000:i * 10000, :, :, :] = data 221 | y_train[(i - 1) * 10000:i * 10000] = labels 222 | 223 | fpath = os.path.join(path, 'test_batch') 224 | X_test, y_test = load_batch(fpath) 225 | 226 | y_train = np.reshape(y_train, (len(y_train), 1)) 227 | y_test = np.reshape(y_test, (len(y_test), 1)) 228 | 229 | return (X_train, y_train), (X_test, y_test) 230 | 231 | 232 | if __name__ == '__main__': 233 | 234 | if not os.path.exists(savedir): 235 | os.makedirs(savedir) 236 | 237 | if mnist: 238 | print('dynamically binarized mnist') 239 | mnist_filenames = ['train-images-idx3-ubyte', 't10k-images-idx3-ubyte'] 240 | 241 | for filename in mnist_filenames: 242 | local_filename = os.path.join(savedir, filename) 243 | urllib.request.urlretrieve("http://yann.lecun.com/exdb/mnist/{}.gz".format(filename), 244 | local_filename + '.gz') 245 | with gzip.open(local_filename + '.gz', 'rb') as f: 246 | file_content = f.read() 247 | with open(local_filename, 'wb') as f: 248 | f.write(file_content) 249 | np.savetxt(local_filename, load_mnist_images_np(local_filename)) 250 | os.remove(local_filename + '.gz') 251 | 252 | print('statically binarized mnist') 253 | subdatasets = ['train', 'valid', 'test'] 254 | for subdataset in subdatasets: 255 | filename = 'binarized_mnist_{}.amat'.format(subdataset) 256 | url = 'http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/binarized_mnist_{}.amat'.format( 257 | subdataset) 258 | local_filename = os.path.join(savedir, filename) 259 | urllib.request.urlretrieve(url, local_filename) 260 | 261 | if cifar10: 262 | (X_train, y_train), (X_test, y_test) = load_cifar10() 263 | pickle.dump((X_train, y_train, X_test, y_test), 264 | open('{}/cifar10.pkl'.format(savedir), 'w')) 265 | 266 | if omniglot: 267 | url = 'https://github.com/yburda/iwae/raw/master/datasets/OMNIGLOT/chardata.mat' 268 | filename = 'omniglot.amat' 269 | local_filename = os.path.join(savedir, filename) 270 | urllib.request.urlretrieve(url, local_filename) 271 | 272 | if maf: 273 | savedir = 'UCIdatasets/' 274 | url = 'https://zenodo.org/record/1161203/files/data.tar.gz' 275 | local_filename = os.path.join(savedir, 'data.tar.gz') 276 | urllib.request.urlretrieve(url, local_filename) 277 | 278 | tar = tarfile.open(local_filename, "r:gz") 279 | tar.extractall(savedir) 280 | tar.close() 281 | os.remove(local_filename) 282 | -------------------------------------------------------------------------------- /license: -------------------------------------------------------------------------------- 1 | Creative Commons Attribution 4.0 International Public License 2 | 3 | By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions. 4 | 5 | Section 1 – Definitions. 6 | 7 | Adapted Material means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image. 8 | Adapter's License means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License. 9 | Copyright and Similar Rights means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights. 10 | Effective Technological Measures means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements. 11 | Exceptions and Limitations means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material. 12 | Licensed Material means the artistic or literary work, database, or other material to which the Licensor applied this Public License. 13 | Licensed Rights means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license. 14 | Licensor means the individual(s) or entity(ies) granting rights under this Public License. 15 | Share means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them. 16 | Sui Generis Database Rights means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world. 17 | You means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning. 18 | Section 2 – Scope. 19 | 20 | License grant. 21 | Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to: 22 | reproduce and Share the Licensed Material, in whole or in part; and 23 | produce, reproduce, and Share Adapted Material. 24 | Exceptions and Limitations. For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions. 25 | Term. The term of this Public License is specified in Section 6(a). 26 | Media and formats; technical modifications allowed. The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material. 27 | Downstream recipients. 28 | Offer from the Licensor – Licensed Material. Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License. 29 | No downstream restrictions. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material. 30 | No endorsement. Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i). 31 | Other rights. 32 | 33 | Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise. 34 | Patent and trademark rights are not licensed under this Public License. 35 | To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties. 36 | Section 3 – License Conditions. 37 | 38 | Your exercise of the Licensed Rights is expressly made subject to the following conditions. 39 | 40 | Attribution. 41 | 42 | If You Share the Licensed Material (including in modified form), You must: 43 | 44 | retain the following if it is supplied by the Licensor with the Licensed Material: 45 | identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated); 46 | a copyright notice; 47 | a notice that refers to this Public License; 48 | a notice that refers to the disclaimer of warranties; 49 | a URI or hyperlink to the Licensed Material to the extent reasonably practicable; 50 | indicate if You modified the Licensed Material and retain an indication of any previous modifications; and 51 | indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License. 52 | You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information. 53 | If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable. 54 | If You Share Adapted Material You produce, the Adapter's License You apply must not prevent recipients of the Adapted Material from complying with this Public License. 55 | Section 4 – Sui Generis Database Rights. 56 | 57 | Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material: 58 | 59 | for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database; 60 | if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and 61 | You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database. 62 | For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights. 63 | Section 5 – Disclaimer of Warranties and Limitation of Liability. 64 | 65 | Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You. 66 | To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You. 67 | The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability. 68 | Section 6 – Term and Termination. 69 | 70 | This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically. 71 | Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates: 72 | 73 | automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or 74 | upon express reinstatement by the Licensor. 75 | For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License. 76 | For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License. 77 | Sections 1, 5, 6, 7, and 8 survive termination of this Public License. 78 | Section 7 – Other Terms and Conditions. 79 | 80 | The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed. 81 | Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License. 82 | Section 8 – Interpretation. 83 | 84 | For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License. 85 | To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions. 86 | No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor. 87 | Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority. 88 | Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” The text of the Creative Commons public licenses is dedicated to the public domain under the CC0 Public Domain Dedication. Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at creativecommons.org/policies, Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses. 89 | 90 | Creative Commons may be contacted at creativecommons.org. 91 | -------------------------------------------------------------------------------- /UCIExperiments.py: -------------------------------------------------------------------------------- 1 | from timeit import default_timer as timer 2 | import lib.utils as utils 3 | from datetime import datetime 4 | import yaml 5 | import os 6 | import UCIdatasets 7 | import numpy as np 8 | from models.Normalizers import * 9 | from models.Conditionners import * 10 | from models.NormalizingFlowFactories import buildFCNormalizingFlow 11 | from models.NormalizingFlow import * 12 | import math 13 | import re 14 | 15 | def batch_iter(X, batch_size, shuffle=False): 16 | """ 17 | X: feature tensor (shape: num_instances x num_features) 18 | """ 19 | if shuffle: 20 | idxs = torch.randperm(X.shape[0]) 21 | else: 22 | idxs = torch.arange(X.shape[0]) 23 | if X.is_cuda: 24 | idxs = idxs.cuda() 25 | for batch_idxs in idxs.split(batch_size): 26 | yield X[batch_idxs] 27 | 28 | 29 | def load_data(name): 30 | 31 | if name == 'bsds300': 32 | return UCIdatasets.BSDS300() 33 | 34 | elif name == 'power': 35 | return UCIdatasets.POWER() 36 | 37 | elif name == 'gas': 38 | return UCIdatasets.GAS() 39 | 40 | elif name == 'hepmass': 41 | return UCIdatasets.HEPMASS() 42 | 43 | elif name == 'miniboone': 44 | return UCIdatasets.MINIBOONE() 45 | 46 | elif name == "digits": 47 | return UCIdatasets.DIGITS() 48 | elif name == "proteins": 49 | return UCIdatasets.PROTEINS() 50 | else: 51 | raise ValueError('Unknown dataset') 52 | 53 | 54 | cond_types = {"DAG": DAGConditioner, "Coupling": CouplingConditioner, "Autoregressive": AutoregressiveConditioner} 55 | norm_types = {"affine": AffineNormalizer, "monotonic": MonotonicNormalizer} 56 | 57 | 58 | def train(dataset="POWER", load=True, nb_step_dual=100, nb_steps=20, path="", l1=.1, nb_epoch=10000, 59 | int_net=[200, 200, 200], emb_net=[200, 200, 200], b_size=100, all_args=None, file_number=None, train=True, 60 | solver="CC", nb_flow=1, weight_decay=1e-5, learning_rate=1e-3, cond_type='DAG', norm_type='affine'): 61 | logger = utils.get_logger(logpath=os.path.join(path, 'logs'), filepath=os.path.abspath(__file__)) 62 | logger.info(str(all_args)) 63 | 64 | logger.info("Creating model...") 65 | 66 | device = "cpu" if not(torch.cuda.is_available()) else "cuda:0" 67 | 68 | if load: 69 | #train = False 70 | file_number = "_" + file_number if file_number is not None else "" 71 | 72 | batch_size = b_size 73 | 74 | logger.info("Loading data...") 75 | data = load_data(dataset) 76 | data.trn.x = torch.from_numpy(data.trn.x).to(device) 77 | data.val.x = torch.from_numpy(data.val.x).to(device) 78 | data.tst.x = torch.from_numpy(data.tst.x).to(device) 79 | logger.info("Data loaded.") 80 | 81 | dim = data.trn.x.shape[1] 82 | conditioner_type = cond_types[cond_type] 83 | conditioner_args = {"in_size": dim, "hidden": emb_net[:-1], "out_size": emb_net[-1]} 84 | if conditioner_type is DAGConditioner: 85 | conditioner_args['l1'] = l1 86 | conditioner_args['gumble_T'] = .5 87 | conditioner_args['nb_epoch_update'] = nb_step_dual 88 | conditioner_args["hot_encoding"] = True 89 | normalizer_type = norm_types[norm_type] 90 | if normalizer_type is MonotonicNormalizer: 91 | normalizer_args = {"integrand_net": int_net, "cond_size": emb_net[-1], "nb_steps": nb_steps, 92 | "solver": solver} 93 | else: 94 | normalizer_args = {} 95 | 96 | model = buildFCNormalizingFlow(nb_flow, conditioner_type, conditioner_args, normalizer_type, normalizer_args) 97 | best_valid_loss = np.inf 98 | 99 | opt = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) 100 | 101 | if load: 102 | logger.info("Loading model...") 103 | model.load_state_dict(torch.load(path + '/model%s.pt' % file_number, map_location={"cuda:0": device})) 104 | model.train() 105 | if os.path.isfile(path + '/ADAM%s.pt'): 106 | opt.load_state_dict(torch.load(path + '/ADAM%s.pt' % file_number, map_location={"cuda:0": device})) 107 | if device != "cpu": 108 | for state in opt.state.values(): 109 | for k, v in state.items(): 110 | if isinstance(v, torch.Tensor): 111 | state[k] = v.cuda() 112 | 113 | #x = data.trn.x[:20] 114 | #print(x, model(x)) 115 | #exit() 116 | 117 | for epoch in range(nb_epoch): 118 | ll_tot = 0 119 | start = timer() 120 | 121 | # Update constraints 122 | if conditioner_type is DAGConditioner: 123 | with torch.no_grad(): 124 | for conditioner in model.getConditioners(): 125 | conditioner.constrainA(zero_threshold=0.) 126 | 127 | # Training loop 128 | model.to(device) 129 | if train: 130 | for i, cur_x in enumerate(batch_iter(data.trn.x, shuffle=True, batch_size=batch_size)): 131 | if normalizer_type is MonotonicNormalizer: 132 | for normalizer in model.getNormalizers(): 133 | normalizer.nb_steps = nb_steps + torch.randint(0, 10, [1])[0].item() 134 | z, jac = model(cur_x) 135 | #print(z.mean(), jac.mean()) 136 | loss = model.loss(z, jac) 137 | if math.isnan(loss.item()) or math.isinf(loss.abs().item()): 138 | torch.save(model.state_dict(), path + '/NANmodel.pt') 139 | print("Error NAN in loss") 140 | exit() 141 | ll_tot += loss.detach() 142 | opt.zero_grad() 143 | loss.backward(retain_graph=True) 144 | opt.step() 145 | 146 | ll_tot /= i + 1 147 | model.step(epoch, ll_tot) 148 | else: 149 | ll_tot = 0. 150 | 151 | 152 | # Valid loop 153 | ll_test = 0. 154 | with torch.no_grad(): 155 | if normalizer_type is MonotonicNormalizer: 156 | for normalizer in model.getNormalizers(): 157 | normalizer.nb_steps = nb_steps + 20 158 | for i, cur_x in enumerate(batch_iter(data.val.x, shuffle=True, batch_size=batch_size)): 159 | z, jac = model(cur_x) 160 | ll = (model.z_log_density(z) + jac) 161 | ll_test += ll.mean().item() 162 | ll_test /= i + 1 163 | 164 | end = timer() 165 | dagness = max(model.DAGness()) 166 | logger.info("epoch: {:d} - Train loss: {:4f} - Valid log-likelihood: {:4f} - <>: {:4f} - Elapsed time per epoch {:4f} (seconds)". 167 | format(epoch, ll_tot.item(), ll_test, dagness, end-start)) 168 | 169 | if dagness < 1e-20 and -ll_test < best_valid_loss: 170 | logger.info("------- New best validation loss --------") 171 | torch.save(model.state_dict(), path + '/best_model.pt') 172 | best_valid_loss = -ll_test 173 | # Valid loop 174 | ll_test = 0. 175 | for i, cur_x in enumerate(batch_iter(data.tst.x, shuffle=True, batch_size=batch_size)): 176 | z, jac = model(cur_x) 177 | ll = (model.z_log_density(z) + jac) 178 | ll_test += ll.mean().item() 179 | ll_test /= i + 1 180 | 181 | logger.info("epoch: {:d} - Test log-likelihood: {:4f} - <>: {:4f}".format(epoch, ll_test, 182 | dagness)) 183 | if epoch % 10 == 0 and conditioner_type is DAGConditioner: 184 | stoch_gate, noise_gate, s_thresh = [], [], [] 185 | 186 | for conditioner in model.getConditioners(): 187 | stoch_gate.append(conditioner.stoch_gate) 188 | noise_gate.append(conditioner.noise_gate) 189 | s_thresh.append(conditioner.s_thresh) 190 | conditioner.stoch_gate = False 191 | conditioner.noise_gate = False 192 | conditioner.s_thresh = True 193 | for threshold in [.95, .5, .1, .01, .0001]: 194 | for conditioner in model.getConditioners(): 195 | conditioner.h_thresh = threshold 196 | # Valid loop 197 | ll_test = 0. 198 | for i, cur_x in enumerate(batch_iter(data.val.x, shuffle=True, batch_size=batch_size)): 199 | z, jac = model(cur_x) 200 | ll = (model.z_log_density(z) + jac) 201 | ll_test += ll.mean().item() 202 | ll_test /= i 203 | dagness = max(model.DAGness()) 204 | logger.info("epoch: {:d} - Threshold: {:4f} - Valid log-likelihood: {:4f} - <>: {:4f}". 205 | format(epoch, threshold, ll_test, dagness)) 206 | 207 | 208 | for i, conditioner in enumerate(model.getConditioners()): 209 | conditioner.h_thresh = threshold 210 | conditioner.stoch_gate = stoch_gate[i] 211 | conditioner.noise_gate = noise_gate[i] 212 | conditioner.s_thresh = s_thresh[i] 213 | 214 | torch.save(model.state_dict(), path + '/model_%d.pt' % epoch) 215 | torch.save(opt.state_dict(), path + '/ADAM_%d.pt' % epoch) 216 | if dataset == "proteins" and conditioner_type is DAGConditioner: 217 | torch.save(model.getConditioners[0].soft_thresholded_A().detach().cpu(), path + '/A_%d.pt' % epoch) 218 | 219 | torch.save(model.state_dict(), path + '/model.pt') 220 | torch.save(opt.state_dict(), path + '/ADAM.pt') 221 | 222 | import argparse 223 | datasets = ["power", "gas", "bsds300", "miniboone", "hepmass", "digits", "proteins"] 224 | 225 | parser = argparse.ArgumentParser(description='') 226 | parser.add_argument("-load_config", default=None, type=str) 227 | # General Parameters 228 | parser.add_argument("-dataset", default=None, choices=datasets, help="Which toy problem ?") 229 | parser.add_argument("-load", default=False, action="store_true", help="Load a model ?") 230 | parser.add_argument("-folder", default="", help="Folder") 231 | parser.add_argument("-f_number", default=None, type=str, help="Number of heating steps.") 232 | parser.add_argument("-test", default=False, action="store_true") 233 | parser.add_argument("-nb_flow", type=int, default=1, help="Number of steps in the flow.") 234 | 235 | # Optim Parameters 236 | parser.add_argument("-weight_decay", default=1e-5, type=float, help="Weight decay value") 237 | parser.add_argument("-learning_rate", default=1e-3, type=float, help="Weight decay value") 238 | parser.add_argument("-nb_epoch", default=10000, type=int, help="Number of epochs") 239 | parser.add_argument("-b_size", default=100, type=int, help="Batch size") 240 | 241 | # Conditioner Parameters 242 | parser.add_argument("-conditioner", default='DAG', choices=['DAG', 'Coupling', 'Autoregressive'], type=str) 243 | parser.add_argument("-emb_net", default=[100, 100, 100, 10], nargs="+", type=int, help="NN layers of embedding") 244 | # Specific for DAG: 245 | parser.add_argument("-nb_steps_dual", default=100, type=int, help="number of step between updating Acyclicity constraint and sparsity constraint") 246 | parser.add_argument("-l1", default=.2, type=float, help="Maximum weight for l1 regularization") 247 | parser.add_argument("-gumble_T", default=1., type=float, help="Temperature of the gumble distribution.") 248 | 249 | # Normalizer Parameters 250 | parser.add_argument("-normalizer", default='affine', choices=['affine', 'monotonic'], type=str) 251 | parser.add_argument("-int_net", default=[100, 100, 100, 100], nargs="+", type=int, help="NN hidden layers of UMNN") 252 | parser.add_argument("-nb_steps", default=20, type=int, help="Number of integration steps.") 253 | parser.add_argument("-solver", default="CC", type=str, help="Which integral solver to use.", 254 | choices=["CC", "CCParallel"]) 255 | 256 | args = parser.parse_args() 257 | 258 | now = datetime.now() 259 | loader = yaml.SafeLoader 260 | loader.add_implicit_resolver( 261 | u'tag:yaml.org,2002:float', 262 | re.compile(u'''^(?: 263 | [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? 264 | |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) 265 | |\\.[0-9_]+(?:[eE][-+][0-9]+)? 266 | |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* 267 | |[-+]?\\.(?:inf|Inf|INF) 268 | |\\.(?:nan|NaN|NAN))$''', re.X), 269 | list(u'-+0123456789.')) 270 | if args.load_config is not None: 271 | with open("UCIExperimentsConfigurations.yml", 'r') as stream: 272 | try: 273 | configs = yaml.load(stream, Loader=loader)[args.load_config] 274 | for key, val in configs.items(): 275 | setattr(args, key, val) 276 | except yaml.YAMLError as exc: 277 | print(exc) 278 | 279 | 280 | dir_name = args.dataset if args.load_config is None else args.load_config 281 | path = "UCIExperiments/" + dir_name + "/" + now.strftime("%m_%d_%Y_%H_%M_%S") if args.folder == "" else args.folder 282 | if not(os.path.isdir(path)): 283 | os.makedirs(path) 284 | train(args.dataset, load=args.load, path=path, nb_step_dual=args.nb_steps_dual, l1=args.l1, nb_epoch=args.nb_epoch, 285 | int_net=args.int_net, emb_net=args.emb_net, b_size=args.b_size, all_args=args, 286 | nb_steps=args.nb_steps, file_number=args.f_number, solver=args.solver, nb_flow=args.nb_flow, 287 | train=not args.test, weight_decay=args.weight_decay, learning_rate=args.learning_rate, 288 | cond_type=args.conditioner, norm_type=args.normalizer) 289 | -------------------------------------------------------------------------------- /models/Conditionners/DAGConditioner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .Conditioner import Conditioner 4 | import networkx as nx 5 | 6 | 7 | class DAGMLP(nn.Module): 8 | def __init__(self, in_size, hidden, out_size, cond_in=0): 9 | super(DAGMLP, self).__init__() 10 | in_size = in_size 11 | l1 = [in_size + cond_in] + hidden 12 | l2 = hidden + [out_size] 13 | layers = [] 14 | for h1, h2 in zip(l1, l2): 15 | layers += [nn.Linear(h1, h2), nn.ReLU()] 16 | layers.pop() 17 | self.net = nn.Sequential(*layers) 18 | 19 | def forward(self, x): 20 | return self.net(x) 21 | 22 | 23 | class DAGConditioner(Conditioner): 24 | def __init__(self, in_size, hidden, out_size, cond_in=0, soft_thresholding=True, h_thresh=0., gumble_T=1., 25 | hot_encoding=False, l1=0., nb_epoch_update=1, A_prior=None): 26 | super(DAGConditioner, self).__init__() 27 | if A_prior is None: 28 | self.A = nn.Parameter(torch.ones(in_size, in_size) * 1.5 + torch.randn((in_size, in_size)) * .02) 29 | else: 30 | self.A = nn.Parameter(A_prior) 31 | self.in_size = in_size 32 | self.exponent = self.in_size % 50 33 | self.s_thresh = soft_thresholding 34 | self.h_thresh = h_thresh 35 | self.stoch_gate = True 36 | self.noise_gate = False 37 | in_net = in_size*2 if hot_encoding else in_size 38 | if issubclass(type(hidden), nn.Module): 39 | self.embedding_net = hidden 40 | else: 41 | self.embedding_net = DAGMLP(in_net, hidden, out_size, cond_in) 42 | self.gumble = True 43 | self.hutchinson = False 44 | self.gumble_T = gumble_T 45 | self.hot_encoding = hot_encoding 46 | with torch.no_grad(): 47 | self.constrainA(h_thresh) 48 | # Buffers related to the optimization of the constraints on A 49 | self.register_buffer("lambd", torch.tensor(.0)) 50 | self.register_buffer("c", torch.tensor(1e-3)) 51 | self.register_buffer("eta", torch.tensor(10.)) 52 | self.register_buffer("gamma", torch.tensor(.9)) 53 | self.register_buffer("lambd", torch.tensor(.0)) 54 | self.register_buffer("l1_weight", torch.tensor(l1)) 55 | self.register_buffer("dag_const", torch.tensor(1.)) 56 | #self.register_buffer("alpha_factor", torch.tensor(1.)) 57 | self.alpha_factor = 1. 58 | self.d = in_size 59 | self.tol = 1e-30 60 | self.register_buffer("alpha", self.getAlpha()) 61 | self.register_buffer("prev_trace", self.get_power_trace()) 62 | self.nb_epoch_update = nb_epoch_update 63 | self.no_update = 0 64 | self.is_invertible = False#torch.tensor(False) 65 | 66 | def getAlpha(self): 67 | with torch.no_grad(): 68 | _, S, _ = torch.svd(self.A**2, compute_uv=False) 69 | alpha = 1/(torch.max(S) * self.in_size) 70 | alpha = torch.tensor(1./self.in_size) 71 | return alpha 72 | 73 | def get_dag(self): 74 | return self 75 | 76 | def post_process(self, zero_threshold=None): 77 | if zero_threshold is None: 78 | zero_threshold = .1 79 | G = nx.from_numpy_matrix((self.soft_thresholded_A().data.clone().abs() > zero_threshold).float().detach().cpu().numpy(), create_using=nx.DiGraph) 80 | while not nx.is_directed_acyclic_graph(G): 81 | zero_threshold += .05 82 | G = nx.from_numpy_matrix( 83 | (self.soft_thresholded_A().data.clone().abs() > zero_threshold).float().detach().cpu().numpy(), 84 | create_using=nx.DiGraph) 85 | self.stoch_gate = False 86 | self.noise_gate = False 87 | self.s_thresh = False 88 | self.h_thresh = 0. 89 | self.A.data = (self.soft_thresholded_A().data.clone().abs() > zero_threshold).float() 90 | self.A *= 1. - torch.eye(self.in_size, device=self.A.device) 91 | self.A.requires_grad = False 92 | self.A.grad = None 93 | 94 | def stochastic_gate(self, importance): 95 | if self.gumble: 96 | # Gumble soft-max gate 97 | temp = self.gumble_T 98 | epsilon = 1e-6 99 | g1 = -torch.log(-torch.log(torch.rand(importance.shape, device=self.A.device))) 100 | g2 = -torch.log(-torch.log(torch.rand(importance.shape, device=self.A.device))) 101 | z1 = torch.exp((torch.log(importance + epsilon) + g1)/temp) 102 | z2 = torch.exp((torch.log(1 - importance + epsilon) + g2)/temp) 103 | return z1 / (z1 + z2) 104 | 105 | else: 106 | beta_1, beta_2 = 3., 10. 107 | sigma = beta_1/(1. + beta_2*torch.sqrt((importance - .5)**2.)) 108 | mu = importance 109 | z = torch.randn(importance.shape, device=self.A.device) * sigma + mu + .25 110 | #non_importance = torch.sqrt((importance - 1.)**2) 111 | #z = z - non_importance/beta_1 112 | return torch.relu(z.clamp_max(1.)) 113 | 114 | def noiser_gate(self, x, importance): 115 | noise = torch.randn(importance.shape, device=self.A.device) * torch.sqrt((1 - importance)**2) 116 | return importance*(x + noise) 117 | 118 | def soft_thresholded_A(self): 119 | return 2*(torch.sigmoid(2*(self.A**2)) -.5) 120 | 121 | def hard_thresholded_A(self): 122 | if self.s_thresh: 123 | return self.soft_thresholded_A()*(self.soft_thresholded_A() > self.h_thresh).float() 124 | return self.A**2 * (self.A**2 > self.h_thresh).float() 125 | 126 | def forward(self, x, context=None): 127 | if self.h_thresh > 0: 128 | if self.stoch_gate: 129 | e = (x.unsqueeze(1).expand(-1, self.in_size, -1) * self.stochastic_gate(self.hard_thresholded_A().unsqueeze(0) 130 | .expand(x.shape[0], -1, -1)))\ 131 | .view(x.shape[0] * self.in_size, -1) 132 | elif self.noise_gate: 133 | e = self.noiser_gate(x.unsqueeze(1).expand(-1, self.in_size, -1), 134 | self.hard_thresholded_A().unsqueeze(0) 135 | .expand(x.shape[0], -1, -1))\ 136 | .view(x.shape[0] * self.in_size, -1) 137 | else: 138 | e = (x.unsqueeze(1).expand(-1, self.in_size, -1) * self.hard_thresholded_A().unsqueeze(0) 139 | .expand(x.shape[0], -1, -1)).view(x.shape[0] * self.in_size, -1) 140 | elif self.s_thresh: 141 | if self.stoch_gate: 142 | e = (x.unsqueeze(1).expand(-1, self.in_size, -1) * self.stochastic_gate(self.soft_thresholded_A().unsqueeze(0) 143 | .expand(x.shape[0], -1, -1))).view(x.shape[0] * self.in_size, -1) 144 | elif self.noise_gate: 145 | e = self.noiser_gate(x.unsqueeze(1).expand(-1, self.in_size, -1), 146 | self.soft_thresholded_A().unsqueeze(0).expand(x.shape[0], -1, -1))\ 147 | .view(x.shape[0] * self.in_size, -1) 148 | else: 149 | e = (x.unsqueeze(1).expand(-1, self.in_size, -1) * self.soft_thresholded_A().unsqueeze(0) 150 | .expand(x.shape[0], -1, -1)).view(x.shape[0] * self.in_size, -1) 151 | else: 152 | e = (x.unsqueeze(1).expand(-1, self.in_size, -1) * self.A.unsqueeze(0).expand(x.shape[0], -1, -1))\ 153 | .view(x.shape[0] * self.in_size, -1) 154 | 155 | if self.hot_encoding: 156 | hot_encoding = torch.eye(self.in_size, device=self.A.device).unsqueeze(0).expand(x.shape[0], -1, -1)\ 157 | .contiguous().view(-1, self.in_size) 158 | # ORIGINAL CODE 159 | # e = self.embedding_net(e) 160 | # full_e = torch.cat((e, hot_encoding), 1).view(x.shape[0], self.in_size, -1) 161 | # # TODO Add context 162 | # return full_e 163 | # END ORIGINAL CODE 164 | 165 | # ASIC'S ATTEMPT TO FIX DAG CONDITIONER ERROR 166 | return self.embedding_net(torch.cat((e, hot_encoding), 1)).view(x.shape[0], self.in_size, -1) 167 | # END ASIC'S ATTEMPT 168 | 169 | return self.embedding_net(e).view(x.shape[0], self.in_size, -1)#.permute(0, 2, 1).contiguous().view(x.shape[0], -1) 170 | 171 | def constrainA(self, zero_threshold=.0001): 172 | self.A *= (self.A.clone().abs() > zero_threshold).float() 173 | self.A *= 1. - torch.eye(self.in_size, device=self.A.device) 174 | return 175 | 176 | def get_power_trace(self): 177 | alpha = min(1., self.alpha) 178 | alpha *= self.alpha_factor 179 | if self.hutchinson != 0: 180 | h_iter = self.hutchinson 181 | trace = 0. 182 | I = torch.eye(self.in_size, device=self.A.device) 183 | for j in range(h_iter): 184 | e0 = torch.randn(self.in_size, 1).to(self.A.device) 185 | e = e0 186 | for i in range(self.in_size): 187 | e = (I + alpha * self.A ** 2) @ e 188 | 189 | trace += (e0 * e).sum() 190 | return trace / h_iter - self.in_size 191 | 192 | B = (torch.eye(self.in_size, device=self.A.device) + alpha * self.A ** 2) 193 | M = torch.matrix_power(B, self.exponent) 194 | return torch.diag(M).sum() - self.in_size 195 | 196 | def update_dual_param(self): 197 | with torch.no_grad(): 198 | lag_const = self.get_power_trace() 199 | while self.dag_const > 0. and lag_const < self.tol and self.exponent < self.in_size: 200 | print("Update exponent", self.exponent) 201 | self.exponent += 50 202 | lag_const = self.get_power_trace() 203 | 204 | if self.dag_const > 0. and lag_const > self.tol: 205 | self.lambd = self.lambd + self.c * lag_const 206 | # Absolute does not make sense (but copied from DAG-GNN) 207 | if lag_const.abs() > self.gamma*self.prev_trace.abs(): 208 | self.c *= self.eta 209 | self.prev_trace = lag_const 210 | elif self.dag_const > 0.: 211 | print("DAGness is very low: %f -> Post processing" % torch.log(lag_const), flush=True) 212 | A_before = self.A.clone() 213 | self.post_process() 214 | self.alpha = torch.tensor(self.getAlpha()) 215 | lag_const = self.get_power_trace() 216 | print("DAGness is now: %f" % torch.log(lag_const), flush=True) 217 | if lag_const > 0.: 218 | print("Error in post-processing.", flush=True) 219 | self.stoch_gate = True 220 | self.noise_gate = False 221 | self.s_thresh = True 222 | self.h_thresh = 0. 223 | self.A = nn.Parameter(A_before) 224 | self.A.requires_grad = True 225 | self.A.grad = self.A.clone() 226 | self.alpha = torch.tensor(self.getAlpha()) 227 | self.prev_trace = self.get_power_trace() 228 | #self.alpha_factor *= 2 ** (1/self.in_size) 229 | self.c *= 1/self.eta 230 | self.lambd = self.lambd + self.c * lag_const 231 | self.dag_const = torch.tensor(1.) 232 | else: 233 | self.dag_const = torch.tensor(0.) 234 | self.l1_weight = torch.tensor(0.) 235 | print("Post processing successful.") 236 | print("Number of edges is %d VS number max is %d" % 237 | (int(self.A.sum().item()), ((self.d - 1)*self.d)/2), flush=True) 238 | 239 | else: 240 | G = nx.from_numpy_matrix(self.A.detach().cpu().numpy() ** 2, create_using=nx.DiGraph) 241 | try: 242 | nx.find_cycle(G) 243 | print("Bad news there is still cycles in this graph.", flush=True) 244 | self.A.requires_grad = True 245 | self.A.grad = self.A.clone() 246 | self.stoch_gate = True 247 | self.noise_gate = False 248 | self.s_thresh = True 249 | self.h_thresh = 0. 250 | self.alpha = self.getAlpha() 251 | self.prev_trace = self.get_power_trace() 252 | self.dag_const = torch.tensor(1.) 253 | print(self.in_size, self.prev_trace) 254 | except nx.NetworkXNoCycle: 255 | print("Good news there is no cycle in this graph.", flush=True) 256 | print("Depth of the graph is: %d" % self.depth()) 257 | self.is_invertible = True#torch.tensor(True) 258 | 259 | print("DAGness is still very low: %f" % torch.log(self.get_power_trace()), flush=True) 260 | return lag_const 261 | 262 | def depth(self): 263 | G = nx.from_numpy_matrix((self.A.detach() > 0).float().cpu().numpy(), create_using=nx.DiGraph) 264 | if self.is_invertible or nx.is_directed_acyclic_graph(G): 265 | return int(nx.dag_longest_path_length(G)) 266 | return 0 267 | 268 | def loss(self): 269 | lag_const = self.get_power_trace() 270 | loss = self.dag_const*(self.lambd*lag_const + self.c/2*lag_const**2) + self.l1_weight*self.A.abs().mean() 271 | return loss 272 | 273 | def step(self, epoch_number, loss_avg=0.): 274 | if self.A.requires_grad: 275 | print(self.alpha, self.A.max(), self.A.min(), self.A.mean(), self.A.requires_grad, self.A.grad.mean(), 276 | self.A.grad.max(), self.A.grad.min(), self.A.grad.std(), self.getAlpha(), self.dag_const, flush=True) 277 | else: 278 | print(self.A.requires_grad, self.getAlpha(), self.dag_const, flush=True) 279 | with torch.no_grad(): 280 | lag_const = self.get_power_trace() 281 | if lag_const > 50: 282 | self.exponent -= 5 283 | self.exponent = self.exponent if self.exponent > 3 else 3 284 | if epoch_number % self.nb_epoch_update == 0 and epoch_number > 0: 285 | if self.in_size < 30: 286 | print(self.soft_thresholded_A(), flush=True) 287 | if self.loss().abs() < loss_avg.abs()/2 or self.no_update > 10: 288 | print("Update param", flush=True) 289 | self.update_dual_param() 290 | self.no_update = 0 291 | else: 292 | print("No Update param", flush=True) 293 | self.no_update += 1 294 | -------------------------------------------------------------------------------- /ImageExperimentsTest.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import lib.utils as utils 3 | import os 4 | import matplotlib.pyplot as plt 5 | import networkx as nx 6 | from torchvision import datasets, transforms 7 | from lib.transform import AddUniformNoise, ToTensor, HorizontalFlip, Transpose, Resize 8 | import numpy as np 9 | import torch.nn as nn 10 | from models.NormalizingFlowFactories import buildMNISTNormalizingFlow, buildCIFAR10NormalizingFlow, buildFCNormalizingFlow 11 | from models.Normalizers import AffineNormalizer, MonotonicNormalizer 12 | from models.Conditionners import * 13 | import torchvision.datasets as dset 14 | import torchvision.transforms as tforms 15 | import matplotlib.animation as animation 16 | import matplotlib 17 | import torchvision 18 | import seaborn as sns 19 | 20 | S = "\n" 21 | sns.set() 22 | def add_noise(x): 23 | """ 24 | [0, 1] -> [0, 255] -> add noise -> [0, 1] 25 | """ 26 | noise = x.new().resize_as_(x).uniform_() 27 | x = x * 255 + noise 28 | x = x / 256 29 | return x 30 | 31 | 32 | def compute_bpp(ll, x, alpha=1e-6): 33 | d = x.shape[1] 34 | bpp = -ll / (d * np.log(2)) - np.log2(1 - 2 * alpha) + 8 \ 35 | + 1 / d * (torch.log2(torch.sigmoid(x)) + torch.log2(1. - torch.sigmoid(x))).sum(1) 36 | return bpp 37 | 38 | 39 | def load_data(dataset="MNIST", batch_size=100, cuda=-1): 40 | if dataset == "MNIST": 41 | data = datasets.MNIST('./MNIST', train=True, download=True, 42 | transform=transforms.Compose([ 43 | AddUniformNoise(), 44 | ToTensor() 45 | ])) 46 | 47 | train_data, valid_data = torch.utils.data.random_split(data, [50000, 10000]) 48 | 49 | test_data = datasets.MNIST('./MNIST', train=False, download=True, 50 | transform=transforms.Compose([ 51 | AddUniformNoise(), 52 | ToTensor() 53 | ])) 54 | kwargs = {'num_workers': 0, 'pin_memory': True} if cuda > -1 else {} 55 | 56 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs) 57 | valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs) 58 | test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs) 59 | elif len(dataset) == 6 and dataset[:5] == 'MNIST': 60 | data = datasets.MNIST('./MNIST', train=True, download=True, 61 | transform=transforms.Compose([ 62 | AddUniformNoise(), 63 | ToTensor() 64 | ])) 65 | label = int(dataset[5]) 66 | idx = data.train_labels == label 67 | data.targets = data.train_labels[idx] 68 | data.data = data.train_data[idx] 69 | 70 | train_data, valid_data = torch.utils.data.random_split(data, [5000, idx.sum() - 5000]) 71 | 72 | test_data = datasets.MNIST('./MNIST', train=False, download=True, 73 | transform=transforms.Compose([ 74 | AddUniformNoise(), 75 | ToTensor() 76 | ])) 77 | idx = test_data.test_labels == label 78 | test_data.targets = test_data.test_labels[idx] 79 | test_data.data = test_data.test_data[idx] 80 | 81 | kwargs = {'num_workers': 0, 'pin_memory': True} if cuda > -1 else {} 82 | 83 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True, 84 | **kwargs) 85 | valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size, shuffle=True, drop_last=True, 86 | **kwargs) 87 | test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True, drop_last=True, 88 | **kwargs) 89 | elif dataset == "CIFAR10": 90 | im_dim = 3 91 | im_size = 32 # if args.imagesize is None else args.imagesize 92 | trans = lambda im_size: tforms.Compose([tforms.Resize(im_size), tforms.ToTensor(), add_noise]) 93 | train_data = dset.CIFAR10( 94 | root="./data", train=True, transform=tforms.Compose([ 95 | tforms.Resize(im_size), 96 | tforms.RandomHorizontalFlip(), 97 | tforms.ToTensor(), 98 | add_noise, 99 | ]), download=True 100 | ) 101 | test_data = dset.CIFAR10(root="./data", train=False, transform=trans(im_size), download=True) 102 | kwargs = {'num_workers': 0, 'pin_memory': True} if cuda > -1 else {} 103 | 104 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, drop_last=True, shuffle=True, **kwargs) 105 | # WARNING VALID = TEST 106 | valid_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, drop_last=True, shuffle=True, **kwargs) 107 | test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, drop_last=True, shuffle=True, **kwargs) 108 | return train_loader, valid_loader, test_loader 109 | 110 | cond_types = {"DAG": DAGConditioner, "Coupling": CouplingConditioner, "Autoregressive": AutoregressiveConditioner} 111 | 112 | def test(dataset="MNIST", load=True, nb_step_dual=100, nb_steps=20, path="", l1=.1, nb_epoch=10000, b_size=100, 113 | int_net=[50, 50, 50], all_args=None, file_number=None, train=True, solver="CC", weight_decay=1e-5, 114 | learning_rate=1e-3, batch_per_optim_step=1, n_gpu=1, norm_type='Affine', nb_flow=[1], hot_encoding=True, 115 | prior_A_kernel=None, conditioner="DAG", emb_net=None): 116 | logger = utils.get_logger(logpath=os.path.join(path, 'logs'), filepath=os.path.abspath(__file__)) 117 | logger.info(str(all_args)) 118 | 119 | 120 | if load: 121 | file_number = "_" + file_number if file_number is not None else "" 122 | 123 | batch_size = b_size 124 | best_valid_loss = np.inf 125 | 126 | logger.info("Loading data...") 127 | train_loader, valid_loader, test_loader = load_data(dataset, batch_size) 128 | if len(dataset) == 6 and dataset[:5] == 'MNIST': 129 | dataset = "MNIST" 130 | alpha = 1e-6 if dataset == "MNIST" else .05 131 | 132 | logger.info("Data loaded.") 133 | 134 | master_device = "cuda:0" if torch.cuda.is_available() else "cpu" 135 | 136 | # ----------------------- Model Definition ------------------- # 137 | logger.info("Creating model...") 138 | if norm_type == 'Affine': 139 | normalizer_type = AffineNormalizer 140 | normalizer_args = {} 141 | else: 142 | normalizer_type = MonotonicNormalizer 143 | normalizer_args = {"integrand_net": int_net, "nb_steps": 15, "solver": solver} 144 | 145 | if conditioner == "DAG": 146 | if dataset == "MNIST": 147 | inner_model = buildMNISTNormalizingFlow(nb_flow, normalizer_type, normalizer_args, l1, 148 | nb_epoch_update=nb_step_dual, hot_encoding=hot_encoding, 149 | prior_kernel=prior_A_kernel) 150 | elif dataset == "CIFAR10": 151 | inner_model = buildCIFAR10NormalizingFlow(nb_flow, normalizer_type, normalizer_args, l1, 152 | nb_epoch_update=nb_step_dual, hot_encoding=hot_encoding) 153 | else: 154 | logger.info("Wrong dataset name. Training aborted.") 155 | exit() 156 | else: 157 | dim = 28 ** 2 if dataset == "MNIST" else 32 * 32 * 3 158 | conditioner_type = cond_types[conditioner] 159 | conditioner_args = {"in_size": dim, "hidden": emb_net[:-1], "out_size": emb_net[-1]} 160 | 161 | inner_model = buildFCNormalizingFlow(nb_flow[0], conditioner_type, conditioner_args, normalizer_type, 162 | normalizer_args) 163 | model = nn.DataParallel(inner_model, device_ids=list(range(n_gpu))).to(master_device) 164 | logger.info(str(model)) 165 | pytorch_total_params = sum(p.numel() for p in model.parameters()) 166 | logger.info("Number of parameters: %d" % pytorch_total_params) 167 | 168 | opt = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) 169 | 170 | if load: 171 | logger.info("Loading model...") 172 | model.load_state_dict(torch.load(path + '/model%s.pt' % file_number, map_location={"cuda:0": master_device})) 173 | model.train() 174 | opt.load_state_dict(torch.load(path + '/ADAM%s.pt' % file_number, map_location={"cuda:0": master_device})) 175 | if master_device != "cpu": 176 | for state in opt.state.values(): 177 | for k, v in state.items(): 178 | if isinstance(v, torch.Tensor): 179 | state[k] = v.cuda() 180 | logger.info("...Model built.") 181 | logger.info("Training starts:") 182 | 183 | if load: 184 | with torch.no_grad(): 185 | for conditioner in model.module.getConditioners(): 186 | if type(conditioner) is DAGConditioner: 187 | print(model.module.DAGness()) 188 | plt.matshow(conditioner.A.detach().numpy()) 189 | plt.savefig(path + "/test.pdf") 190 | conditioner.post_process() 191 | 192 | for normalizer in model.module.getNormalizers(): 193 | if type(normalizer) is MonotonicNormalizer: 194 | print(normalizer.nb_steps) 195 | normalizer.nb_steps = 250 196 | 197 | # ----------------------- Valid Loop ------------------------- # 198 | if False: 199 | ll_test = 0. 200 | bpp_test = 0. 201 | model.to(master_device) 202 | with torch.no_grad(): 203 | for batch_idx, (cur_x, target) in enumerate(valid_loader): 204 | cur_x = cur_x.view(batch_size, -1).float().to(master_device) 205 | z, jac = model(cur_x) 206 | x_inv = model.module.invert(z) 207 | fig, ax = plt.subplots(1, 2) 208 | #ax[0].matshow(cur_x[0].view(28, 28)) 209 | #ax[1].matshow(x_inv[0].view(28, 28)) 210 | #plt.show() 211 | ll = (model.module.z_log_density(z) + jac) 212 | ll_test += ll.mean().item() 213 | bpp_test += compute_bpp(ll, cur_x.view(batch_size, -1).float().to(master_device), alpha).mean().item() 214 | print(bpp_test/(batch_idx + 1)) 215 | ll_test /= batch_idx + 1 216 | bpp_test /= batch_idx + 1 217 | 218 | dagness = max(model.module.DAGness()) 219 | logger.info("Valid log-likelihood: {:4f} - Valid BPP {:4f} - <>: {:4f} ".format(ll_test, bpp_test, dagness)) 220 | 221 | 222 | if False: 223 | logger.info("------- Test loss with threshold -------") 224 | torch.save(model.state_dict(), path + '/best_model.pt') 225 | # Valid loop 226 | ll_test = 0. 227 | bpp_test = 0. 228 | with torch.no_grad(): 229 | for batch_idx, (cur_x, target) in enumerate(test_loader): 230 | z, jac = model(cur_x.view(batch_size, -1).float().to(master_device)) 231 | ll = (model.module.z_log_density(z) + jac) 232 | ll_test += ll.mean().item() 233 | bpp_test += compute_bpp(ll, cur_x.view(batch_size, -1).float().to(master_device), alpha).mean().item() 234 | print(bpp_test / (batch_idx + 1)) 235 | 236 | ll_test /= batch_idx + 1 237 | bpp_test /= batch_idx + 1 238 | logger.info("Test log-likelihood: {:4f} - Test BPP {:4f}".format(ll_test, bpp_test)) 239 | 240 | # Some plots and videos 241 | 242 | 243 | 244 | # Plot of the adjacency Matrix 245 | if True: 246 | for i_cond, conditioner in enumerate(model.module.getConditioners()): 247 | # Video of the conditioning Matrix 248 | in_s = conditioner.in_size if dataset == "MNIST" else 3 * 32 * 32 249 | a_tmp = conditioner.soft_thresholded_A()[0, :] 250 | a_tmp = a_tmp.view(int(in_s**.5), int(in_s**.5)).cpu().numpy() if dataset == "MNIST" else a_tmp.view(3, 32, 32).cpu().numpy() 251 | fig, ax = plt.subplots() 252 | mat = ax.matshow(a_tmp) 253 | plt.colorbar(mat) 254 | current_cmap = matplotlib.cm.get_cmap() 255 | current_cmap.set_bad(color='red') 256 | mat.set_clim(0, 1.) 257 | 258 | def update(i): 259 | A = conditioner.soft_thresholded_A()[i, :].cpu().numpy() 260 | A[i] = np.nan 261 | if dataset == "MNIST": 262 | A = A.reshape(int(in_s**.5), int(in_s**.5)) 263 | elif dataset == "CIFAR10": 264 | A = A.reshape(3, 32, 32) 265 | mat.set_data(A) 266 | return mat 267 | 268 | # Set up formatting for the movie files 269 | Writer = animation.writers['ffmpeg'] 270 | writer = Writer(fps=15, metadata=dict(artist='Me'), bitrate=1800) 271 | #ani = animation.FuncAnimation(fig, update, range(in_s), interval=100, save_count=0) 272 | #ani.save(path + '/A_test%d.mp4' % i_cond, writer=writer) 273 | #plt.close(fig) 274 | 275 | A = (conditioner.soft_thresholded_A() > 0.).float() 276 | fig, ax = plt.subplots(1, 3) 277 | ax[0].matshow(A) 278 | G = nx.from_numpy_matrix(A.detach().cpu().numpy(), create_using=nx.DiGraph) 279 | top_order = list(nx.topological_sort(G)) 280 | A_top = A.clone() 281 | for i in range(in_s): 282 | A_top[:, i] = A_top[:, i][top_order] 283 | for j in range(in_s): 284 | A_top[j, :] = A_top[j, :][top_order] 285 | ax[1].matshow(np.array(top_order).reshape(int(in_s**.5), int(in_s**.5))) 286 | ax[2].matshow(A_top) 287 | 288 | plt.savefig(path + '/A_matrices%d.pdf' % i_cond) 289 | plt.close(fig) 290 | 291 | 292 | deg_out = (conditioner.soft_thresholded_A() > 0.).sum(0).cpu().numpy() 293 | deg_in = (conditioner.soft_thresholded_A() > 0.).sum(1).cpu().numpy() 294 | import matplotlib as mpl 295 | label_size = 20 296 | mpl.rcParams['xtick.labelsize'] = label_size 297 | mpl.rcParams['ytick.labelsize'] = label_size 298 | fig, ax = plt.subplots(1, 2, figsize=(12, 6)) 299 | if dataset == "MNIST": 300 | shape = (int(in_s**.5), int(in_s**.5)) 301 | elif dataset == "CIFAR10": 302 | shape = (3, 32, 32) 303 | res0 = ax[0].matshow(deg_in.reshape(shape)) 304 | ax[0].set_xlabel("\n(a)", fontsize=20) 305 | #fig.colorbar(res0, ax=ax[0]) 306 | res1 = ax[1].matshow(deg_out.reshape(shape)) 307 | ax[1].set_xlabel("\n(b)", fontsize=20) 308 | fig.colorbar(res1, ax=ax[:], shrink=0.75) 309 | plt.savefig(path + '/A_degrees_test%d.pdf' % i_cond) 310 | 311 | with torch.no_grad(): 312 | n_images = 5 313 | in_s = 784 314 | images = [] 315 | for T in [0.9, 1., 1.05, 1.1, 1.15]: 316 | z = torch.randn(n_images, in_s).to(device=master_device) * T 317 | x = model.module.invert(z) 318 | images += [x.view(n_images, 1, 28, 28)] 319 | print((z - model(x)[0]).abs().mean()) 320 | grid_img = torchvision.utils.make_grid(torch.cat(images, 0), nrow=n_images) 321 | torchvision.utils.save_image(grid_img, path + '/images_test_%f.png' % T) 322 | 323 | 324 | import argparse 325 | 326 | parser = argparse.ArgumentParser(description='') 327 | parser.add_argument("-load", default=False, action="store_true", help="Load a model ?") 328 | parser.add_argument("-folder", default="", help="Folder") 329 | parser.add_argument("-nb_steps_dual", default=100, type=int, 330 | help="number of step between updating Acyclicity constraint and sparsity constraint") 331 | parser.add_argument("-l1", default=10., type=float, help="Maximum weight for l1 regularization") 332 | parser.add_argument("-nb_epoch", default=10000, type=int, help="Number of epochs") 333 | parser.add_argument("-b_size", default=1, type=int, help="Batch size") 334 | parser.add_argument("-int_net", default=[50, 50, 50], nargs="+", type=int, help="NN hidden layers of UMNN") 335 | parser.add_argument("-nb_steps", default=20, type=int, help="Number of integration steps.") 336 | parser.add_argument("-f_number", default=None, type=str, help="Number of heating steps.") 337 | parser.add_argument("-solver", default="CC", type=str, help="Which integral solver to use.", 338 | choices=["CC", "CCParallel"]) 339 | parser.add_argument("-nb_flow", default=[1], nargs="+", type=int, help="Number of steps in the flow.") 340 | parser.add_argument("-test", default=False, action="store_true") 341 | parser.add_argument("-weight_decay", default=1e-5, type=float, help="Weight decay value") 342 | parser.add_argument("-learning_rate", default=1e-3, type=float, help="Weight decay value") 343 | parser.add_argument("-batch_per_optim_step", default=1, type=int, help="Number of batch to accumulate") 344 | parser.add_argument("-nb_gpus", default=1, type=int, help="Number of gpus to train on") 345 | parser.add_argument("-dataset", default="MNIST", type=str, choices=["MNIST", "CIFAR10", "MNIST1"]) 346 | parser.add_argument("-normalizer", default="Affine", type=str, choices=["Affine", "Monotonic"]) 347 | parser.add_argument("-no_hot_encoding", default=False, action="store_true") 348 | parser.add_argument("-prior_A_kernel", default=None, type=int) 349 | 350 | parser.add_argument("-conditioner", default='DAG', choices=['DAG', 'Coupling', 'Autoregressive'], type=str) 351 | parser.add_argument("-emb_net", default=[100, 100, 100, 10], nargs="+", type=int, help="NN layers of embedding") 352 | 353 | args = parser.parse_args() 354 | from datetime import datetime 355 | now = datetime.now() 356 | 357 | path = args.dataset + "/" + now.strftime("%m_%d_%Y_%H_%M_%S") if args.folder == "" else args.folder 358 | if not (os.path.isdir(path)): 359 | os.makedirs(path) 360 | test(dataset=args.dataset, load=args.load, path=path, nb_step_dual=args.nb_steps_dual, l1=args.l1, nb_epoch=args.nb_epoch, 361 | int_net=args.int_net, b_size=args.b_size, all_args=args, nb_flow=args.nb_flow, 362 | nb_steps=args.nb_steps, file_number=args.f_number, norm_type=args.normalizer, 363 | solver=args.solver, train=not args.test, weight_decay=args.weight_decay, learning_rate=args.learning_rate, 364 | batch_per_optim_step=args.batch_per_optim_step, n_gpu=args.nb_gpus, hot_encoding=not args.no_hot_encoding, 365 | prior_A_kernel=args.prior_A_kernel, conditioner=args.conditioner, emb_net=args.emb_net) 366 | -------------------------------------------------------------------------------- /ImageExperiments.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from timeit import default_timer as timer 3 | import lib.utils as utils 4 | import os 5 | import matplotlib 6 | import matplotlib.pyplot as plt 7 | import networkx as nx 8 | from torchvision import datasets, transforms 9 | from lib.transform import AddUniformNoise, ToTensor, HorizontalFlip, Transpose, Resize 10 | import numpy as np 11 | import math 12 | import torch.nn as nn 13 | from UMNN import UMNNMAFFlow 14 | from models.NormalizingFlowFactories import buildMNISTNormalizingFlow, buildCIFAR10NormalizingFlow, buildFCNormalizingFlow 15 | from models.Normalizers import AffineNormalizer, MonotonicNormalizer 16 | from models.Conditionners import * 17 | import torchvision.datasets as dset 18 | import torchvision.transforms as tforms 19 | import matplotlib.animation as animation 20 | import matplotlib 21 | import torchvision 22 | 23 | def add_noise(x): 24 | """ 25 | [0, 1] -> [0, 255] -> add noise -> [0, 1] 26 | """ 27 | noise = x.new().resize_as_(x).uniform_() 28 | x = x * 255 + noise 29 | x = x / 256 30 | return x 31 | 32 | 33 | def compute_bpp(ll, x, alpha=1e-6): 34 | d = x.shape[1] 35 | bpp = -ll / (d * np.log(2)) - np.log2(1 - 2 * alpha) + 8 \ 36 | + 1 / d * (torch.log2(torch.sigmoid(x)) + torch.log2(1. - torch.sigmoid(x))).sum(1) 37 | return bpp 38 | 39 | 40 | def load_data(dataset="MNIST", batch_size=100, cuda=-1): 41 | if dataset == "MNIST": 42 | data = datasets.MNIST('./MNIST', train=True, download=True, 43 | transform=transforms.Compose([ 44 | AddUniformNoise(), 45 | ToTensor() 46 | ])) 47 | 48 | train_data, valid_data = torch.utils.data.random_split(data, [50000, 10000]) 49 | 50 | test_data = datasets.MNIST('./MNIST', train=False, download=True, 51 | transform=transforms.Compose([ 52 | AddUniformNoise(), 53 | ToTensor() 54 | ])) 55 | kwargs = {'num_workers': 0, 'pin_memory': True} if cuda > -1 else {} 56 | 57 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs) 58 | valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs) 59 | test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs) 60 | elif len(dataset) == 6 and dataset[:5] == 'MNIST': 61 | data = datasets.MNIST('./MNIST', train=True, download=True, 62 | transform=transforms.Compose([ 63 | AddUniformNoise(), 64 | ToTensor() 65 | ])) 66 | label = int(dataset[5]) 67 | idx = data.train_labels == label 68 | data.targets = data.train_labels[idx] 69 | data.data = data.train_data[idx] 70 | 71 | train_data, valid_data = torch.utils.data.random_split(data, [5000, idx.sum() - 5000]) 72 | 73 | test_data = datasets.MNIST('./MNIST', train=False, download=True, 74 | transform=transforms.Compose([ 75 | AddUniformNoise(), 76 | ToTensor() 77 | ])) 78 | idx = test_data.test_labels == label 79 | test_data.targets = test_data.test_labels[idx] 80 | test_data.data = test_data.test_data[idx] 81 | 82 | kwargs = {'num_workers': 0, 'pin_memory': True} if cuda > -1 else {} 83 | 84 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True, 85 | **kwargs) 86 | valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size, shuffle=True, drop_last=True, 87 | **kwargs) 88 | test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True, drop_last=True, 89 | **kwargs) 90 | elif dataset == "CIFAR10": 91 | im_dim = 3 92 | im_size = 32 # if args.imagesize is None else args.imagesize 93 | trans = lambda im_size: tforms.Compose([tforms.Resize(im_size), tforms.ToTensor(), add_noise]) 94 | train_data = dset.CIFAR10( 95 | root="./data", train=True, transform=tforms.Compose([ 96 | tforms.Resize(im_size), 97 | tforms.RandomHorizontalFlip(), 98 | tforms.ToTensor(), 99 | add_noise, 100 | ]), download=True 101 | ) 102 | test_data = dset.CIFAR10(root="./data", train=False, transform=trans(im_size), download=True) 103 | kwargs = {'num_workers': 0, 'pin_memory': True} if cuda > -1 else {} 104 | 105 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, drop_last=True, shuffle=True, **kwargs) 106 | # WARNING VALID = TEST 107 | valid_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, drop_last=True, shuffle=True, **kwargs) 108 | test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, drop_last=True, shuffle=True, **kwargs) 109 | return train_loader, valid_loader, test_loader 110 | 111 | 112 | cond_types = {"DAG": DAGConditioner, "Coupling": CouplingConditioner, "Autoregressive": AutoregressiveConditioner} 113 | 114 | 115 | def train(dataset="MNIST", load=True, nb_step_dual=100, nb_steps=20, path="", l1=.1, nb_epoch=10000, b_size=100, 116 | int_net=[50, 50, 50], all_args=None, file_number=None, train=True, solver="CC", weight_decay=1e-5, 117 | learning_rate=1e-3, batch_per_optim_step=1, n_gpu=1, norm_type='Affine', nb_flow=[1], hot_encoding=True, 118 | prior_A_kernel=None, conditioner="DAG", emb_net=None): 119 | logger = utils.get_logger(logpath=os.path.join(path, 'logs'), filepath=os.path.abspath(__file__)) 120 | logger.info(str(all_args)) 121 | 122 | 123 | if load: 124 | file_number = "_" + file_number if file_number is not None else "" 125 | 126 | batch_size = b_size 127 | best_valid_loss = np.inf 128 | 129 | logger.info("Loading data...") 130 | train_loader, valid_loader, test_loader = load_data(dataset, batch_size) 131 | if len(dataset) == 6 and dataset[:5] == 'MNIST': 132 | dataset = "MNIST" 133 | alpha = 1e-6 if dataset == "MNIST" else .05 134 | 135 | logger.info("Data loaded.") 136 | 137 | master_device = "cuda:0" if torch.cuda.is_available() else "cpu" 138 | 139 | # ----------------------- Model Definition ------------------- # 140 | logger.info("Creating model...") 141 | if norm_type == 'Affine': 142 | normalizer_type = AffineNormalizer 143 | normalizer_args = {} 144 | else: 145 | normalizer_type = MonotonicNormalizer 146 | normalizer_args = {"integrand_net": int_net, "nb_steps": 15, "solver": solver} 147 | 148 | if conditioner == "DAG": 149 | conditioner_type = DAGConditioner 150 | if dataset == "MNIST": 151 | inner_model = buildMNISTNormalizingFlow(nb_flow, normalizer_type, normalizer_args, l1, 152 | nb_epoch_update=nb_step_dual, hot_encoding=hot_encoding, 153 | prior_kernel=prior_A_kernel) 154 | elif dataset == "CIFAR10": 155 | inner_model = buildCIFAR10NormalizingFlow(nb_flow, normalizer_type, normalizer_args, l1, 156 | nb_epoch_update=nb_step_dual, hot_encoding=hot_encoding) 157 | else: 158 | logger.info("Wrong dataset name. Training aborted.") 159 | exit() 160 | else: 161 | dim = 28**2 if dataset == "MNIST" else 32*32*3 162 | conditioner_type = cond_types[conditioner] 163 | conditioner_args = {"in_size": dim, "hidden": emb_net[:-1], "out_size": emb_net[-1]} 164 | if norm_type == 'Monotonic': 165 | normalizer_args["cond_size"] = emb_net[-1] 166 | 167 | inner_model = buildFCNormalizingFlow(nb_flow[0], conditioner_type, conditioner_args, normalizer_type, normalizer_args) 168 | model = nn.DataParallel(inner_model, device_ids=list(range(n_gpu))).to(master_device) 169 | logger.info(str(model)) 170 | pytorch_total_params = sum(p.numel() for p in model.parameters()) 171 | logger.info("Number of parameters: %d" % pytorch_total_params) 172 | 173 | opt = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) 174 | 175 | if load: 176 | logger.info("Loading model...") 177 | model.load_state_dict(torch.load(path + '/model%s.pt' % file_number, map_location={"cuda:0": master_device})) 178 | model.train() 179 | opt.load_state_dict(torch.load(path + '/ADAM%s.pt' % file_number, map_location={"cuda:0": master_device})) 180 | if master_device != "cpu": 181 | for state in opt.state.values(): 182 | for k, v in state.items(): 183 | if isinstance(v, torch.Tensor): 184 | state[k] = v.cuda() 185 | logger.info("...Model built.") 186 | logger.info("Training starts:") 187 | 188 | if load: 189 | for conditioner in model.module.getConditioners(): 190 | conditioner.alpha = conditioner.getAlpha() 191 | 192 | # ----------------------- Main Loop ------------------------- # 193 | for epoch in range(nb_epoch): 194 | ll_tot = 0 195 | start = timer() 196 | if train: 197 | model.to(master_device) 198 | # ----------------------- Training Loop ------------------------- # 199 | for batch_idx, (cur_x, target) in enumerate(train_loader): 200 | cur_x = cur_x.view(batch_size, -1).float().to(master_device) 201 | for normalizer in model.module.getNormalizers(): 202 | if type(normalizer) is MonotonicNormalizer: 203 | normalizer.nb_steps = nb_steps + torch.randint(0, 10, [1])[0].item() 204 | z, jac = model(cur_x) 205 | loss = model.module.loss(z, jac)/(batch_per_optim_step * n_gpu) 206 | if math.isnan(loss.item()): 207 | print("Error Nan in loss") 208 | print("Dagness:", model.module.DAGness()) 209 | exit() 210 | ll_tot += loss.detach() 211 | if batch_idx % batch_per_optim_step == 0: 212 | opt.zero_grad() 213 | 214 | loss.backward(retain_graph=True) 215 | if (batch_idx + 1) % batch_per_optim_step == 0: 216 | opt.step() 217 | 218 | with torch.no_grad(): 219 | print("Dagness:", model.module.DAGness()) 220 | 221 | ll_tot /= (batch_idx + 1) 222 | torch.cuda.empty_cache() 223 | model.module.step(epoch, ll_tot) 224 | 225 | else: 226 | ll_tot = 0. 227 | 228 | # ----------------------- Valid Loop ------------------------- # 229 | ll_test = 0. 230 | bpp_test = 0. 231 | model.to(master_device) 232 | with torch.no_grad(): 233 | for normalizer in model.module.getNormalizers(): 234 | if type(normalizer) is MonotonicNormalizer: 235 | normalizer.nb_steps = 150 236 | for batch_idx, (cur_x, target) in enumerate(valid_loader): 237 | cur_x = cur_x.view(batch_size, -1).float().to(master_device) 238 | z, jac = model(cur_x) 239 | ll = (model.module.z_log_density(z) + jac) 240 | ll_test += ll.mean().item() 241 | bpp_test += compute_bpp(ll, cur_x.view(batch_size, -1).float().to(master_device), alpha).mean().item() 242 | ll_test /= batch_idx + 1 243 | bpp_test /= batch_idx + 1 244 | end = timer() 245 | 246 | dagness = max(model.module.DAGness()) 247 | logger.info( 248 | "epoch: {:d} - Train loss: {:4f} - Valid log-likelihood: {:4f} - Valid BPP {:4f} - <>: {:4f} " 249 | "- Elapsed time per epoch {:4f} (seconds)".format(epoch, ll_tot, ll_test, bpp_test, dagness, end - start)) 250 | if model.module.isInvertible() and -ll_test < best_valid_loss: 251 | logger.info("------- New best validation loss --------") 252 | torch.save(model.state_dict(), path + '/best_model.pt') 253 | best_valid_loss = -ll_test 254 | # Valid loop 255 | ll_test = 0. 256 | for batch_idx, (cur_x, target) in enumerate(test_loader): 257 | z, jac = model(cur_x.view(batch_size, -1).float().to(master_device)) 258 | ll = (model.module.z_log_density(z) + jac) 259 | ll_test += ll.mean().item() 260 | bpp_test += compute_bpp(ll, cur_x.view(batch_size, -1).float().to(master_device), alpha).mean().item() 261 | 262 | ll_test /= batch_idx + 1 263 | bpp_test /= batch_idx + 1 264 | logger.info("epoch: {:d} - Test log-likelihood: {:4f} - Test BPP {:4f} - <>: {:4f}". 265 | format(epoch, ll_test, bpp_test, dagness)) 266 | if epoch % 10 == 0 and conditioner_type is DAGConditioner: 267 | stoch_gate, noise_gate, s_thresh = [], [], [] 268 | for conditioner in model.module.getConditioners(): 269 | stoch_gate.append(conditioner.stoch_gate) 270 | noise_gate.append(conditioner.noise_gate) 271 | s_thresh.append(conditioner.s_thresh) 272 | conditioner.stoch_gate = False 273 | conditioner.noise_gate = False 274 | conditioner.s_thresh = True 275 | for threshold in [.95, .5, .1, .01, .0001]: 276 | for conditioner in model.module.getConditioners(): 277 | conditioner.h_thresh = threshold 278 | # Valid loop 279 | ll_test = 0. 280 | bpp_test = 0. 281 | for batch_idx, (cur_x, target) in enumerate(valid_loader): 282 | cur_x = cur_x.view(batch_size, -1).float().to(master_device) 283 | z, jac = model(cur_x) 284 | ll = (model.module.z_log_density(z) + jac) 285 | ll_test += ll.mean().item() 286 | bpp_test += compute_bpp(ll, cur_x.view(batch_size, -1).float().to(master_device), alpha).mean().item() 287 | ll_test /= batch_idx + 1 288 | bpp_test /= batch_idx + 1 289 | dagness = max(model.module.DAGness()) 290 | logger.info("epoch: {:d} - Threshold: {:4f} - Valid log-likelihood: {:4f} - Valid BPP {:4f} - <>: {:4f}". 291 | format(epoch, threshold, ll_test, bpp_test, dagness)) 292 | for i, conditioner in enumerate(model.module.getConditioners()): 293 | conditioner.h_thresh = 0. 294 | conditioner.stoch_gate = stoch_gate[i] 295 | conditioner.noise_gate = noise_gate[i] 296 | conditioner.s_thresh = s_thresh[i] 297 | 298 | 299 | 300 | 301 | in_s = 784 if dataset == "MNIST" else 3*32*32 302 | a_tmp = model.module.getConditioners()[0].soft_thresholded_A()[0, :] 303 | a_tmp = a_tmp.view(28, 28).cpu().numpy() if dataset == "MNIST" else a_tmp.view(3, 32, 32).cpu().numpy() 304 | fig, ax = plt.subplots() 305 | mat = ax.matshow(a_tmp) 306 | plt.colorbar(mat) 307 | current_cmap = matplotlib.cm.get_cmap() 308 | current_cmap.set_bad(color='red') 309 | mat.set_clim(0, 1.) 310 | def update(i): 311 | A = model.module.getConditioners()[0].soft_thresholded_A()[i, :].cpu().numpy() 312 | A[i] = np.nan 313 | if dataset == "MNIST": 314 | A = A.reshape(28, 28) 315 | elif dataset == "CIFAR10": 316 | A = A.reshape(3, 32, 32) 317 | mat.set_data(A) 318 | return mat 319 | 320 | # Set up formatting for the movie files 321 | Writer = animation.writers['ffmpeg'] 322 | writer = Writer(fps=15, metadata=dict(artist='Me'), bitrate=1800) 323 | ani = animation.FuncAnimation(fig, update, range(in_s), interval=100, save_count=0) 324 | ani.save(path + '/A_epoch_%d.mp4' % epoch, writer=writer) 325 | 326 | deg_out = (model.module.getConditioners()[0].soft_thresholded_A() > 0.).sum(0).cpu().numpy() 327 | deg_in = (model.module.getConditioners()[0].soft_thresholded_A() > 0.).sum(1).cpu().numpy() 328 | fig, ax = plt.subplots(1, 2, figsize=(12, 6)) 329 | if dataset == "MNIST": 330 | shape = (28, 28) 331 | elif dataset == "CIFAR10": 332 | shape = (3, 32, 32) 333 | res0 = ax[0].matshow(np.log(deg_in).reshape(shape)) 334 | ax[0].set(title="In degrees") 335 | fig.colorbar(res0, ax=ax[0]) 336 | res1 = ax[1].matshow(np.log(deg_out.reshape(shape))) 337 | ax[1].set(title="Out degrees") 338 | fig.colorbar(res1, ax=ax[1]) 339 | plt.savefig(path + '/A_degrees_epoch_%d.png' % epoch) 340 | 341 | if model.module.isInvertible(): 342 | with torch.no_grad(): 343 | n_images = 16 344 | in_s = 28**2 345 | for T in [.1, .25, .5, .75, 1.]: 346 | z = torch.randn(n_images, in_s).to(device=master_device) * T 347 | x = model.module.invert(z) 348 | print((z - model(x)[0]).abs().mean()) 349 | grid_img = torchvision.utils.make_grid(x.view(n_images, 1, 28, 28), nrow=4) 350 | torchvision.utils.save_image(grid_img, path + '/images_%d_%f.png' % (epoch, T)) 351 | 352 | if epoch % nb_step_dual == 0: 353 | logger.info("Saving model N°%d" % epoch) 354 | torch.save(model.state_dict(), path + '/model_%d.pt' % epoch) 355 | torch.save(opt.state_dict(), path + '/ADAM_%d.pt' % epoch) 356 | 357 | torch.save(model.state_dict(), path + '/model.pt') 358 | torch.save(opt.state_dict(), path + '/ADAM.pt') 359 | torch.cuda.empty_cache() 360 | 361 | import argparse 362 | 363 | parser = argparse.ArgumentParser(description='') 364 | parser.add_argument("-load", default=False, action="store_true", help="Load a model ?") 365 | parser.add_argument("-folder", default="", help="Folder") 366 | parser.add_argument("-nb_steps_dual", default=100, type=int, 367 | help="number of step between updating Acyclicity constraint and sparsity constraint") 368 | parser.add_argument("-l1", default=10., type=float, help="Maximum weight for l1 regularization") 369 | parser.add_argument("-nb_epoch", default=10000, type=int, help="Number of epochs") 370 | parser.add_argument("-b_size", default=1, type=int, help="Batch size") 371 | parser.add_argument("-int_net", default=[50, 50, 50], nargs="+", type=int, help="NN hidden layers of UMNN") 372 | parser.add_argument("-nb_steps", default=20, type=int, help="Number of integration steps.") 373 | parser.add_argument("-f_number", default=None, type=str, help="Number of heating steps.") 374 | parser.add_argument("-solver", default="CC", type=str, help="Which integral solver to use.", 375 | choices=["CC", "CCParallel"]) 376 | parser.add_argument("-nb_flow", default=[1], nargs="+", type=int, help="Number of steps in the flow.") 377 | parser.add_argument("-test", default=False, action="store_true") 378 | parser.add_argument("-weight_decay", default=1e-5, type=float, help="Weight decay value") 379 | parser.add_argument("-learning_rate", default=1e-3, type=float, help="Weight decay value") 380 | parser.add_argument("-batch_per_optim_step", default=1, type=int, help="Number of batch to accumulate") 381 | parser.add_argument("-nb_gpus", default=1, type=int, help="Number of gpus to train on") 382 | parser.add_argument("-dataset", default="MNIST", type=str, choices=["MNIST", "CIFAR10", "MNIST1"]) 383 | parser.add_argument("-normalizer", default="Affine", type=str, choices=["Affine", "Monotonic"]) 384 | parser.add_argument("-no_hot_encoding", default=False, action="store_true") 385 | parser.add_argument("-prior_A_kernel", default=None, type=int) 386 | 387 | parser.add_argument("-conditioner", default='DAG', choices=['DAG', 'Coupling', 'Autoregressive'], type=str) 388 | parser.add_argument("-emb_net", default=[100, 100, 100, 10], nargs="+", type=int, help="NN layers of embedding") 389 | 390 | args = parser.parse_args() 391 | from datetime import datetime 392 | now = datetime.now() 393 | 394 | path = args.dataset + "/" + now.strftime("%m_%d_%Y_%H_%M_%S") if args.folder == "" else args.folder 395 | if not (os.path.isdir(path)): 396 | os.makedirs(path) 397 | train(dataset=args.dataset, load=args.load, path=path, nb_step_dual=args.nb_steps_dual, l1=args.l1, nb_epoch=args.nb_epoch, 398 | int_net=args.int_net, b_size=args.b_size, all_args=args, nb_flow=args.nb_flow, 399 | nb_steps=args.nb_steps, file_number=args.f_number, norm_type=args.normalizer, 400 | solver=args.solver, train=not args.test, weight_decay=args.weight_decay, learning_rate=args.learning_rate, 401 | batch_per_optim_step=args.batch_per_optim_step, n_gpu=args.nb_gpus, hot_encoding=not args.no_hot_encoding, 402 | prior_A_kernel=args.prior_A_kernel, conditioner=args.conditioner, emb_net=args.emb_net) 403 | --------------------------------------------------------------------------------