├── __init__.py ├── src ├── __init__.py ├── flows │ ├── models │ │ ├── __init__.py │ │ ├── glow │ │ │ ├── __init__.py │ │ │ ├── actnorm.py │ │ │ ├── invconv.py │ │ │ ├── coupling.py │ │ │ ├── layers.py │ │ │ └── glow.py │ │ ├── flowplusplus │ │ │ ├── __init__.py │ │ │ ├── inv_conv.py │ │ │ ├── coupling.py │ │ │ ├── log_dist.py │ │ │ └── act_norm.py │ │ ├── maf │ │ │ ├── __init__.py │ │ │ ├── realnvp.py │ │ │ ├── util.py │ │ │ └── layers.py │ │ └── ema.py │ ├── trainers │ │ ├── __init__.py │ │ ├── glow1 │ │ │ ├── __init__.py │ │ │ ├── learning_rate_schedule.py │ │ │ ├── thops.py │ │ │ ├── config.py │ │ │ └── builder.py │ │ ├── vision │ │ │ ├── datasets │ │ │ │ ├── __init__.py │ │ │ │ └── celeba.py │ │ │ └── __init__.py │ │ ├── hparams │ │ │ └── celeba.json │ │ └── glow1.py │ ├── results │ │ ├── uci_wine_quality_maf │ │ │ ├── model_checkpoint.pt │ │ │ ├── best_model_checkpoint.pt │ │ │ ├── logs │ │ │ │ └── stdout.txt │ │ │ ├── config.yaml │ │ │ └── results.txt │ │ ├── uci_breast_cancer_maf │ │ │ ├── model_checkpoint.pt │ │ │ ├── best_model_checkpoint.pt │ │ │ ├── logs │ │ │ │ └── stdout.txt │ │ │ ├── config.yaml │ │ │ └── results.txt │ │ ├── kmm_kliep_synthetic_maf │ │ │ ├── model_checkpoint.pt │ │ │ ├── best_model_checkpoint.pt │ │ │ ├── logs │ │ │ │ └── stdout.txt │ │ │ ├── config.yaml │ │ │ └── results.txt │ │ └── uci_blood_transfusion_maf │ │ │ ├── model_checkpoint.pt │ │ │ ├── best_model_checkpoint.pt │ │ │ ├── logs │ │ │ └── stdout.txt │ │ │ ├── config.yaml │ │ │ └── results.txt │ ├── functions │ │ ├── __init__.py │ │ ├── utils.py │ │ └── ckpt_util.py │ └── utils.py ├── classification │ ├── __init__.py │ ├── models │ │ ├── __init__.py │ │ ├── tre.py │ │ ├── flow_mlp.py │ │ ├── omni_classifier.py │ │ ├── mlp.py │ │ ├── cnn.py │ │ └── networks.py │ ├── trainers │ │ ├── __init__.py │ │ └── base.py │ ├── utils.py │ └── main.py ├── datasets │ ├── __init__.py │ ├── looping.py │ ├── kmm.py │ ├── util.py │ ├── vision.py │ ├── mnist.py │ ├── dataset_utils.py │ └── omniglot.py ├── util │ ├── __init__.py │ ├── shell_util.py │ ├── optim_util.py │ ├── norm_util.py │ └── array_util.py ├── requirements.txt ├── configs │ ├── classification │ │ ├── mnist │ │ │ ├── diff_bkgd │ │ │ │ ├── attr_bkgd.yaml │ │ │ │ ├── encodings0.1.yaml │ │ │ │ ├── encodings0.25.yaml │ │ │ │ ├── encodings0.5.yaml │ │ │ │ └── encodings1.0.yaml │ │ │ └── diff_digits │ │ │ │ ├── encodings0.1.yaml │ │ │ │ ├── encodings0.5.yaml │ │ │ │ ├── encodings1.0.yaml │ │ │ │ ├── encodings0.25.yaml │ │ │ │ ├── attr_digits.yaml │ │ │ │ ├── encodings0.1_mlp.yaml │ │ │ │ ├── encodings0.5_mlp.yaml │ │ │ │ ├── encodings1.0_mlp.yaml │ │ │ │ └── encodings0.25_mlp.yaml │ │ ├── gmm │ │ │ ├── mlp_z.yaml │ │ │ ├── flow_mlp_z.yaml │ │ │ ├── mlp_x.yaml │ │ │ └── joint_flow_mlp_z.yaml │ │ ├── mi │ │ │ ├── mlp_z.yaml │ │ │ ├── mlp_x.yaml │ │ │ ├── disc_flow_z.yaml │ │ │ └── joint_flow_z.yaml │ │ └── omniglot │ │ │ ├── x_dre.yaml │ │ │ ├── z_dre.yaml │ │ │ ├── baseline.yaml │ │ │ ├── x_method.yaml │ │ │ ├── z_method.yaml │ │ │ ├── gen_baseline.yaml │ │ │ └── mix_baseline.yaml │ └── flows │ │ ├── gmm │ │ └── maf.yaml │ │ ├── mnist │ │ ├── diff_bkgd │ │ │ ├── perc0.1.yaml │ │ │ ├── perc0.5.yaml │ │ │ ├── perc0.25.yaml │ │ │ └── perc1.0.yaml │ │ └── diff_digits │ │ │ ├── perc0.1.yaml │ │ │ ├── perc0.25.yaml │ │ │ ├── perc0.5.yaml │ │ │ └── perc1.0.yaml │ │ ├── uci_breast_cancer │ │ └── maf.yaml │ │ ├── uci_blood_transfusion │ │ └── maf.yaml │ │ ├── uci_wine_quality │ │ └── maf.yaml │ │ ├── mi │ │ └── mi_maf.yaml │ │ └── omniglot │ │ └── omniglot_maf.yaml └── losses.py ├── init_env.sh ├── .gitignore ├── data ├── 2d_gaussians │ ├── X.npy │ ├── Z.npy │ ├── u.npy │ ├── y.npy │ ├── Z_test.npy │ └── u_test.npy ├── uci_wine_quality │ ├── test.npz │ └── train.npz ├── uci_breast_cancer │ ├── test.npz │ └── train.npz └── uci_blood_transfusion │ ├── test.npz │ └── train.npz ├── scripts └── create_yaml_files.py └── README.md /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/flows/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/classification/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/flows/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/classification/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/flows/trainers/glow1/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /init_env.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=${PYTHONPATH}:$(pwd) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | wandb/ 2 | data/encodings 3 | data/mnist 4 | -------------------------------------------------------------------------------- /src/flows/models/glow/__init__.py: -------------------------------------------------------------------------------- 1 | from models.glow.glow import Glow 2 | -------------------------------------------------------------------------------- /src/flows/trainers/vision/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .celeba import CelebADataset 2 | 3 | -------------------------------------------------------------------------------- /data/2d_gaussians/X.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ermongroup/f-dre/HEAD/data/2d_gaussians/X.npy -------------------------------------------------------------------------------- /data/2d_gaussians/Z.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ermongroup/f-dre/HEAD/data/2d_gaussians/Z.npy -------------------------------------------------------------------------------- /data/2d_gaussians/u.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ermongroup/f-dre/HEAD/data/2d_gaussians/u.npy -------------------------------------------------------------------------------- /data/2d_gaussians/y.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ermongroup/f-dre/HEAD/data/2d_gaussians/y.npy -------------------------------------------------------------------------------- /src/flows/models/flowplusplus/__init__.py: -------------------------------------------------------------------------------- 1 | from models.flowplusplus.flowplusplus import FlowPlusPlus 2 | -------------------------------------------------------------------------------- /src/flows/models/maf/__init__.py: -------------------------------------------------------------------------------- 1 | from .layers import * 2 | from .maf import * 3 | from .realnvp import * -------------------------------------------------------------------------------- /data/2d_gaussians/Z_test.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ermongroup/f-dre/HEAD/data/2d_gaussians/Z_test.npy -------------------------------------------------------------------------------- /data/2d_gaussians/u_test.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ermongroup/f-dre/HEAD/data/2d_gaussians/u_test.npy -------------------------------------------------------------------------------- /data/uci_wine_quality/test.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ermongroup/f-dre/HEAD/data/uci_wine_quality/test.npz -------------------------------------------------------------------------------- /data/uci_breast_cancer/test.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ermongroup/f-dre/HEAD/data/uci_breast_cancer/test.npz -------------------------------------------------------------------------------- /data/uci_breast_cancer/train.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ermongroup/f-dre/HEAD/data/uci_breast_cancer/train.npz -------------------------------------------------------------------------------- /data/uci_wine_quality/train.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ermongroup/f-dre/HEAD/data/uci_wine_quality/train.npz -------------------------------------------------------------------------------- /data/uci_blood_transfusion/test.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ermongroup/f-dre/HEAD/data/uci_blood_transfusion/test.npz -------------------------------------------------------------------------------- /data/uci_blood_transfusion/train.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ermongroup/f-dre/HEAD/data/uci_blood_transfusion/train.npz -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import getpass 2 | 3 | root = '/atlas/u/{}/multi-fairgen/data'.format(getpass.getuser()) 4 | 5 | 6 | -------------------------------------------------------------------------------- /src/flows/trainers/vision/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import CelebADataset 2 | 3 | Datasets = { 4 | "celeba": CelebADataset 5 | } 6 | -------------------------------------------------------------------------------- /src/util/__init__.py: -------------------------------------------------------------------------------- 1 | from util.array_util import * 2 | from util.norm_util import * 3 | from util.optim_util import * 4 | from util.shell_util import * 5 | -------------------------------------------------------------------------------- /src/flows/results/uci_wine_quality_maf/model_checkpoint.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ermongroup/f-dre/HEAD/src/flows/results/uci_wine_quality_maf/model_checkpoint.pt -------------------------------------------------------------------------------- /src/flows/results/uci_breast_cancer_maf/model_checkpoint.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ermongroup/f-dre/HEAD/src/flows/results/uci_breast_cancer_maf/model_checkpoint.pt -------------------------------------------------------------------------------- /src/flows/results/kmm_kliep_synthetic_maf/model_checkpoint.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ermongroup/f-dre/HEAD/src/flows/results/kmm_kliep_synthetic_maf/model_checkpoint.pt -------------------------------------------------------------------------------- /src/flows/results/uci_blood_transfusion_maf/model_checkpoint.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ermongroup/f-dre/HEAD/src/flows/results/uci_blood_transfusion_maf/model_checkpoint.pt -------------------------------------------------------------------------------- /src/flows/results/uci_wine_quality_maf/best_model_checkpoint.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ermongroup/f-dre/HEAD/src/flows/results/uci_wine_quality_maf/best_model_checkpoint.pt -------------------------------------------------------------------------------- /src/flows/results/uci_breast_cancer_maf/best_model_checkpoint.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ermongroup/f-dre/HEAD/src/flows/results/uci_breast_cancer_maf/best_model_checkpoint.pt -------------------------------------------------------------------------------- /src/flows/results/kmm_kliep_synthetic_maf/best_model_checkpoint.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ermongroup/f-dre/HEAD/src/flows/results/kmm_kliep_synthetic_maf/best_model_checkpoint.pt -------------------------------------------------------------------------------- /src/flows/results/uci_blood_transfusion_maf/best_model_checkpoint.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ermongroup/f-dre/HEAD/src/flows/results/uci_blood_transfusion_maf/best_model_checkpoint.pt -------------------------------------------------------------------------------- /src/classification/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from classification.trainers.base import * 2 | from classification.trainers.attr_classifier import AttrClassifier 3 | from classification.trainers.classifier import Classifier -------------------------------------------------------------------------------- /src/requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm==4.49.0 2 | scipy==1.4.1 3 | seaborn==0.8.1 4 | torchvision==0.8.2 5 | matplotlib==3.3.2 6 | numpy==1.18.1 7 | opencv_python==3.4.1.15 8 | requests==2.20.0 9 | torch==1.7.1 10 | docopt==0.6.2 11 | six==1.15.0 12 | dsets==0.0.1 13 | Pillow==8.3.0 14 | PyYAML==5.4.1 15 | scikit_learn==0.24.2 16 | tensorboardX==2.4 17 | -------------------------------------------------------------------------------- /src/flows/results/uci_breast_cancer_maf/logs/stdout.txt: -------------------------------------------------------------------------------- 1 | INFO - main.py - 2021-02-19 16:27:30,660 - Saving output to /atlas/u/madeline/multi-fairgen/src/flows/results/biased_3_redo_trial0 2 | INFO - main.py - 2021-02-19 16:27:30,661 - Writing log file to /atlas/u/madeline/multi-fairgen/src/flows/results/biased_3_redo_trial0/logs 3 | INFO - main.py - 2021-02-19 16:27:30,661 - Exp instance id = 398426 4 | -------------------------------------------------------------------------------- /src/flows/results/kmm_kliep_synthetic_maf/logs/stdout.txt: -------------------------------------------------------------------------------- 1 | INFO - main.py - 2021-02-14 14:13:06,842 - Saving output to /atlas/u/madeline/multi-fairgen/src/flows/results/kmm_synthetic_maf_trial_1 2 | INFO - main.py - 2021-02-14 14:13:06,843 - Writing log file to /atlas/u/madeline/multi-fairgen/src/flows/results/kmm_synthetic_maf_trial_1/logs 3 | INFO - main.py - 2021-02-14 14:13:06,843 - Exp instance id = 1256886 4 | -------------------------------------------------------------------------------- /src/flows/results/uci_blood_transfusion_maf/logs/stdout.txt: -------------------------------------------------------------------------------- 1 | INFO - main.py - 2021-04-20 10:40:10,156 - Saving output to /atlas/u/madeline/multi-fairgen/src/flows/results/biased_0.1_0.9_4-19-21_trial0 2 | INFO - main.py - 2021-04-20 10:40:10,156 - Writing log file to /atlas/u/madeline/multi-fairgen/src/flows/results/biased_0.1_0.9_4-19-21_trial0/logs 3 | INFO - main.py - 2021-04-20 10:40:10,156 - Exp instance id = 3936251 4 | -------------------------------------------------------------------------------- /src/flows/results/uci_wine_quality_maf/logs/stdout.txt: -------------------------------------------------------------------------------- 1 | INFO - main.py - 2021-04-20 10:51:43,331 - Saving output to /atlas/u/madeline/multi-fairgen/src/flows/results/wine_cov_shift_test_perc0.7_sigma0.05_trial0 2 | INFO - main.py - 2021-04-20 10:51:43,332 - Writing log file to /atlas/u/madeline/multi-fairgen/src/flows/results/wine_cov_shift_test_perc0.7_sigma0.05_trial0/logs 3 | INFO - main.py - 2021-04-20 10:51:43,332 - Exp instance id = 2834582 4 | -------------------------------------------------------------------------------- /src/datasets/looping.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import Dataset 3 | 4 | class LoopingDataset(Dataset): 5 | """ 6 | Dataset class to handle indices going out of bounds when training 7 | """ 8 | def __init__(self, dataset): 9 | self.dataset = dataset 10 | 11 | def __len__(self): 12 | return len(self.dataset) 13 | 14 | def __getitem__(self, index): 15 | if index >= len(self.dataset): 16 | index = np.random.choice(len(self.dataset)) 17 | item, label = self.dataset[index] 18 | return item, label -------------------------------------------------------------------------------- /src/util/shell_util.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """Computes and stores the average and current value. 3 | 4 | Adapted from: https://github.com/pytorch/examples/blob/master/imagenet/train.py 5 | """ 6 | def __init__(self): 7 | self.val = 0. 8 | self.avg = 0. 9 | self.sum = 0. 10 | self.count = 0. 11 | 12 | def reset(self): 13 | self.val = 0. 14 | self.avg = 0. 15 | self.sum = 0. 16 | self.count = 0. 17 | 18 | def update(self, val, n=1): 19 | self.val = val 20 | self.sum += val * n 21 | self.count += n 22 | self.avg = self.sum / self.count 23 | -------------------------------------------------------------------------------- /src/configs/classification/mnist/diff_bkgd/attr_bkgd.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 128 3 | n_epochs: 10 4 | ngpu: 1 5 | iter_log: 1000 6 | iter_save: 100 7 | out_dir: "classification/results" 8 | data_dir: "../data" 9 | 10 | data: 11 | dataset: "BackgroundMNIST" 12 | perc: 1.0 13 | image_size: 28 14 | channels: 1 15 | random_flip: false 16 | num_workers: 4 17 | 18 | model: 19 | name: "mlp" 20 | spectral_norm: true 21 | batch_norm: true 22 | in_dim: 784 23 | h_dim: 100 24 | dropout: 0.1 25 | n_classes: 2 26 | 27 | optim: 28 | weight_decay: 0.000 29 | optimizer: "Adam" 30 | lr: 0.0002 31 | beta1: 0.9 32 | amsgrad: false 33 | 34 | loss: 35 | name: "bce" 36 | -------------------------------------------------------------------------------- /src/configs/classification/gmm/mlp_z.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 128 3 | n_epochs: 15 4 | ngpu: 1 5 | iter_log: 1000 6 | iter_save: 100 7 | exp_id: "gmm_clf_z" 8 | out_dir: "/path/to/f-dre/src/classification/results/" 9 | data_dir: "/path/to/f-dre/data/" 10 | 11 | data: 12 | dataset: "GMM" 13 | subset: false 14 | x_space: false 15 | input_size: 2 16 | perc: 1.0 17 | mus: [0, 3] 18 | class_idx: 20 19 | num_workers: 4 20 | 21 | model: 22 | name: "mlp" 23 | spectral_norm: true 24 | batch_norm: true 25 | in_dim: 2 26 | h_dim: 200 27 | dropout: 0.1 28 | n_classes: 2 29 | 30 | optim: 31 | weight_decay: 0.0005 32 | optimizer: "Adam" 33 | lr: 0.0002 34 | beta1: 0.9 35 | amsgrad: false 36 | 37 | loss: 38 | name: "bce" -------------------------------------------------------------------------------- /src/configs/classification/mi/mlp_z.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 128 3 | n_epochs: 200 4 | ngpu: 1 5 | iter_log: 1000 6 | iter_save: 100 7 | exp_id: "mi_z" 8 | out_dir: "/path/to/f-dre/src/classification/results/" 9 | data_dir: "/path/to/f-dre/data/" 10 | 11 | data: 12 | dataset: "MI" 13 | subset: false 14 | x_space: false 15 | input_size: 40 16 | perc: 1.0 17 | mus: [0, 3] 18 | class_idx: 20 19 | num_workers: 4 20 | rho: 0.9 21 | 22 | model: 23 | name: "mlp" 24 | spectral_norm: true 25 | batch_norm: true 26 | in_dim: 40 27 | h_dim: 200 28 | dropout: 0.1 29 | n_classes: 1 30 | 31 | optim: 32 | weight_decay: 0.0005 33 | optimizer: "Adam" 34 | lr: 0.0002 35 | beta1: 0.9 36 | amsgrad: false 37 | 38 | loss: 39 | name: "bce" -------------------------------------------------------------------------------- /src/configs/classification/gmm/flow_mlp_z.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 128 3 | n_epochs: 20 4 | ngpu: 1 5 | iter_log: 1000 6 | iter_save: 100 7 | exp_id: "gmm_flow_mlp_z" 8 | out_dir: "/path/to/f-dre/src/classification/results/" 9 | data_dir: "/path/to/f-dre/data/" 10 | 11 | data: 12 | dataset: "GMM" 13 | subset: false 14 | x_space: true 15 | input_size: 2 16 | perc: 1.0 17 | mus: [0, 3] 18 | class_idx: 20 19 | num_workers: 4 20 | 21 | model: 22 | name: "flow_mlp" 23 | spectral_norm: true 24 | batch_norm: true 25 | in_dim: 2 26 | h_dim: 200 27 | dropout: 0.1 28 | n_classes: 2 29 | 30 | optim: 31 | weight_decay: 0.0005 32 | optimizer: "Adam" 33 | lr: 0.0002 34 | beta1: 0.9 35 | amsgrad: false 36 | 37 | loss: 38 | name: "bce" -------------------------------------------------------------------------------- /src/configs/classification/gmm/mlp_x.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 128 3 | n_epochs: 15 4 | ngpu: 1 5 | iter_log: 1000 6 | iter_save: 100 7 | exp_id: "gmm_clf_x" 8 | out_dir: "/path/to/f-dre/src/classification/results/" 9 | data_dir: "/path/to/f-dre/data/" 10 | 11 | data: 12 | dataset: "GMM" 13 | subset: false 14 | x_space: true 15 | input_size: 2 16 | perc: 1.0 17 | mus: [0, 3] 18 | class_idx: 20 19 | num_workers: 4 20 | 21 | model: 22 | name: "mlp" 23 | spectral_norm: true 24 | batch_norm: true 25 | in_dim: 2 26 | h_dim: 200 27 | dropout: 0.1 28 | n_classes: 2 29 | 30 | optim: 31 | weight_decay: 0.000 32 | optimizer: "Adam" 33 | lr: 0.0002 34 | beta1: 0.9 35 | amsgrad: false 36 | 37 | loss: 38 | name: "bce" 39 | alpha: 100 -------------------------------------------------------------------------------- /src/flows/functions/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | 3 | 4 | def get_optimizer(config, parameters): 5 | if config.optim.optimizer == 'Adam': 6 | return optim.Adam(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay, 7 | betas=(config.optim.beta1, 0.999), amsgrad=config.optim.amsgrad, 8 | eps=config.optim.eps) 9 | elif config.optim.optimizer == 'RMSProp': 10 | return optim.RMSprop(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay) 11 | elif config.optim.optimizer == 'SGD': 12 | return optim.SGD(parameters, lr=config.optim.lr, momentum=0.9) 13 | else: 14 | raise NotImplementedError( 15 | 'Optimizer {} not understood.'.format(config.optim.optimizer)) 16 | -------------------------------------------------------------------------------- /src/configs/classification/mi/mlp_x.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 128 3 | n_epochs: 200 4 | ngpu: 1 5 | iter_log: 1000 6 | iter_save: 100 7 | exp_id: "mi_x" 8 | out_dir: "/path/to/f-dre/src/classification/results/" 9 | data_dir: "/path/to/f-dre/data/" 10 | 11 | data: 12 | dataset: "MI" 13 | subset: false 14 | x_space: true 15 | input_size: 40 16 | perc: 1.0 17 | mus: [0, 3] 18 | class_idx: 20 19 | num_workers: 4 20 | rho: 0.9 21 | 22 | model: 23 | name: "mlp" 24 | spectral_norm: true 25 | batch_norm: true 26 | in_dim: 40 27 | h_dim: 200 28 | dropout: 0.1 29 | n_classes: 1 30 | 31 | optim: 32 | weight_decay: 0.0005 33 | optimizer: "Adam" 34 | lr: 0.0002 35 | beta1: 0.9 36 | amsgrad: false 37 | 38 | loss: 39 | name: "bce" 40 | alpha: 0.01 -------------------------------------------------------------------------------- /src/configs/classification/gmm/joint_flow_mlp_z.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 128 3 | n_epochs: 200 4 | ngpu: 1 5 | iter_log: 1000 6 | iter_save: 100 7 | exp_id: "joint_flow_z" 8 | out_dir: "/path/to/f-dre/src/classification/results/" 9 | data_dir: "/path/to/f-dre/data/" 10 | 11 | data: 12 | dataset: "GMM" 13 | subset: false 14 | x_space: true 15 | input_size: 2 16 | perc: 1.0 17 | mus: [0, 3] 18 | class_idx: 20 19 | num_workers: 4 20 | 21 | model: 22 | name: "flow_mlp" 23 | spectral_norm: true 24 | batch_norm: true 25 | in_dim: 2 26 | h_dim: 200 27 | dropout: 0.1 28 | n_classes: 2 29 | 30 | optim: 31 | weight_decay: 0.000001 32 | optimizer: "Adam" 33 | lr: 0.0001 34 | beta1: 0.9 35 | amsgrad: false 36 | 37 | loss: 38 | name: "joint" 39 | alpha: 0.2 -------------------------------------------------------------------------------- /src/configs/classification/mnist/diff_bkgd/encodings0.1.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 100 3 | n_epochs: 5 4 | ngpu: 1 5 | iter_log: 1000 6 | iter_save: 100 7 | out_dir: "classification/results" 8 | data_dir: "../data" 9 | 10 | data: 11 | dataset: "SplitEncodedMNIST" 12 | encoding_model: "maf" 13 | encoded_dataset: "BackgroundMNIST" 14 | perc: 0.1 15 | image_size: 28 16 | channels: 1 17 | random_flip: false 18 | num_workers: 4 19 | flip_toy_var_order: true 20 | model: 21 | name: "mlp" 22 | spectral_norm: true 23 | batch_norm: true 24 | in_dim: 784 25 | h_dim: 100 26 | dropout: 0.1 27 | n_classes: 2 28 | 29 | optim: 30 | weight_decay: 0.000 31 | optimizer: "Adam" 32 | lr: 0.0002 33 | beta1: 0.9 34 | amsgrad: false 35 | 36 | loss: 37 | name: "bce" 38 | -------------------------------------------------------------------------------- /src/configs/classification/mnist/diff_bkgd/encodings0.25.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 100 3 | n_epochs: 5 4 | ngpu: 1 5 | iter_log: 1000 6 | iter_save: 100 7 | out_dir: "classification/results" 8 | data_dir: "../data" 9 | 10 | data: 11 | dataset: "SplitEncodedMNIST" 12 | encoding_model: "maf" 13 | encoded_dataset: "BackgroundMNIST" 14 | perc: 0.25 15 | image_size: 28 16 | channels: 1 17 | random_flip: false 18 | num_workers: 4 19 | flip_toy_var_order: true 20 | 21 | model: 22 | name: "mlp" 23 | spectral_norm: true 24 | batch_norm: true 25 | in_dim: 784 26 | h_dim: 100 27 | dropout: 0.1 28 | n_classes: 2 29 | 30 | optim: 31 | weight_decay: 0.000 32 | optimizer: "Adam" 33 | lr: 0.0002 34 | beta1: 0.9 35 | amsgrad: false 36 | 37 | loss: 38 | name: "bce" 39 | -------------------------------------------------------------------------------- /src/configs/classification/mnist/diff_bkgd/encodings0.5.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 100 3 | n_epochs: 5 4 | ngpu: 1 5 | iter_log: 1000 6 | iter_save: 100 7 | out_dir: "classification/results" 8 | data_dir: "../data" 9 | 10 | data: 11 | dataset: "SplitEncodedMNIST" 12 | encoding_model: "maf" 13 | encoded_dataset: "BackgroundMNIST" 14 | perc: 0.5 15 | image_size: 28 16 | channels: 1 17 | random_flip: false 18 | num_workers: 4 19 | flip_toy_var_order: true 20 | 21 | model: 22 | name: "mlp" 23 | spectral_norm: true 24 | batch_norm: true 25 | in_dim: 784 26 | h_dim: 100 27 | dropout: 0.1 28 | n_classes: 2 29 | 30 | optim: 31 | weight_decay: 0.000 32 | optimizer: "Adam" 33 | lr: 0.0002 34 | beta1: 0.9 35 | amsgrad: false 36 | 37 | loss: 38 | name: "bce" 39 | -------------------------------------------------------------------------------- /src/configs/classification/mnist/diff_bkgd/encodings1.0.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 100 3 | n_epochs: 5 4 | ngpu: 1 5 | iter_log: 1000 6 | iter_save: 100 7 | out_dir: "classification/results" 8 | data_dir: "../data" 9 | 10 | data: 11 | dataset: "SplitEncodedMNIST" 12 | encoding_model: "maf" 13 | encoded_dataset: "BackgroundMNIST" 14 | perc: 1.0 15 | image_size: 28 16 | channels: 1 17 | random_flip: false 18 | num_workers: 4 19 | flip_toy_var_order: true 20 | 21 | model: 22 | name: "mlp" 23 | spectral_norm: true 24 | batch_norm: true 25 | in_dim: 784 26 | h_dim: 100 27 | dropout: 0.1 28 | n_classes: 2 29 | 30 | optim: 31 | weight_decay: 0.000 32 | optimizer: "Adam" 33 | lr: 0.0002 34 | beta1: 0.9 35 | amsgrad: false 36 | 37 | loss: 38 | name: "bce" 39 | -------------------------------------------------------------------------------- /src/configs/classification/mi/disc_flow_z.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 128 3 | n_epochs: 200 4 | ngpu: 1 5 | iter_log: 1000 6 | iter_save: 100 7 | exp_id: "mi_disc_flow_z" 8 | out_dir: "/path/to/f-dre/src/classification/results/" 9 | data_dir: "/path/to/f-dre/data/" 10 | 11 | data: 12 | dataset: "MI" 13 | subset: false 14 | x_space: true 15 | input_size: 40 16 | perc: 1.0 17 | mus: [0, 3] 18 | class_idx: 20 19 | num_workers: 4 20 | rho: 0.9 21 | 22 | model: 23 | name: "flow_mlp" 24 | spectral_norm: true 25 | batch_norm: true 26 | in_dim: 40 27 | h_dim: 200 28 | dropout: 0.1 29 | n_classes: 1 30 | 31 | optim: 32 | weight_decay: 0.0005 33 | optimizer: "Adam" 34 | lr: 0.0002 35 | beta1: 0.9 36 | amsgrad: false 37 | 38 | loss: 39 | name: "bce" 40 | alpha: 0.01 -------------------------------------------------------------------------------- /src/configs/classification/mi/joint_flow_z.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 128 3 | n_epochs: 200 4 | ngpu: 1 5 | iter_log: 1000 6 | iter_save: 100 7 | exp_id: "mi_flow_mlp_z" 8 | out_dir: "/path/to/f-dre/src/classification/results/" 9 | data_dir: "/path/to/f-dre/data/" 10 | 11 | data: 12 | dataset: "MI" 13 | subset: false 14 | x_space: true 15 | input_size: 40 16 | perc: 1.0 17 | mus: [0, 3] 18 | class_idx: 20 19 | num_workers: 4 20 | rho: 0.9 21 | 22 | model: 23 | name: "flow_mlp" 24 | spectral_norm: true 25 | batch_norm: true 26 | in_dim: 40 27 | h_dim: 200 28 | dropout: 0.1 29 | n_classes: 1 30 | 31 | optim: 32 | weight_decay: 0.0005 33 | optimizer: "Adam" 34 | lr: 0.0002 35 | beta1: 0.9 36 | amsgrad: false 37 | 38 | loss: 39 | name: "joint" 40 | alpha: 0.9 -------------------------------------------------------------------------------- /src/configs/classification/mnist/diff_digits/encodings0.1.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 100 3 | n_epochs: 5 4 | ngpu: 1 5 | iter_log: 1000 6 | iter_save: 100 7 | out_dir: "classification/results" 8 | data_dir: "../data" 9 | 10 | data: 11 | dataset: "SplitEncodedMNIST" 12 | encoding_model: "maf" 13 | encoded_dataset: "DigitMNISTSubset" 14 | perc: 0.1 15 | image_size: 28 16 | channels: 1 17 | random_flip: false 18 | num_workers: 4 19 | flip_toy_var_order: true 20 | model: 21 | name: "resnet" 22 | spectral_norm: true 23 | batch_norm: true 24 | in_dim: 784 25 | h_dim: 100 26 | dropout: 0.1 27 | n_classes: 2 28 | 29 | optim: 30 | weight_decay: 0.000 31 | optimizer: "Adam" 32 | lr: 0.0002 33 | beta1: 0.9 34 | amsgrad: false 35 | 36 | loss: 37 | name: "cross_entropy" 38 | -------------------------------------------------------------------------------- /src/configs/classification/mnist/diff_digits/encodings0.5.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 100 3 | n_epochs: 5 4 | ngpu: 1 5 | iter_log: 1000 6 | iter_save: 100 7 | out_dir: "classification/results" 8 | data_dir: "../data" 9 | 10 | data: 11 | dataset: "SplitEncodedMNIST" 12 | encoding_model: "maf" 13 | encoded_dataset: "DigitMNISTSubset" 14 | perc: 0.5 15 | image_size: 28 16 | channels: 1 17 | random_flip: false 18 | num_workers: 4 19 | flip_toy_var_order: true 20 | 21 | model: 22 | name: "resnet" 23 | spectral_norm: true 24 | batch_norm: true 25 | in_dim: 784 26 | h_dim: 100 27 | dropout: 0.1 28 | n_classes: 2 29 | 30 | optim: 31 | weight_decay: 0.000 32 | optimizer: "Adam" 33 | lr: 0.0002 34 | beta1: 0.9 35 | amsgrad: false 36 | 37 | loss: 38 | name: "cross_entropy" 39 | -------------------------------------------------------------------------------- /src/configs/classification/mnist/diff_digits/encodings1.0.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 100 3 | n_epochs: 5 4 | ngpu: 1 5 | iter_log: 1000 6 | iter_save: 100 7 | out_dir: "classification/results" 8 | data_dir: "../data" 9 | 10 | data: 11 | dataset: "SplitEncodedMNIST" 12 | encoding_model: "maf" 13 | encoded_dataset: "DigitMNISTSubset" 14 | perc: 1.0 15 | image_size: 28 16 | channels: 1 17 | random_flip: false 18 | num_workers: 4 19 | flip_toy_var_order: true 20 | 21 | model: 22 | name: "resnet" 23 | spectral_norm: true 24 | batch_norm: true 25 | in_dim: 784 26 | h_dim: 100 27 | dropout: 0.1 28 | n_classes: 2 29 | 30 | optim: 31 | weight_decay: 0.000 32 | optimizer: "Adam" 33 | lr: 0.0002 34 | beta1: 0.9 35 | amsgrad: false 36 | 37 | loss: 38 | name: "cross_entropy" 39 | -------------------------------------------------------------------------------- /src/configs/classification/mnist/diff_digits/encodings0.25.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 100 3 | n_epochs: 5 4 | ngpu: 1 5 | iter_log: 1000 6 | iter_save: 100 7 | out_dir: "classification/results" 8 | data_dir: "../data" 9 | 10 | data: 11 | dataset: "SplitEncodedMNIST" 12 | encoding_model: "maf" 13 | encoded_dataset: "DigitMNISTSubset" 14 | perc: 0.25 15 | image_size: 28 16 | channels: 1 17 | random_flip: false 18 | num_workers: 4 19 | flip_toy_var_order: true 20 | 21 | model: 22 | name: "resnet" 23 | spectral_norm: true 24 | batch_norm: true 25 | in_dim: 784 26 | h_dim: 100 27 | dropout: 0.1 28 | n_classes: 2 29 | 30 | optim: 31 | weight_decay: 0.000 32 | optimizer: "Adam" 33 | lr: 0.0002 34 | beta1: 0.9 35 | amsgrad: false 36 | 37 | loss: 38 | name: "cross_entropy" 39 | -------------------------------------------------------------------------------- /src/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def joint_gen_disc_loss(model, x, logits, y, alpha): 8 | """ 9 | joint generative/discriminative approach 10 | """ 11 | clf_loss = F.binary_cross_entropy_with_logits(logits.squeeze(), y.float()) 12 | flow_loss = -model.flow.log_prob(x).mean(0) 13 | # TODO: double check shapes 14 | loss = (alpha * clf_loss) + ((1.-alpha) * flow_loss) 15 | 16 | return loss 17 | 18 | 19 | def grad_penalty(model, alpha, device, norm_type='fro'): 20 | total_norm = torch.norm( 21 | torch.stack([ 22 | torch.norm(p.grad.detach(), norm_type) for (n,p) in model.named_parameters() if 'fc' in n 23 | ]), norm_type 24 | ) 25 | return (alpha * total_norm) 26 | -------------------------------------------------------------------------------- /src/configs/classification/mnist/diff_digits/attr_digits.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 128 3 | n_epochs: 10 4 | ngpu: 1 5 | iter_log: 1000 6 | iter_save: 100 7 | out_dir: "classification/results" 8 | data_dir: "../data" 9 | 10 | data: 11 | dataset: "DigitMNISTSubset" 12 | perc: 1.0 13 | biased_digits: [0, 7] 14 | biased_digit_percs: [0.5, 0.5] 15 | ref_digits: [1, 2] 16 | ref_digit_percs: [0.5, 0.5] 17 | image_size: 28 18 | channels: 1 19 | random_flip: false 20 | num_workers: 4 21 | 22 | model: 23 | name: "mlp" 24 | spectral_norm: true 25 | batch_norm: true 26 | in_dim: 784 27 | h_dim: 100 28 | dropout: 0.1 29 | n_classes: 2 30 | 31 | optim: 32 | weight_decay: 0.000 33 | optimizer: "Adam" 34 | lr: 0.0002 35 | beta1: 0.9 36 | amsgrad: false 37 | 38 | loss: 39 | name: "bce" 40 | -------------------------------------------------------------------------------- /src/configs/classification/mnist/diff_digits/encodings0.1_mlp.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 100 3 | n_epochs: 5 4 | ngpu: 1 5 | iter_log: 1000 6 | iter_save: 100 7 | out_dir: "/atlas/u/madeline/multi-fairgen/src/classification/results" 8 | data_dir: "/atlas/u/madeline/multi-fairgen/data/" 9 | 10 | data: 11 | dataset: "SplitEncodedMNIST" 12 | encoding_model: "maf" 13 | encoded_dataset: "DigitMNISTSubset" 14 | perc: 0.1 15 | image_size: 28 16 | channels: 1 17 | random_flip: false 18 | num_workers: 4 19 | flip_toy_var_order: true 20 | model: 21 | name: "mlp" 22 | spectral_norm: true 23 | batch_norm: true 24 | in_dim: 784 25 | h_dim: 100 26 | dropout: 0.1 27 | n_classes: 2 28 | 29 | optim: 30 | weight_decay: 0.000 31 | optimizer: "Adam" 32 | lr: 0.0002 33 | beta1: 0.9 34 | amsgrad: false 35 | 36 | loss: 37 | name: "bce" -------------------------------------------------------------------------------- /src/configs/classification/mnist/diff_digits/encodings0.5_mlp.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 100 3 | n_epochs: 5 4 | ngpu: 1 5 | iter_log: 1000 6 | iter_save: 100 7 | out_dir: "/atlas/u/madeline/multi-fairgen/src/classification/results" 8 | data_dir: "/atlas/u/madeline/multi-fairgen/data/" 9 | 10 | data: 11 | dataset: "SplitEncodedMNIST" 12 | encoding_model: "maf" 13 | encoded_dataset: "DigitMNISTSubset" 14 | perc: 0.5 15 | image_size: 28 16 | channels: 1 17 | random_flip: false 18 | num_workers: 4 19 | flip_toy_var_order: true 20 | 21 | model: 22 | name: "mlp" 23 | spectral_norm: true 24 | batch_norm: true 25 | in_dim: 784 26 | h_dim: 100 27 | dropout: 0.1 28 | n_classes: 2 29 | 30 | optim: 31 | weight_decay: 0.000 32 | optimizer: "Adam" 33 | lr: 0.0002 34 | beta1: 0.9 35 | amsgrad: false 36 | 37 | loss: 38 | name: "bce" -------------------------------------------------------------------------------- /src/configs/classification/mnist/diff_digits/encodings1.0_mlp.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 100 3 | n_epochs: 5 4 | ngpu: 1 5 | iter_log: 1000 6 | iter_save: 100 7 | out_dir: "/atlas/u/madeline/multi-fairgen/src/classification/results" 8 | data_dir: "/atlas/u/madeline/multi-fairgen/data/" 9 | 10 | data: 11 | dataset: "SplitEncodedMNIST" 12 | encoding_model: "maf" 13 | encoded_dataset: "DigitMNISTSubset" 14 | perc: 1.0 15 | image_size: 28 16 | channels: 1 17 | random_flip: false 18 | num_workers: 4 19 | flip_toy_var_order: true 20 | 21 | model: 22 | name: "mlp" 23 | spectral_norm: true 24 | batch_norm: true 25 | in_dim: 784 26 | h_dim: 100 27 | dropout: 0.1 28 | n_classes: 2 29 | 30 | optim: 31 | weight_decay: 0.000 32 | optimizer: "Adam" 33 | lr: 0.0002 34 | beta1: 0.9 35 | amsgrad: false 36 | 37 | loss: 38 | name: "bce" -------------------------------------------------------------------------------- /src/configs/classification/mnist/diff_digits/encodings0.25_mlp.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 100 3 | n_epochs: 5 4 | ngpu: 1 5 | iter_log: 1000 6 | iter_save: 100 7 | out_dir: "/atlas/u/madeline/multi-fairgen/src/classification/results" 8 | data_dir: "/atlas/u/madeline/multi-fairgen/data/" 9 | 10 | data: 11 | dataset: "SplitEncodedMNIST" 12 | encoding_model: "maf" 13 | encoded_dataset: "DigitMNISTSubset" 14 | perc: 0.25 15 | image_size: 28 16 | channels: 1 17 | random_flip: false 18 | num_workers: 4 19 | flip_toy_var_order: true 20 | 21 | model: 22 | name: "mlp" 23 | spectral_norm: true 24 | batch_norm: true 25 | in_dim: 784 26 | h_dim: 100 27 | dropout: 0.1 28 | n_classes: 2 29 | 30 | optim: 31 | weight_decay: 0.000 32 | optimizer: "Adam" 33 | lr: 0.0002 34 | beta1: 0.9 35 | amsgrad: false 36 | 37 | loss: 38 | name: "bce" -------------------------------------------------------------------------------- /src/configs/classification/omniglot/x_dre.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 32 3 | n_epochs: 10 4 | ngpu: 1 5 | iter_log: 1000 6 | iter_save: 100 7 | exp_id: "omniglot_x_dre_clf" 8 | out_dir: "/path/to/f-dre/src/classification/results/" 9 | data_dir: "/path/to/f-dre/data/" 10 | 11 | data: 12 | dataset: "Omniglot_Mixture" 13 | subset: false 14 | x_space: true 15 | synthetic: true 16 | input_size: 784 17 | image_size: 28 18 | channels: 1 19 | perc: 1.0 20 | mus: [0, 3] 21 | class_idx: 20 22 | num_workers: 4 23 | 24 | model: 25 | name: "cnn_bce" 26 | baseline: true 27 | spectral_norm: true 28 | batch_norm: true 29 | in_dim: 784 30 | h_dim: 200 31 | dropout: 0.1 32 | n_classes: 1 33 | 34 | optim: 35 | weight_decay: 0.000 36 | optimizer: "Adam" 37 | lr: 0.001 38 | beta1: 0.9 39 | amsgrad: false 40 | 41 | loss: 42 | name: "bce" 43 | alpha: 0.01 -------------------------------------------------------------------------------- /src/configs/classification/omniglot/z_dre.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 32 3 | n_epochs: 20 4 | ngpu: 1 5 | iter_log: 1000 6 | iter_save: 100 7 | exp_id: "omniglot_z_dre_clf" 8 | out_dir: "/path/to/f-dre/src/classification/results/" 9 | data_dir: "/path/to/f-dre/data/" 10 | 11 | data: 12 | dataset: "Omniglot_Mixture" 13 | subset: false 14 | x_space: false 15 | synthetic: true 16 | input_size: 784 17 | image_size: 28 18 | channels: 1 19 | perc: 1.0 20 | mus: [0, 3] 21 | class_idx: 20 22 | num_workers: 4 23 | 24 | model: 25 | name: "cnn_bce" 26 | baseline: true 27 | spectral_norm: true 28 | batch_norm: true 29 | in_dim: 784 30 | h_dim: 200 31 | dropout: 0.1 32 | n_classes: 1 33 | 34 | optim: 35 | weight_decay: 0.000 36 | optimizer: "Adam" 37 | lr: 0.001 38 | beta1: 0.9 39 | amsgrad: false 40 | 41 | loss: 42 | name: "bce" 43 | alpha: 0.01 -------------------------------------------------------------------------------- /src/configs/classification/omniglot/baseline.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 32 3 | n_epochs: 100 4 | ngpu: 1 5 | iter_log: 1000 6 | iter_save: 100 7 | exp_id: "omniglot_baseline" 8 | out_dir: "/path/to/f-dre/src/classification/results/" 9 | data_dir: "/path/to/f-dre/data/" 10 | 11 | data: 12 | dataset: "Omniglot" 13 | subset: false 14 | x_space: true 15 | synthetic: false 16 | augment: false 17 | input_size: 784 18 | image_size: 28 19 | channels: 1 20 | perc: 1.0 21 | mus: [0, 3] 22 | class_idx: 20 23 | num_workers: 4 24 | 25 | model: 26 | name: "cnn" 27 | baseline: true 28 | spectral_norm: true 29 | batch_norm: true 30 | in_dim: 784 31 | h_dim: 200 32 | dropout: 0.1 33 | n_classes: 1 34 | 35 | optim: 36 | weight_decay: 0.000 37 | optimizer: "Adam" 38 | lr: 0.001 39 | beta1: 0.9 40 | amsgrad: false 41 | 42 | loss: 43 | name: "cross_entropy" 44 | alpha: 0.01 -------------------------------------------------------------------------------- /src/configs/classification/omniglot/x_method.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 32 3 | n_epochs: 100 4 | ngpu: 1 5 | iter_log: 1000 6 | iter_save: 100 7 | exp_id: "omniglot_method_x" 8 | out_dir: "/path/to/f-dre/src/classification/results/" 9 | data_dir: "/path/to/f-dre/data/" 10 | 11 | data: 12 | dataset: "Omniglot" 13 | subset: false 14 | x_space: true 15 | synthetic: true 16 | augment: true 17 | input_size: 784 18 | image_size: 28 19 | channels: 1 20 | perc: 1.0 21 | mus: [0, 3] 22 | class_idx: 20 23 | num_workers: 4 24 | 25 | model: 26 | name: "cnn" 27 | baseline: false 28 | spectral_norm: true 29 | batch_norm: true 30 | in_dim: 784 31 | h_dim: 200 32 | dropout: 0.1 33 | n_classes: 1 34 | 35 | optim: 36 | weight_decay: 0.000 37 | optimizer: "Adam" 38 | lr: 0.001 39 | beta1: 0.9 40 | amsgrad: false 41 | 42 | loss: 43 | name: "cross_entropy" 44 | alpha: 0.01 -------------------------------------------------------------------------------- /src/configs/classification/omniglot/z_method.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 32 3 | n_epochs: 100 4 | ngpu: 1 5 | iter_log: 1000 6 | iter_save: 100 7 | exp_id: "omniglot_method_z" 8 | out_dir: "/path/to/f-dre/src/classification/results/" 9 | data_dir: "/path/to/f-dre/data/" 10 | 11 | data: 12 | dataset: "Omniglot" 13 | subset: false 14 | x_space: false 15 | synthetic: true 16 | augment: true 17 | input_size: 784 18 | image_size: 28 19 | channels: 1 20 | perc: 1.0 21 | mus: [0, 3] 22 | class_idx: 20 23 | num_workers: 4 24 | 25 | model: 26 | name: "cnn" 27 | baseline: false 28 | spectral_norm: true 29 | batch_norm: true 30 | in_dim: 784 31 | h_dim: 200 32 | dropout: 0.1 33 | n_classes: 1 34 | 35 | optim: 36 | weight_decay: 0.000 37 | optimizer: "Adam" 38 | lr: 0.001 39 | beta1: 0.9 40 | amsgrad: false 41 | 42 | loss: 43 | name: "cross_entropy" 44 | alpha: 0.01 -------------------------------------------------------------------------------- /src/configs/classification/omniglot/gen_baseline.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 32 3 | n_epochs: 100 4 | ngpu: 1 5 | iter_log: 1000 6 | iter_save: 100 7 | exp_id: "omniglot_gen_baseline" 8 | out_dir: "/path/to/f-dre/src/classification/results/" 9 | data_dir: "/path/to/f-dre/data/" 10 | 11 | data: 12 | dataset: "Omniglot" 13 | subset: false 14 | x_space: true 15 | synthetic: true 16 | augment: false 17 | input_size: 784 18 | image_size: 28 19 | channels: 1 20 | perc: 1.0 21 | mus: [0, 3] 22 | class_idx: 20 23 | num_workers: 4 24 | 25 | model: 26 | name: "cnn" 27 | baseline: true 28 | spectral_norm: true 29 | batch_norm: true 30 | in_dim: 784 31 | h_dim: 200 32 | dropout: 0.1 33 | n_classes: 1 34 | 35 | optim: 36 | weight_decay: 0.000 37 | optimizer: "Adam" 38 | lr: 0.001 39 | beta1: 0.9 40 | amsgrad: false 41 | 42 | loss: 43 | name: "cross_entropy" 44 | alpha: 0.01 -------------------------------------------------------------------------------- /src/configs/classification/omniglot/mix_baseline.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 32 3 | n_epochs: 100 4 | ngpu: 1 5 | iter_log: 1000 6 | iter_save: 100 7 | exp_id: "omniglot_mix_baseline" 8 | out_dir: "/path/to/f-dre/src/classification/results/" 9 | data_dir: "/path/to/f-dre/data/" 10 | 11 | data: 12 | dataset: "Omniglot" 13 | subset: false 14 | x_space: true 15 | synthetic: true 16 | augment: true 17 | input_size: 784 18 | image_size: 28 19 | channels: 1 20 | perc: 1.0 21 | mus: [0, 3] 22 | class_idx: 20 23 | num_workers: 4 24 | 25 | model: 26 | name: "cnn" 27 | baseline: true 28 | spectral_norm: true 29 | batch_norm: true 30 | in_dim: 784 31 | h_dim: 200 32 | dropout: 0.1 33 | n_classes: 1 34 | 35 | optim: 36 | weight_decay: 0.000 37 | optimizer: "Adam" 38 | lr: 0.001 39 | beta1: 0.9 40 | amsgrad: false 41 | 42 | loss: 43 | name: "cross_entropy" 44 | alpha: 0.01 -------------------------------------------------------------------------------- /src/flows/results/uci_breast_cancer_maf/config.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: UCI_trials 3 | flip_toy_var_order: true 4 | input_size: 9 5 | num_workers: 4 6 | perc: 1.0 7 | dre: 8 | alpha: 0.06 9 | loss: 10 | name: cross_entropy 11 | model: 12 | activation_fn: relu 13 | cond_label_size: 1 14 | conditional: false 15 | dropout: 0.1 16 | ema: false 17 | hidden_size: 100 18 | input_order: sequential 19 | input_size: 9 20 | n_blocks: 5 21 | n_classes: 2 22 | n_components: 1 23 | n_hidden: 1 24 | name: maf 25 | no_batch_norm: false 26 | optim: 27 | amsgrad: false 28 | beta1: 0.9 29 | eps: 1.0e-08 30 | lr: 0.0001 31 | optimizer: Adam 32 | weight_decay: 1.0e-06 33 | sampling: 34 | n_samples: 50000 35 | training: 36 | batch_size: 100 37 | data_dir: /atlas/u/madeline/multi-fairgen/data/ 38 | iter_save: 100 39 | log_interval: 500 40 | n_epochs: 100 41 | ngpu: 1 42 | out_dir: /atlas/u/madeline/multi-fairgen/src/flows/results 43 | -------------------------------------------------------------------------------- /src/flows/results/uci_wine_quality_maf/config.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: UCI_trials 3 | flip_toy_var_order: true 4 | input_size: 11 5 | num_workers: 4 6 | perc: 1.0 7 | dre: 8 | alpha: 0.06 9 | loss: 10 | name: cross_entropy 11 | model: 12 | activation_fn: relu 13 | cond_label_size: 1 14 | conditional: false 15 | dropout: 0.1 16 | ema: false 17 | hidden_size: 100 18 | input_order: sequential 19 | input_size: 11 20 | n_blocks: 5 21 | n_classes: 2 22 | n_components: 1 23 | n_hidden: 1 24 | name: maf 25 | no_batch_norm: false 26 | optim: 27 | amsgrad: false 28 | beta1: 0.9 29 | eps: 1.0e-08 30 | lr: 0.0001 31 | optimizer: Adam 32 | weight_decay: 1.0e-06 33 | sampling: 34 | n_samples: 50000 35 | training: 36 | batch_size: 100 37 | data_dir: /atlas/u/madeline/multi-fairgen/data/ 38 | iter_save: 100 39 | log_interval: 500 40 | n_epochs: 100 41 | ngpu: 1 42 | out_dir: /atlas/u/madeline/multi-fairgen/src/flows/results 43 | -------------------------------------------------------------------------------- /src/flows/results/uci_blood_transfusion_maf/config.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: UCITransfusion 3 | flip_toy_var_order: true 4 | input_size: 4 5 | num_workers: 4 6 | perc: 1.0 7 | dre: 8 | alpha: 0.06 9 | loss: 10 | name: cross_entropy 11 | model: 12 | activation_fn: relu 13 | cond_label_size: 1 14 | conditional: false 15 | dropout: 0.1 16 | ema: false 17 | hidden_size: 100 18 | input_order: sequential 19 | input_size: 4 20 | n_blocks: 5 21 | n_classes: 2 22 | n_components: 1 23 | n_hidden: 1 24 | name: maf 25 | no_batch_norm: false 26 | optim: 27 | amsgrad: false 28 | beta1: 0.9 29 | eps: 1.0e-08 30 | lr: 0.0001 31 | optimizer: Adam 32 | weight_decay: 1.0e-06 33 | sampling: 34 | n_samples: 50000 35 | training: 36 | batch_size: 100 37 | data_dir: /atlas/u/madeline/multi-fairgen/data/ 38 | iter_save: 100 39 | log_interval: 500 40 | n_epochs: 100 41 | ngpu: 1 42 | out_dir: /atlas/u/madeline/multi-fairgen/src/flows/results 43 | -------------------------------------------------------------------------------- /src/configs/flows/gmm/maf.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 100 3 | n_epochs: 100 4 | ngpu: 1 5 | log_interval: 500 6 | iter_save: 100 7 | exp_id: "gmm_flow" 8 | out_dir: "/path/to/f-dre/src/flows/results" 9 | data_dir: "/path/to/f-dre/data/" 10 | 11 | data: 12 | dataset: "GMM_flow" 13 | perc: 1.0 14 | mus: [0, 3] 15 | input_size: 2 16 | num_workers: 4 17 | flip_toy_var_order: true 18 | 19 | dre: 20 | alpha: 0.06 21 | 22 | sampling: 23 | n_samples: 50000 24 | # sir: 1000 25 | 26 | model: 27 | name: "maf" 28 | dropout: 0.1 29 | n_classes: 2 30 | n_blocks: 5 31 | n_components: 1 32 | input_size: 2 33 | hidden_size: 100 34 | n_hidden: 1 35 | activation_fn: 'relu' 36 | input_order: 'sequential' 37 | conditional: false 38 | no_batch_norm: false 39 | cond_label_size: 1 40 | ema: false 41 | 42 | optim: 43 | weight_decay: 0.000001 44 | optimizer: "Adam" 45 | lr: 0.0001 46 | beta1: 0.9 47 | amsgrad: false 48 | eps: 0.00000001 49 | 50 | loss: 51 | name: "cross_entropy" -------------------------------------------------------------------------------- /src/configs/flows/mnist/diff_bkgd/perc0.1.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 100 3 | n_epochs: 200 4 | ngpu: 1 5 | log_interval: 100 6 | iter_save: 100 7 | out_dir: "flows/results" 8 | data_dir: "../data/" 9 | 10 | data: 11 | dataset: "BackgroundMNIST" 12 | perc: 0.1 13 | image_size: 28 14 | channels: 1 15 | random_flip: false 16 | num_workers: 4 17 | flip_toy_var_order: true 18 | 19 | dre: 20 | alpha: 0 21 | 22 | sampling: 23 | n_samples: 1000 24 | n_sir: 1000 25 | 26 | model: 27 | name: "maf" 28 | dropout: 0.1 29 | n_classes: 2 30 | n_blocks: 5 31 | n_components: 1 32 | input_size: 784 33 | hidden_size: 1024 34 | n_hidden: 1 35 | activation_fn: 'relu' 36 | input_order: 'sequential' 37 | conditional: false 38 | no_batch_norm: false 39 | cond_label_size: 10 40 | ema: false 41 | 42 | optim: 43 | weight_decay: 0.000001 44 | optimizer: "Adam" 45 | lr: 0.0001 46 | beta1: 0.9 47 | amsgrad: false 48 | eps: 0.00000001 49 | 50 | loss: 51 | name: "cross_entropy" 52 | -------------------------------------------------------------------------------- /src/flows/results/kmm_kliep_synthetic_maf/config.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: KMM_trials 3 | flip_toy_var_order: true 4 | input_size: 2 5 | mus: 6 | - 0 7 | - 3 8 | num_workers: 4 9 | perc: 1.0 10 | dre: 11 | alpha: 0.06 12 | loss: 13 | name: cross_entropy 14 | model: 15 | activation_fn: relu 16 | cond_label_size: 1 17 | conditional: false 18 | dropout: 0.1 19 | ema: false 20 | hidden_size: 100 21 | input_order: sequential 22 | input_size: 2 23 | n_blocks: 5 24 | n_classes: 2 25 | n_components: 1 26 | n_hidden: 1 27 | name: maf 28 | no_batch_norm: false 29 | optim: 30 | amsgrad: false 31 | beta1: 0.9 32 | eps: 1.0e-08 33 | lr: 0.0001 34 | optimizer: Adam 35 | weight_decay: 1.0e-06 36 | sampling: 37 | n_samples: 50000 38 | training: 39 | batch_size: 100 40 | data_dir: /atlas/u/madeline/multi-fairgen/data/ 41 | exp_id: kmm_flow 42 | iter_save: 100 43 | log_interval: 500 44 | n_epochs: 100 45 | ngpu: 1 46 | out_dir: /atlas/u/madeline/multi-fairgen/src/flows/results 47 | -------------------------------------------------------------------------------- /src/configs/flows/uci_breast_cancer/maf.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 100 3 | n_epochs: 100 4 | ngpu: 1 5 | log_interval: 500 6 | iter_save: 100 7 | out_dir: "/atlas/u/madeline/multi-fairgen/src/flows/results" 8 | data_dir: "/atlas/u/madeline/multi-fairgen/data/" 9 | 10 | data: 11 | dataset: "UCI_trials" 12 | perc: 1.0 13 | input_size: 9 14 | num_workers: 4 15 | flip_toy_var_order: true 16 | 17 | dre: 18 | alpha: 0.06 19 | 20 | sampling: 21 | n_samples: 50000 22 | # sir: 1000 23 | 24 | model: 25 | name: "maf" 26 | dropout: 0.1 27 | n_classes: 2 28 | n_blocks: 5 29 | n_components: 1 30 | input_size: 9 31 | hidden_size: 100 32 | n_hidden: 1 33 | activation_fn: 'relu' 34 | input_order: 'sequential' 35 | conditional: false 36 | no_batch_norm: false 37 | cond_label_size: 1 38 | ema: false 39 | 40 | optim: 41 | weight_decay: 0.000001 42 | optimizer: "Adam" 43 | lr: 0.0001 44 | beta1: 0.9 45 | amsgrad: false 46 | eps: 0.00000001 47 | 48 | loss: 49 | name: "cross_entropy" -------------------------------------------------------------------------------- /src/configs/flows/uci_blood_transfusion/maf.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 100 3 | n_epochs: 100 4 | ngpu: 1 5 | log_interval: 500 6 | iter_save: 100 7 | out_dir: "/atlas/u/madeline/multi-fairgen/src/flows/results" 8 | data_dir: "/atlas/u/madeline/multi-fairgen/data/" 9 | 10 | data: 11 | dataset: "UCITransfusion" 12 | perc: 1.0 13 | input_size: 4 14 | num_workers: 4 15 | flip_toy_var_order: true 16 | 17 | dre: 18 | alpha: 0.06 19 | 20 | sampling: 21 | n_samples: 50000 22 | # sir: 1000 23 | 24 | model: 25 | name: "maf" 26 | dropout: 0.1 27 | n_classes: 2 28 | n_blocks: 5 29 | n_components: 1 30 | input_size: 4 31 | hidden_size: 100 32 | n_hidden: 1 33 | activation_fn: 'relu' 34 | input_order: 'sequential' 35 | conditional: false 36 | no_batch_norm: false 37 | cond_label_size: 1 38 | ema: false 39 | 40 | optim: 41 | weight_decay: 0.000001 42 | optimizer: "Adam" 43 | lr: 0.0001 44 | beta1: 0.9 45 | amsgrad: false 46 | eps: 0.00000001 47 | 48 | loss: 49 | name: "cross_entropy" -------------------------------------------------------------------------------- /src/configs/flows/uci_wine_quality/maf.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 100 3 | n_epochs: 100 4 | ngpu: 1 5 | log_interval: 500 6 | iter_save: 100 7 | out_dir: "/atlas/u/madeline/multi-fairgen/src/flows/results" 8 | data_dir: "/atlas/u/madeline/multi-fairgen/data/" 9 | 10 | data: 11 | dataset: "UCI_trials" 12 | perc: 1.0 13 | input_size: 11 14 | num_workers: 4 15 | flip_toy_var_order: true 16 | 17 | dre: 18 | alpha: 0.06 19 | 20 | sampling: 21 | n_samples: 50000 22 | # sir: 1000 23 | 24 | model: 25 | name: "maf" 26 | dropout: 0.1 27 | n_classes: 2 28 | n_blocks: 5 29 | n_components: 1 30 | input_size: 11 31 | hidden_size: 100 32 | n_hidden: 1 33 | activation_fn: 'relu' 34 | input_order: 'sequential' 35 | conditional: false 36 | no_batch_norm: false 37 | cond_label_size: 1 38 | ema: false 39 | 40 | optim: 41 | weight_decay: 0.000001 42 | optimizer: "Adam" 43 | lr: 0.0001 44 | beta1: 0.9 45 | amsgrad: false 46 | eps: 0.00000001 47 | 48 | loss: 49 | name: "cross_entropy" 50 | -------------------------------------------------------------------------------- /src/configs/flows/mnist/diff_bkgd/perc0.5.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 100 3 | n_epochs: 200 4 | ngpu: 1 5 | log_interval: 100 6 | iter_save: 100 7 | out_dir: "flows/results" 8 | data_dir: "../data/" 9 | 10 | data: 11 | dataset: "BackgroundMNIST" 12 | perc: 0.5 13 | image_size: 28 14 | channels: 1 15 | random_flip: false 16 | num_workers: 4 17 | flip_toy_var_order: true 18 | 19 | dre: 20 | alpha: 0 21 | 22 | sampling: 23 | n_samples: 1000 24 | n_sir: 1000 25 | # sir: 1000 26 | 27 | model: 28 | name: "maf" 29 | dropout: 0.1 30 | n_classes: 2 31 | n_blocks: 5 32 | n_components: 1 33 | input_size: 784 34 | hidden_size: 1024 35 | n_hidden: 1 36 | activation_fn: 'relu' 37 | input_order: 'sequential' 38 | conditional: false 39 | no_batch_norm: false 40 | cond_label_size: 10 41 | ema: false 42 | 43 | optim: 44 | weight_decay: 0.000001 45 | optimizer: "Adam" 46 | lr: 0.0001 47 | beta1: 0.9 48 | amsgrad: false 49 | eps: 0.00000001 50 | 51 | loss: 52 | name: "cross_entropy" 53 | -------------------------------------------------------------------------------- /src/configs/flows/mnist/diff_bkgd/perc0.25.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 100 3 | n_epochs: 200 4 | ngpu: 1 5 | log_interval: 100 6 | iter_save: 100 7 | out_dir: "flows/results" 8 | data_dir: "../data/" 9 | 10 | data: 11 | dataset: "BackgroundMNIST" 12 | perc: 0.25 13 | image_size: 28 14 | channels: 1 15 | random_flip: false 16 | num_workers: 4 17 | flip_toy_var_order: true 18 | 19 | dre: 20 | alpha: 0 21 | 22 | sampling: 23 | n_samples: 1000 24 | n_sir: 1000 25 | # sir: 1000 26 | 27 | model: 28 | name: "maf" 29 | dropout: 0.1 30 | n_classes: 2 31 | n_blocks: 5 32 | n_components: 1 33 | input_size: 784 34 | hidden_size: 1024 35 | n_hidden: 1 36 | activation_fn: 'relu' 37 | input_order: 'sequential' 38 | conditional: false 39 | no_batch_norm: false 40 | cond_label_size: 10 41 | ema: false 42 | 43 | optim: 44 | weight_decay: 0.000001 45 | optimizer: "Adam" 46 | lr: 0.0001 47 | beta1: 0.9 48 | amsgrad: false 49 | eps: 0.00000001 50 | 51 | loss: 52 | name: "cross_entropy" 53 | -------------------------------------------------------------------------------- /src/configs/flows/mnist/diff_bkgd/perc1.0.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 100 3 | n_epochs: 200 4 | ngpu: 1 5 | log_interval: 100 6 | iter_save: 100 7 | out_dir: "flows/results" 8 | data_dir: "../data/" 9 | 10 | data: 11 | dataset: "BackgroundMNIST" 12 | perc: 1.0 13 | image_size: 28 14 | channels: 1 15 | random_flip: false 16 | num_workers: 4 17 | flip_toy_var_order: true 18 | 19 | dre: 20 | alpha: 0 21 | 22 | sampling: 23 | n_samples: 1000 24 | n_sir: 1000 25 | # sir: 1000 26 | 27 | model: 28 | name: "maf" 29 | dropout: 0.1 30 | n_classes: 2 31 | n_blocks: 5 32 | n_components: 1 33 | input_size: 784 34 | hidden_size: 1024 35 | n_hidden: 1 36 | activation_fn: 'relu' 37 | input_order: 'sequential' 38 | conditional: false 39 | no_batch_norm: false 40 | cond_label_size: 10 41 | ema: false 42 | 43 | optim: 44 | weight_decay: 0.000001 45 | optimizer: "Adam" 46 | lr: 0.0001 47 | beta1: 0.9 48 | amsgrad: false 49 | eps: 0.00000001 50 | 51 | loss: 52 | name: "cross_entropy" 53 | -------------------------------------------------------------------------------- /src/configs/flows/mi/mi_maf.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 100 3 | n_epochs: 200 4 | ngpu: 1 5 | log_interval: 500 6 | iter_save: 100 7 | exp_id: "mi_flow" 8 | out_dir: "/path/to/f-dre/src/flows/results" 9 | data_dir: "/path/to/f-dre/data/" 10 | 11 | data: 12 | dataset: "MI_flow" 13 | subset: false 14 | x_space: false 15 | perc: 1.0 16 | mus: [0, 3] 17 | input_size: 40 18 | num_workers: 4 19 | flip_toy_var_order: true 20 | rho: 0.9 21 | 22 | dre: 23 | alpha: 0.06 24 | 25 | sampling: 26 | n_samples: 50000 27 | # sir: 1000 28 | 29 | model: 30 | name: "maf" 31 | dropout: 0.1 32 | n_classes: 2 33 | n_blocks: 5 34 | n_components: 1 35 | input_size: 40 36 | hidden_size: 100 37 | n_hidden: 1 38 | activation_fn: 'relu' 39 | input_order: 'sequential' 40 | conditional: false 41 | no_batch_norm: false 42 | cond_label_size: 1 43 | ema: false 44 | 45 | optim: 46 | weight_decay: 0.000001 47 | optimizer: "Adam" 48 | lr: 0.0001 49 | beta1: 0.9 50 | amsgrad: false 51 | eps: 0.00000001 52 | 53 | loss: 54 | name: "cross_entropy" -------------------------------------------------------------------------------- /src/flows/models/glow/actnorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class ActNorm(nn.Module): 5 | """ ActNorm layer; cf Glow section 3.1 """ 6 | def __init__(self, param_dim=(1,3,1,1)): 7 | super().__init__() 8 | self.scale = nn.Parameter(torch.ones(param_dim)) 9 | self.bias = nn.Parameter(torch.zeros(param_dim)) 10 | self.register_buffer('initialized', torch.tensor(0).byte()) 11 | 12 | def forward(self, x): 13 | if not self.initialized: 14 | # per channel mean and variance where x.shape = (B, C, H, W) 15 | self.bias.squeeze().data.copy_(x.transpose(0,1).flatten(1).mean(1)).view_as(self.scale) 16 | self.scale.squeeze().data.copy_(x.transpose(0,1).flatten(1).std(1, False) + 1e-6).view_as(self.bias) 17 | self.initialized += 1 18 | 19 | z = (x - self.bias) / self.scale 20 | logdet = - self.scale.abs().log().sum() * x.shape[2] * x.shape[3] 21 | return z, logdet 22 | 23 | def inverse(self, z): 24 | return z * self.scale + self.bias, self.scale.abs().log().sum() * z.shape[2] * z.shape[3] 25 | 26 | -------------------------------------------------------------------------------- /src/configs/flows/mnist/diff_digits/perc0.1.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 100 3 | n_epochs: 200 4 | ngpu: 1 5 | log_interval: 100 6 | iter_save: 100 7 | out_dir: "flows/results" 8 | data_dir: "../data/" 9 | 10 | data: 11 | dataset: "DigitMNISTSubset" 12 | perc: 0.1 13 | biased_digits: [0, 7] 14 | biased_digit_percs: [0.5, 0.5] 15 | ref_digits: [1, 2] 16 | ref_digit_percs: [0.5, 0.5] 17 | image_size: 28 18 | channels: 1 19 | random_flip: false 20 | num_workers: 4 21 | flip_toy_var_order: true 22 | 23 | dre: 24 | alpha: 0 25 | 26 | sampling: 27 | n_samples: 1000 28 | n_sir: 1000 29 | 30 | model: 31 | name: "maf" 32 | dropout: 0.1 33 | n_classes: 2 34 | n_blocks: 5 35 | n_components: 1 36 | input_size: 784 37 | hidden_size: 1024 38 | n_hidden: 1 39 | activation_fn: 'relu' 40 | input_order: 'sequential' 41 | conditional: false 42 | no_batch_norm: false 43 | cond_label_size: 10 44 | ema: false 45 | 46 | optim: 47 | weight_decay: 0.000001 48 | optimizer: "Adam" 49 | lr: 0.0001 50 | beta1: 0.9 51 | amsgrad: false 52 | eps: 0.00000001 53 | 54 | loss: 55 | name: "cross_entropy" 56 | -------------------------------------------------------------------------------- /src/configs/flows/mnist/diff_digits/perc0.25.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 100 3 | n_epochs: 200 4 | ngpu: 1 5 | log_interval: 100 6 | iter_save: 100 7 | out_dir: "flows/results" 8 | data_dir: "../data/" 9 | 10 | data: 11 | dataset: "DigitMNISTSubset" 12 | perc: 0.25 13 | biased_digits: [0, 7] 14 | biased_digit_percs: [0.5, 0.5] 15 | ref_digits: [1, 2] 16 | ref_digit_percs: [0.5, 0.5] 17 | image_size: 28 18 | channels: 1 19 | random_flip: false 20 | num_workers: 4 21 | flip_toy_var_order: true 22 | 23 | dre: 24 | alpha: 0 25 | 26 | sampling: 27 | n_samples: 1000 28 | n_sir: 1000 29 | # sir: 1000 30 | 31 | model: 32 | name: "maf" 33 | dropout: 0.1 34 | n_classes: 2 35 | n_blocks: 5 36 | n_components: 1 37 | input_size: 784 38 | hidden_size: 1024 39 | n_hidden: 1 40 | activation_fn: 'relu' 41 | input_order: 'sequential' 42 | conditional: false 43 | no_batch_norm: false 44 | cond_label_size: 10 45 | ema: false 46 | 47 | optim: 48 | weight_decay: 0.000001 49 | optimizer: "Adam" 50 | lr: 0.0001 51 | beta1: 0.9 52 | amsgrad: false 53 | eps: 0.00000001 54 | 55 | loss: 56 | name: "cross_entropy" 57 | -------------------------------------------------------------------------------- /src/configs/flows/mnist/diff_digits/perc0.5.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 100 3 | n_epochs: 200 4 | ngpu: 1 5 | log_interval: 100 6 | iter_save: 100 7 | out_dir: "flows/results" 8 | data_dir: "../data/" 9 | 10 | data: 11 | dataset: "DigitMNISTSubset" 12 | perc: 0.5 13 | biased_digits: [0, 7] 14 | biased_digit_percs: [0.5, 0.5] 15 | ref_digits: [1, 2] 16 | ref_digit_percs: [0.5, 0.5] 17 | image_size: 28 18 | channels: 1 19 | random_flip: false 20 | num_workers: 4 21 | flip_toy_var_order: true 22 | 23 | dre: 24 | alpha: 0 25 | 26 | sampling: 27 | n_samples: 1000 28 | n_sir: 1000 29 | # sir: 1000 30 | 31 | model: 32 | name: "maf" 33 | dropout: 0.1 34 | n_classes: 2 35 | n_blocks: 5 36 | n_components: 1 37 | input_size: 784 38 | hidden_size: 1024 39 | n_hidden: 1 40 | activation_fn: 'relu' 41 | input_order: 'sequential' 42 | conditional: false 43 | no_batch_norm: false 44 | cond_label_size: 10 45 | ema: false 46 | 47 | optim: 48 | weight_decay: 0.000001 49 | optimizer: "Adam" 50 | lr: 0.0001 51 | beta1: 0.9 52 | amsgrad: false 53 | eps: 0.00000001 54 | 55 | loss: 56 | name: "cross_entropy" 57 | -------------------------------------------------------------------------------- /src/configs/flows/mnist/diff_digits/perc1.0.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 100 3 | n_epochs: 200 4 | ngpu: 1 5 | log_interval: 100 6 | iter_save: 100 7 | out_dir: "flows/results" 8 | data_dir: "../data/" 9 | 10 | data: 11 | dataset: "DigitMNISTSubset" 12 | perc: 1.0 13 | biased_digits: [0, 7] 14 | biased_digit_percs: [0.5, 0.5] 15 | ref_digits: [1, 2] 16 | ref_digit_percs: [0.5, 0.5] 17 | image_size: 28 18 | channels: 1 19 | random_flip: false 20 | num_workers: 4 21 | flip_toy_var_order: true 22 | 23 | dre: 24 | alpha: 0 25 | 26 | sampling: 27 | n_samples: 1000 28 | n_sir: 1000 29 | # sir: 1000 30 | 31 | model: 32 | name: "maf" 33 | dropout: 0.1 34 | n_classes: 2 35 | n_blocks: 5 36 | n_components: 1 37 | input_size: 784 38 | hidden_size: 1024 39 | n_hidden: 1 40 | activation_fn: 'relu' 41 | input_order: 'sequential' 42 | conditional: false 43 | no_batch_norm: false 44 | cond_label_size: 10 45 | ema: false 46 | 47 | optim: 48 | weight_decay: 0.000001 49 | optimizer: "Adam" 50 | lr: 0.0001 51 | beta1: 0.9 52 | amsgrad: false 53 | eps: 0.00000001 54 | 55 | loss: 56 | name: "cross_entropy" 57 | -------------------------------------------------------------------------------- /src/configs/flows/omniglot/omniglot_maf.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 100 3 | n_epochs: 200 4 | ngpu: 1 5 | log_interval: 100 6 | iter_save: 100 7 | exp_id: "omniglot_maf" 8 | out_dir: "/path/to/f-dre/src/flows/results/" 9 | data_dir: "/path/to/f-dre/data/" 10 | 11 | sampling: 12 | generate: true 13 | fair: true 14 | n_samples: 3000 15 | encode: false 16 | 17 | data: 18 | dataset: "Omniglot_Mixture" 19 | subset: false 20 | perc: 0.1 21 | image_size: 28 22 | channels: 1 23 | input_size: 784 24 | random_flip: false 25 | x_space: false 26 | class_idx: 20 27 | num_workers: 4 28 | flip_toy_var_order: true 29 | 30 | model: 31 | name: "maf" 32 | dropout: 0.1 33 | n_classes: 2 34 | n_blocks: 5 35 | n_components: 1 36 | input_size: 784 37 | hidden_size: 1024 38 | n_hidden: 2 39 | activation_fn: 'relu' 40 | input_order: 'sequential' 41 | conditional: false 42 | no_batch_norm: false 43 | cond_label_size: 10 44 | ema: false 45 | 46 | dre: 47 | alpha: 0.06 48 | 49 | sampling: 50 | n_samples: 50000 51 | sir: 1000 52 | 53 | optim: 54 | weight_decay: 0.0005 55 | optimizer: "Adam" 56 | lr: 0.0001 57 | beta1: 0.9 58 | amsgrad: false 59 | eps: 0.00000001 60 | 61 | loss: 62 | name: "cross_entropy" -------------------------------------------------------------------------------- /src/classification/models/tre.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from functools import partial 5 | from synthetic_clf import Classifier 6 | 7 | 8 | class TREClassifier(nn.Module): 9 | def __init__(self, config): 10 | super().__init__() 11 | self.in_dim = config.data.channels 12 | self.m = config.dre.m 13 | self.p = config.dre.p 14 | self.spectral_norm = config.model.spectral_norm 15 | 16 | # data-dependent params 17 | self.mu_0 = config.data.mus[:-1] 18 | self.mu_m = config.data.mus[-1] 19 | 20 | 21 | # multi-head logistic regressor 22 | self.linear_fc = nn.ModuleList( 23 | [nn.Linear(1, 1) for _ in range(self.m)]) 24 | for w in self.linear_fc: 25 | nn.init.xavier_normal_(w.weight) 26 | 27 | def forward(self, x): 28 | """Summary 29 | """ 30 | # construct pairs per minibatch 31 | xs = [torch.cat([x[i], x[i+1]]).unsqueeze(1) for i in range(self.m)] 32 | 33 | # right now testing out binary cross entropy with logits 34 | out = [] 35 | for i, fc in enumerate(self.linear_fc): 36 | out.append(fc(xs[i]**2)) 37 | out = torch.stack(out) 38 | return out 39 | 40 | def loss(self, x, out, y): 41 | """Summary 42 | """ 43 | raise NotImplementedError -------------------------------------------------------------------------------- /src/datasets/kmm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.distributions as dist 5 | from torch.distributions import Normal 6 | from torch.utils.data import Dataset, TensorDataset 7 | from .looping import LoopingDataset 8 | 9 | 10 | class KMM(Dataset): 11 | def __init__(self, config, split='train'): 12 | self.config = config 13 | self.split = split 14 | 15 | self.perc = config.data.perc 16 | self.input_size = config.data.input_size 17 | self.label_size = 1 18 | 19 | if self.split == 'train': 20 | source_record = np.load('/atlas/u/kechoi/multi-fairgen/data/kmm/source.npz') 21 | target_record = np.load('/atlas/u/kechoi/multi-fairgen/data/kmm/target.npz') 22 | data = np.vstack([source_record['x'], target_record['x']]) 23 | labels = np.hstack([source_record['y'], target_record['y']]) 24 | else: 25 | record = np.load('/atlas/u/kechoi/multi-fairgen/data/kmm/target_test.npz') 26 | data = record['x'] 27 | labels = record['y'] 28 | self.dataset = torch.from_numpy(data).float() 29 | self.labels = torch.from_numpy(labels).float() 30 | 31 | def __len__(self): 32 | return len(self.dataset) 33 | 34 | def __getitem__(self, i): 35 | item = self.dataset[i] 36 | label = self.labels[i] 37 | 38 | return item, label -------------------------------------------------------------------------------- /src/flows/models/maf/realnvp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.distributions as D 4 | 5 | from .layers import * 6 | 7 | class RealNVP(nn.Module): 8 | def __init__(self, n_blocks, input_size, hidden_size, n_hidden, cond_label_size=None, batch_norm=True): 9 | super().__init__() 10 | 11 | # base distribution for calculation of log prob under the model 12 | self.register_buffer('base_dist_mean', torch.zeros(input_size)) 13 | self.register_buffer('base_dist_var', torch.ones(input_size)) 14 | 15 | # construct model 16 | modules = [] 17 | mask = torch.arange(input_size).float() % 2 18 | for i in range(n_blocks): 19 | modules += [LinearMaskedCoupling(input_size, hidden_size, n_hidden, mask, cond_label_size)] 20 | mask = 1 - mask 21 | modules += batch_norm * [BatchNorm(input_size)] 22 | 23 | self.net = FlowSequential(*modules) 24 | 25 | @property 26 | def base_dist(self): 27 | return D.Normal(self.base_dist_mean, self.base_dist_var) 28 | 29 | def forward(self, x, y=None): 30 | return self.net(x, y) 31 | 32 | def inverse(self, u, y=None): 33 | return self.net.inverse(u, y) 34 | 35 | def log_prob(self, x, y=None): 36 | u, sum_log_abs_det_jacobians = self.forward(x, y) 37 | return torch.sum(self.base_dist.log_prob(u) + sum_log_abs_det_jacobians, dim=1) -------------------------------------------------------------------------------- /src/flows/trainers/glow1/learning_rate_schedule.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def default(init_lr, global_step): 5 | return init_lr 6 | 7 | 8 | # https://github.com/tensorflow/tensor2tensor/issues/280#issuecomment-339110329 9 | def noam_learning_rate_decay(init_lr, global_step, warmup_steps=4000, minimum=None): 10 | # Noam scheme from tensor2tensor: 11 | warmup_steps = float(warmup_steps) 12 | step = global_step + 1. 13 | lr = init_lr * warmup_steps**0.5 * np.minimum( 14 | step * warmup_steps**-1.5, step**-0.5) 15 | if minimum is not None and global_step > warmup_steps: 16 | if lr < minimum: 17 | lr = minimum 18 | return lr 19 | 20 | 21 | def step_learning_rate_decay(init_lr, global_step, 22 | anneal_rate=0.98, 23 | anneal_interval=30000): 24 | return init_lr * anneal_rate ** (global_step // anneal_interval) 25 | 26 | 27 | def cyclic_cosine_annealing(init_lr, global_step, T, M): 28 | """Cyclic cosine annealing 29 | 30 | https://arxiv.org/pdf/1704.00109.pdf 31 | 32 | Args: 33 | init_lr (float): Initial learning rate 34 | global_step (int): Current iteration number 35 | T (int): Total iteration number (i,e. nepoch) 36 | M (int): Number of ensembles we want 37 | 38 | Returns: 39 | float: Annealed learning rate 40 | """ 41 | TdivM = T // M 42 | return init_lr / 2.0 * (np.cos(np.pi * ((global_step - 1) % TdivM) / TdivM) + 1.0) 43 | -------------------------------------------------------------------------------- /src/flows/functions/utils.py: -------------------------------------------------------------------------------- 1 | # TODO: (1) classifier outputs for attribute labels; (2) FID scores 2 | # NOTE: this code is untested and needs to be adapted for our setup! 3 | import math 4 | import functools 5 | import numpy as np 6 | from tqdm import tqdm, trange 7 | import os 8 | import glob 9 | import pickle 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.nn import init 14 | import torch.optim as optim 15 | import torch.nn.functional as F 16 | import torchvision 17 | 18 | 19 | def get_ratio_estimates(model, z): 20 | """ 21 | get density ratio estimates in z-space 22 | """ 23 | model.eval() 24 | logits, probas = model(z) 25 | ratios = probas[:,1]/probas[:,0] 26 | 27 | return ratios 28 | 29 | 30 | def fairness_discrepancy(data, n_classes): 31 | """ 32 | computes fairness discrepancy metric for single or multi-attribute 33 | this metric computes L2, L1, AND KL-total variation distance 34 | """ 35 | unique, freq = torch.unique(data, return_counts=True) 36 | props = freq / len(data) 37 | truth = 1./n_classes 38 | 39 | # L2 and L1 40 | l2_fair_d = math.sqrt(((props - truth)**2).sum()) 41 | l1_fair_d = abs(props - truth).sum() 42 | 43 | # q = props, p = truth 44 | kl_fair_d = (props * (math.log(props) - math.log(truth))).sum() 45 | 46 | return l2_fair_d, l1_fair_d, kl_fair_d 47 | 48 | 49 | def classify_examples(model, x): 50 | model.eval() 51 | 52 | logits, probas = model(x) 53 | _, pred = torch.max(probas, 1) 54 | 55 | return preds, probas -------------------------------------------------------------------------------- /src/classification/models/flow_mlp.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from src.flows.models.maf import MAF 6 | 7 | 8 | class FlowClassifier(nn.Module): 9 | """ 10 | Deeper classifier that uses a normalizing flow to map data points into z-space before performing classification. 11 | """ 12 | def __init__(self, config): 13 | super(FlowClassifier, self).__init__() 14 | self.config = config 15 | self.h_dim = config.model.h_dim 16 | self.n_classes = config.model.n_classes 17 | self.in_dim = config.model.in_dim 18 | 19 | # HACK: hardcoded flow architecture that we've been using! 20 | # TODO: fix flow architecture to vary with dataset size 21 | if 'CIFAR' in self.config.data.dataset: 22 | self.flow = MAF(5, self.in_dim, 1024, 2, None, 'relu', 'sequential', batch_norm=True) 23 | else: 24 | self.flow = MAF(5, self.in_dim, 100, 1, None, 'relu', 'sequential', batch_norm=True) 25 | 26 | self.fc1 = nn.Linear(self.in_dim, self.h_dim) 27 | self.fc2 = nn.Linear(self.h_dim, self.h_dim) 28 | self.fc3 = nn.Linear(self.h_dim, self.h_dim) 29 | self.fc4 = nn.Linear(self.h_dim, 1) 30 | 31 | def forward(self, x): 32 | # map data into z-space 33 | z, _ = self.flow.forward(x) 34 | 35 | # then train classifier 36 | z = F.relu(self.fc1(z)) 37 | z = F.relu(self.fc2(z)) 38 | z = F.relu(self.fc3(z)) 39 | logits = self.fc4(z) 40 | probas = torch.sigmoid(logits) 41 | 42 | return logits, probas 43 | 44 | @torch.no_grad() 45 | def flow_encode(self, x): 46 | z, _ = self.flow.forward(x) 47 | return z -------------------------------------------------------------------------------- /src/flows/models/maf/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def create_masks(input_size, hidden_size, n_hidden, input_order='sequential', input_degrees=None): 4 | # MADE paper sec 4: 5 | # degrees of connections between layers -- ensure at most in_degree - 1 connections 6 | degrees = [] 7 | 8 | # set input degrees to what is provided in args (the flipped order of the previous layer in a stack of mades); 9 | # else init input degrees based on strategy in input_order (sequential or random) 10 | if input_order == 'sequential': 11 | degrees += [torch.arange(input_size)] if input_degrees is None else [input_degrees] 12 | for _ in range(n_hidden + 1): 13 | degrees += [torch.arange(hidden_size) % (input_size - 1)] 14 | degrees += [torch.arange(input_size) % input_size - 1] if input_degrees is None else [input_degrees % input_size - 1] 15 | 16 | elif input_order == 'random': 17 | degrees += [torch.randperm(input_size)] if input_degrees is None else [input_degrees] 18 | for _ in range(n_hidden + 1): 19 | min_prev_degree = min(degrees[-1].min().item(), input_size - 1) 20 | degrees += [torch.randint(min_prev_degree, input_size, (hidden_size,))] 21 | min_prev_degree = min(degrees[-1].min().item(), input_size - 1) 22 | degrees += [torch.randint(min_prev_degree, input_size, (input_size,)) - 1] if input_degrees is None else [input_degrees - 1] 23 | 24 | # construct masks 25 | masks = [] 26 | for (d0, d1) in zip(degrees[:-1], degrees[1:]): 27 | masks += [(d1.unsqueeze(-1) >= d0.unsqueeze(0)).float()] 28 | 29 | return masks, degrees[0] -------------------------------------------------------------------------------- /scripts/create_yaml_files.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | 4 | 5 | # alphas = [0.1, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] 6 | # alphas = [0.01, 0.02, 0.05, 0.07] 7 | alphas = [0.001, 0.005, 0.007] 8 | for alpha in alphas: 9 | exp_id = 'joint_gmm_flow_mlp_perc1.0_alpha{}'.format(alpha) 10 | 11 | dict_file = [ 12 | {'training': { 13 | 'batch_size': 128, 14 | 'n_epochs': 200, 15 | 'ngpu': 1, 16 | 'iter_log': 1000, 17 | 'iter_save': 100, 18 | 'exp_id': exp_id, 19 | 'out_dir': "/atlas/u/kechoi/multi-fairgen/src/classification/results/", 20 | 'data_dir': "/atlas/u/kechoi/multi-fairgen/data/" 21 | }}, 22 | {'data': { 23 | 'dataset': "GMM", 24 | 'subset': False, 25 | 'x_space': True, 26 | 'input_size': 2, 27 | 'perc': 1.0, 28 | 'mus': [0, 3], 29 | 'class_idx': 20, 30 | 'num_workers': 4, 31 | }}, 32 | {'model': { 33 | 'name': "flow_mlp", 34 | 'spectral_norm': True, 35 | 'batch_norm': True, 36 | 'in_dim': 2, 37 | 'h_dim': 200, 38 | 'dropout': 0.1, 39 | 'n_classes': 2, 40 | }}, 41 | {'optim': { 42 | 'weight_decay': 0.000001, 43 | 'optimizer': "Adam", 44 | 'lr': 0.0001, 45 | 'beta1': 0.9, 46 | 'amsgrad': False, 47 | }}, 48 | {'loss': { 49 | 'name': "joint", 50 | 'alpha': alpha, 51 | }} 52 | ] 53 | 54 | # filenames 55 | fname = exp_id = 'joint_flow_mlp_perc1.0_alpha{}.yaml'.format(alpha) 56 | fpath = os.path.join('/atlas/u/kechoi/multi-fairgen/src/configs/classification/gmm/joint_sweep/', fname) 57 | 58 | with open(fpath, 'w') as file: 59 | documents = yaml.dump(dict_file, file) -------------------------------------------------------------------------------- /src/flows/trainers/hparams/celeba.json: -------------------------------------------------------------------------------- 1 | { 2 | "Dir": { 3 | "log_root": "results/celeba" 4 | }, 5 | "Glow" : { 6 | "image_shape": [64, 64, 3], 7 | "hidden_channels": 512, 8 | "K": 32, 9 | "L": 3, 10 | "actnorm_scale": 1.0, 11 | "flow_permutation": "invconv", 12 | "flow_coupling": "affine", 13 | "LU_decomposed": false, 14 | "learn_top": false, 15 | "y_condition": false, 16 | "y_classes": 40 17 | }, 18 | "Criterion" : { 19 | "y_condition": "multi-classes" 20 | }, 21 | "Data" : { 22 | "center_crop": 160, 23 | "resize": 64 24 | }, 25 | "Optim": { 26 | "name": "adam", 27 | "args": { 28 | "lr": 1e-3, 29 | "betas": [0.9, 0.9999], 30 | "eps": 1e-8 31 | }, 32 | "Schedule": { 33 | "name": "noam_learning_rate_decay", 34 | "args": { 35 | "warmup_steps": 4000, 36 | "minimum": 1e-4 37 | } 38 | } 39 | }, 40 | "Device": { 41 | "glow": ["cuda:0", "cuda:1", "cuda:2", "cuda:3"], 42 | "data": "cuda:0" 43 | }, 44 | "Train": { 45 | "batch_size": 12, 46 | "num_batches": 1000000, 47 | "max_grad_clip": 5, 48 | "max_grad_norm": 100, 49 | "max_checkpoints": 20, 50 | "checkpoints_gap": 5000, 51 | "num_plot_samples": 1, 52 | "scalar_log_gap": 50, 53 | "plot_gap": 50, 54 | "inference_gap": 50, 55 | "warm_start": "", 56 | "weight_y": 0.5 57 | }, 58 | "Infer": { 59 | "pre_trained": "../results/trained.pkg" 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /src/flows/models/flowplusplus/inv_conv.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class InvConv(nn.Module): 8 | """Invertible 1x1 Convolution for 2D inputs. Originally described in Glow 9 | (https://arxiv.org/abs/1807.03039). Does not support LU-decomposed version. 10 | 11 | Args: 12 | num_channels (int): Number of channels in the input and output. 13 | random_init (bool): Initialize with a random orthogonal matrix. 14 | Otherwise initialize with noisy identity. 15 | """ 16 | def __init__(self, num_channels, random_init=False): 17 | super(InvConv, self).__init__() 18 | self.num_channels = 2 * num_channels 19 | 20 | if random_init: 21 | # Initialize with a random orthogonal matrix 22 | w_init = np.random.randn(self.num_channels, self.num_channels) 23 | w_init = np.linalg.qr(w_init)[0] 24 | else: 25 | # Initialize as identity permutation with some noise 26 | w_init = np.eye(self.num_channels, self.num_channels) \ 27 | + 1e-3 * np.random.randn(self.num_channels, self.num_channels) 28 | self.weight = nn.Parameter(torch.from_numpy(w_init.astype(np.float32))) 29 | 30 | def forward(self, x, sldj, reverse=False): 31 | x = torch.cat(x, dim=1) 32 | 33 | ldj = torch.slogdet(self.weight)[1] * x.size(2) * x.size(3) 34 | 35 | if reverse: 36 | weight = torch.inverse(self.weight.double()).float() 37 | sldj = sldj - ldj 38 | else: 39 | weight = self.weight 40 | sldj = sldj + ldj 41 | 42 | weight = weight.view(self.num_channels, self.num_channels, 1, 1) 43 | x = F.conv2d(x, weight) 44 | x = x.chunk(2, dim=1) 45 | 46 | return x, sldj 47 | -------------------------------------------------------------------------------- /src/flows/models/ema.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class EMAHelper(object): 5 | def __init__(self, mu=0.999): 6 | self.mu = mu 7 | self.shadow = {} 8 | 9 | def register(self, module): 10 | if isinstance(module, nn.DataParallel): 11 | module = module.module 12 | for name, param in module.named_parameters(): 13 | if param.requires_grad: 14 | self.shadow[name] = param.data.clone() 15 | 16 | def update(self, module): 17 | if isinstance(module, nn.DataParallel): 18 | module = module.module 19 | for name, param in module.named_parameters(): 20 | if param.requires_grad: 21 | self.shadow[name].data = ( 22 | 1. - self.mu) * param.data + self.mu * self.shadow[name].data 23 | 24 | def ema(self, module): 25 | if isinstance(module, nn.DataParallel): 26 | module = module.module 27 | for name, param in module.named_parameters(): 28 | if param.requires_grad: 29 | param.data.copy_(self.shadow[name].data) 30 | 31 | def ema_copy(self, module): 32 | if isinstance(module, nn.DataParallel): 33 | inner_module = module.module 34 | module_copy = type(inner_module)( 35 | inner_module.config).to(inner_module.config.device) 36 | module_copy.load_state_dict(inner_module.state_dict()) 37 | module_copy = nn.DataParallel(module_copy) 38 | else: 39 | module_copy = type(module)(module.config).to(module.config.device) 40 | module_copy.load_state_dict(module.state_dict()) 41 | # module_copy = copy.deepcopy(module) 42 | self.ema(module_copy) 43 | return module_copy 44 | 45 | def state_dict(self): 46 | return self.shadow 47 | 48 | def load_state_dict(self, state_dict): 49 | self.shadow = state_dict 50 | -------------------------------------------------------------------------------- /src/flows/trainers/glow1/thops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def onehot(y, num_classes): 5 | y_onehot = torch.zeros(y.size(0), num_classes).to(y.device) 6 | if len(y.size()) == 1: 7 | y_onehot = y_onehot.scatter_(1, y.unsqueeze(-1), 1) 8 | elif len(y.size()) == 2: 9 | y_onehot = y_onehot.scatter_(1, y, 1) 10 | else: 11 | raise ValueError("[onehot]: y should be in shape [B], or [B, C]") 12 | return y_onehot 13 | 14 | 15 | def sum(tensor, dim=None, keepdim=False): 16 | if dim is None: 17 | # sum up all dim 18 | return torch.sum(tensor) 19 | else: 20 | if isinstance(dim, int): 21 | dim = [dim] 22 | dim = sorted(dim) 23 | for d in dim: 24 | tensor = tensor.sum(dim=d, keepdim=True) 25 | if not keepdim: 26 | for i, d in enumerate(dim): 27 | tensor.squeeze_(d-i) 28 | return tensor 29 | 30 | 31 | def mean(tensor, dim=None, keepdim=False): 32 | if dim is None: 33 | # mean all dim 34 | return torch.mean(tensor) 35 | else: 36 | if isinstance(dim, int): 37 | dim = [dim] 38 | dim = sorted(dim) 39 | for d in dim: 40 | tensor = tensor.mean(dim=d, keepdim=True) 41 | if not keepdim: 42 | for i, d in enumerate(dim): 43 | tensor.squeeze_(d-i) 44 | return tensor 45 | 46 | 47 | def split_feature(tensor, type="split"): 48 | """ 49 | type = ["split", "cross"] 50 | """ 51 | C = tensor.size(1) 52 | if type == "split": 53 | return tensor[:, :C // 2, ...], tensor[:, C // 2:, ...] 54 | elif type == "cross": 55 | return tensor[:, 0::2, ...], tensor[:, 1::2, ...] 56 | 57 | 58 | def cat_feature(tensor_a, tensor_b): 59 | return torch.cat((tensor_a, tensor_b), dim=1) 60 | 61 | 62 | def pixels(tensor): 63 | return int(tensor.size(2) * tensor.size(3)) 64 | -------------------------------------------------------------------------------- /src/util/optim_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | import torch.nn.utils as utils 4 | 5 | 6 | def bits_per_dim(x, nll): 7 | """Get the bits per dimension implied by using model with `loss` 8 | for compressing `x`, assuming each entry can take on `k` discrete values. 9 | 10 | Args: 11 | x (torch.Tensor): Input to the model. Just used for dimensions. 12 | nll (torch.Tensor): Scalar negative log-likelihood loss tensor. 13 | 14 | Returns: 15 | bpd (torch.Tensor): Bits per dimension implied if compressing `x`. 16 | """ 17 | dim = np.prod(x.size()[1:]) 18 | bpd = nll / (np.log(2) * dim) 19 | 20 | return bpd 21 | 22 | 23 | def clip_grad_norm(optimizer, max_norm, norm_type=2): 24 | """Clip the norm of the gradients for all parameters under `optimizer`. 25 | 26 | Args: 27 | optimizer (torch.optim.Optimizer): 28 | max_norm (float): The maximum allowable norm of gradients. 29 | norm_type (int): The type of norm to use in computing gradient norms. 30 | """ 31 | for group in optimizer.param_groups: 32 | utils.clip_grad_norm_(group['params'], max_norm, norm_type) 33 | 34 | 35 | class NLLLoss(nn.Module): 36 | """Negative log-likelihood loss assuming isotropic gaussian with unit norm. 37 | 38 | Args: 39 | k (int or float): Number of discrete values in each input dimension. 40 | E.g., `k` is 256 for natural images. 41 | 42 | See Also: 43 | Equation (3) in the RealNVP paper: https://arxiv.org/abs/1605.08803 44 | """ 45 | def __init__(self, k=256): 46 | super(NLLLoss, self).__init__() 47 | self.k = k 48 | 49 | def forward(self, z, sldj): 50 | prior_ll = -0.5 * (z ** 2 + np.log(2 * np.pi)) 51 | prior_ll = prior_ll.flatten(1).sum(-1) \ 52 | - np.log(self.k) * np.prod(z.size()[1:]) 53 | ll = prior_ll + sldj 54 | nll = -ll.mean() 55 | 56 | return nll 57 | -------------------------------------------------------------------------------- /src/flows/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def logsumexp_1p(s): 6 | # numerically stable implementation of log sigmoid via logsumexp 7 | # NOTE: this assumes that you feed in -s for the positive probabilities 8 | if len(s.size()) == 2: 9 | s = s.squeeze() 10 | x = torch.stack([s, torch.zeros_like(s)], -1) 11 | val, _ = torch.max(x, 1) 12 | val = val.repeat(2,1).T 13 | logsigmoid = torch.clamp(s, 0) + torch.log( 14 | torch.sum(torch.exp(x - val), 1)) 15 | 16 | return -logsigmoid 17 | 18 | 19 | def classify_examples(model, config): 20 | """ 21 | classifies generated samples into appropriate classes 22 | NOTE: unused atm 23 | """ 24 | model.eval() 25 | preds = [] 26 | samples = np.load(config['sample_path'])['x'] 27 | n_batches = samples.shape[0] // 1000 28 | 29 | with torch.no_grad(): 30 | # generate 10K samples 31 | for i in range(n_batches): 32 | x = samples[i*1000:(i+1)*1000] 33 | samp = x / 255. # renormalize to feed into classifier 34 | samp = torch.from_numpy(samp).to('cuda').float() 35 | 36 | # get classifier predictions 37 | logits, probas = model(samp) 38 | _, pred = torch.max(probas, 1) 39 | preds.append(pred) 40 | preds = torch.cat(preds).data.cpu().numpy() 41 | 42 | return preds 43 | 44 | 45 | def fairness_discrepancy(data, n_classes): 46 | """ 47 | computes fairness discrepancy metric for single or multi-attribute 48 | this metric computes L2, L1, AND KL-total variation distance 49 | """ 50 | unique, freq = np.unique(data, return_counts=True) 51 | props = freq / len(data) 52 | truth = 1./n_classes 53 | 54 | # L2 and L1 55 | l2_fair_d = np.sqrt(((props - truth)**2).sum()) 56 | l1_fair_d = abs(props - truth).sum() 57 | 58 | # q = props, p = truth 59 | kl_fair_d = (props * (np.log(props) - np.log(truth))).sum() 60 | 61 | return l2_fair_d, l1_fair_d, kl_fair_d -------------------------------------------------------------------------------- /src/flows/models/flowplusplus/coupling.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from models.flowplusplus import log_dist as logistic 4 | from models.flowplusplus.nn import NN 5 | 6 | 7 | class Coupling(nn.Module): 8 | """Mixture-of-Logistics Coupling layer in Flow++ 9 | 10 | Args: 11 | in_channels (int): Number of channels in the input. 12 | mid_channels (int): Number of channels in the transformation network. 13 | num_blocks (int): Number of residual blocks in the transformation network. 14 | num_components (int): Number of components in the mixture. 15 | drop_prob (float): Dropout probability. 16 | use_attn (bool): Use attention in the NN blocks. 17 | aux_channels (int): Number of channels in optional auxiliary input. 18 | """ 19 | def __init__(self, in_channels, mid_channels, num_blocks, num_components, drop_prob, 20 | use_attn=True, aux_channels=None): 21 | super(Coupling, self).__init__() 22 | self.nn = NN(in_channels, mid_channels, num_blocks, num_components, drop_prob, use_attn, aux_channels) 23 | 24 | def forward(self, x, sldj=None, reverse=False, aux=None): 25 | x_change, x_id = x 26 | a, b, pi, mu, s = self.nn(x_id, aux) 27 | 28 | if reverse: 29 | out = x_change * a.mul(-1).exp() - b 30 | out, scale_ldj = logistic.inverse(out, reverse=True) 31 | out = out.clamp(1e-5, 1. - 1e-5) 32 | out = logistic.mixture_inv_cdf(out, pi, mu, s) 33 | logistic_ldj = logistic.mixture_log_pdf(out, pi, mu, s) 34 | sldj = sldj - (a + scale_ldj + logistic_ldj).flatten(1).sum(-1) 35 | else: 36 | out = logistic.mixture_log_cdf(x_change, pi, mu, s).exp() 37 | out, scale_ldj = logistic.inverse(out) 38 | out = (out + b) * a.exp() 39 | logistic_ldj = logistic.mixture_log_pdf(x_change, pi, mu, s) 40 | sldj = sldj + (logistic_ldj + scale_ldj + a).flatten(1).sum(-1) 41 | 42 | x = (out, x_id) 43 | 44 | return x, sldj 45 | -------------------------------------------------------------------------------- /src/util/norm_util.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def get_param_groups(net, weight_decay, norm_suffix='weight_g', verbose=False): 5 | """Get two parameter groups from `net`: One named "normalized" which will 6 | override the optimizer with `weight_decay`, and one named "unnormalized" 7 | which will inherit all hyperparameters from the optimizer. 8 | Args: 9 | net (torch.nn.Module): Network to get parameters from 10 | weight_decay (float): Weight decay to apply to normalized weights. 11 | norm_suffix (str): Suffix to select weights that should be normalized. 12 | For WeightNorm, using 'weight_g' normalizes the scale variables. 13 | verbose (bool): Print out number of normalized and unnormalized parameters. 14 | """ 15 | norm_params = [] 16 | unnorm_params = [] 17 | for n, p in net.named_parameters(): 18 | if n.endswith(norm_suffix): 19 | norm_params.append(p) 20 | else: 21 | unnorm_params.append(p) 22 | 23 | param_groups = [{'name': 'normalized', 'params': norm_params, 'weight_decay': weight_decay}, 24 | {'name': 'unnormalized', 'params': unnorm_params}] 25 | 26 | if verbose: 27 | print('{} normalized parameters'.format(len(norm_params))) 28 | print('{} unnormalized parameters'.format(len(unnorm_params))) 29 | 30 | return param_groups 31 | 32 | 33 | class WNConv2d(nn.Module): 34 | """Weight-normalized 2d convolution. 35 | Args: 36 | in_channels (int): Number of channels in the input. 37 | out_channels (int): Number of channels in the output. 38 | kernel_size (int): Side length of each convolutional kernel. 39 | padding (int): Padding to add on edges of input. 40 | bias (bool): Use bias in the convolution operation. 41 | """ 42 | def __init__(self, in_channels, out_channels, kernel_size, padding, bias=True): 43 | super(WNConv2d, self).__init__() 44 | self.conv = nn.utils.weight_norm( 45 | nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=bias)) 46 | 47 | def forward(self, x): 48 | x = self.conv(x) 49 | 50 | return x 51 | -------------------------------------------------------------------------------- /src/classification/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import numpy as np 4 | 5 | 6 | def logit_transform(image, lam=1e-6): 7 | image = lam + (1 - 2 * lam) * image 8 | return torch.log(image) - torch.log1p(-image) 9 | 10 | 11 | def maf_preprocess(x): 12 | x = (x * 255).byte() 13 | # performs dequantization, rescaling, then logit transform 14 | x = (x + torch.rand(x.size()).to(x.device)) / 256. 15 | x = logit_transform(x) 16 | return x 17 | 18 | 19 | def glow_preprocess(x): 20 | # Follows: 21 | # https://github.com/tensorflow/tensor2tensor/blob/e48cf23c505565fd63378286d9722a1632f4bef7/tensor2tensor/models/research/glow.py#L78 22 | n_bits = 8 23 | x = x * 255 # undo ToTensor scaling to [0,1] 24 | 25 | n_bins = 2 ** n_bits 26 | if n_bits < 8: 27 | x = torch.floor(x / 2 ** (8 - n_bits)) 28 | x = x / n_bins - 0.5 29 | 30 | return x 31 | 32 | 33 | def logsumexp_1p(s): 34 | # numerically stable implementation of log sigmoid via logsumexp 35 | # NOTE: this assumes that you feed in -s for the positive probabilities 36 | if len(s.size()) == 2: 37 | s = s.squeeze() 38 | x = torch.stack([s, torch.zeros_like(s)], -1) 39 | val, _ = torch.max(x, 1) 40 | val = val.repeat(2,1).T 41 | logsigmoid = torch.clamp(s, 0) + torch.log( 42 | torch.sum(torch.exp(x - val), 1)) 43 | 44 | return -logsigmoid 45 | 46 | 47 | class AverageMeter(object): 48 | """Computes and stores the average and current value""" 49 | 50 | def __init__(self): 51 | self.reset() 52 | 53 | def reset(self): 54 | self.val = 0 55 | self.avg = 0 56 | self.sum = 0 57 | self.count = 0 58 | 59 | def update(self, val, n=1): 60 | self.val = val 61 | self.sum += val * n 62 | self.count += n 63 | self.avg = self.sum / self.count 64 | 65 | 66 | def read_json(fname): 67 | fname = Path(fname) 68 | with fname.open('rt') as handle: 69 | return json.load(handle, object_hook=OrderedDict) 70 | 71 | 72 | def write_json(content, fname): 73 | fname = Path(fname) 74 | with fname.open('wt') as handle: 75 | json.dump(content, handle, indent=4, sort_keys=False) -------------------------------------------------------------------------------- /src/flows/models/glow/invconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class Invertible1x1Conv(nn.Module): 6 | """ Invertible 1x1 convolution layer; cf Glow section 3.2 """ 7 | def __init__(self, n_channels=3, lu_factorize=False): 8 | super().__init__() 9 | self.lu_factorize = lu_factorize 10 | 11 | # initiaize a 1x1 convolution weight matrix 12 | w = torch.randn(n_channels, n_channels) 13 | w = torch.qr(w)[0] # note: nn.init.orthogonal_ returns orth matrices with dets +/- 1 which complicates the inverse call below 14 | 15 | if lu_factorize: 16 | # compute LU factorization 17 | p, l, u = torch.btriunpack(*w.unsqueeze(0).btrifact()) 18 | # initialize model parameters 19 | self.p, self.l, self.u = nn.Parameter(p.squeeze()), nn.Parameter(l.squeeze()), nn.Parameter(u.squeeze()) 20 | s = self.u.diag() 21 | self.log_s = nn.Parameter(s.abs().log()) 22 | self.register_buffer('sign_s', s.sign()) # note: not optimizing the sign; det W remains the same sign 23 | self.register_buffer('l_mask', torch.tril(torch.ones_like(self.l), -1)) # store mask to compute LU in forward/inverse pass 24 | else: 25 | self.w = nn.Parameter(w) 26 | 27 | def forward(self, x): 28 | B,C,H,W = x.shape 29 | if self.lu_factorize: 30 | l = self.l * self.l_mask + torch.eye(C).to(self.l.device) 31 | u = self.u * self.l_mask.t() + torch.diag(self.sign_s * self.log_s.exp()) 32 | self.w = self.p @ l @ u 33 | logdet = self.log_s.sum() * H * W 34 | else: 35 | logdet = torch.slogdet(self.w)[-1] * H * W 36 | 37 | return F.conv2d(x, self.w.view(C,C,1,1)), logdet 38 | 39 | def inverse(self, z): 40 | B,C,H,W = z.shape 41 | if self.lu_factorize: 42 | l = torch.inverse(self.l * self.l_mask + torch.eye(C).to(self.l.device)) 43 | u = torch.inverse(self.u * self.l_mask.t() + torch.diag(self.sign_s * self.log_s.exp())) 44 | w_inv = u @ l @ self.p.inverse() 45 | logdet = - self.log_s.sum() * H * W 46 | else: 47 | w_inv = self.w.inverse() 48 | logdet = - torch.slogdet(self.w)[-1] * H * W 49 | 50 | return F.conv2d(z, w_inv.view(C,C,1,1)), logdet -------------------------------------------------------------------------------- /src/datasets/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Select dataset functions from MAF repo 3 | https://github.com/gpapamak/maf/blob/master/util.py 4 | """ 5 | 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | 9 | def plot_hist_marginals(data, lims=None, gt=None): 10 | """ 11 | Plots marginal histograms and pairwise scatter plots of a dataset. 12 | """ 13 | 14 | n_bins = int(np.sqrt(data.shape[0])) 15 | 16 | if data.ndim == 1: 17 | 18 | fig, ax = plt.subplots(1, 1) 19 | ax.hist(data, n_bins, normed=True) 20 | ax.set_ylim([0, ax.get_ylim()[1]]) 21 | if lims is not None: ax.set_xlim(lims) 22 | if gt is not None: ax.vlines(gt, 0, ax.get_ylim()[1], color='r') 23 | 24 | else: 25 | 26 | n_dim = data.shape[1] 27 | fig, ax = plt.subplots(n_dim, n_dim) 28 | ax = np.array([[ax]]) if n_dim == 1 else ax 29 | 30 | if lims is not None: 31 | lims = np.asarray(lims) 32 | lims = np.tile(lims, [n_dim, 1]) if lims.ndim == 1 else lims 33 | 34 | for i in xrange(n_dim): 35 | for j in xrange(n_dim): 36 | 37 | if i == j: 38 | ax[i, j].hist(data[:, i], n_bins, normed=True) 39 | ax[i, j].set_ylim([0, ax[i, j].get_ylim()[1]]) 40 | if lims is not None: ax[i, j].set_xlim(lims[i]) 41 | if gt is not None: ax[i, j].vlines(gt[i], 0, ax[i, j].get_ylim()[1], color='r') 42 | 43 | else: 44 | ax[i, j].plot(data[:, i], data[:, j], 'k.', ms=2) 45 | if lims is not None: 46 | ax[i, j].set_xlim(lims[i]) 47 | ax[i, j].set_ylim(lims[j]) 48 | if gt is not None: ax[i, j].plot(gt[i], gt[j], 'r.', ms=8) 49 | 50 | plt.show(block=False) 51 | 52 | return fig, ax 53 | 54 | 55 | def one_hot_encode(labels, n_labels): 56 | """ 57 | Transforms numeric labels to 1-hot encoded labels. Assumes numeric labels are in the range 0, 1, ..., n_labels-1. 58 | """ 59 | 60 | assert np.min(labels) >= 0 and np.max(labels) < n_labels 61 | 62 | y = np.zeros([labels.size, n_labels]) 63 | y[range(labels.size), labels] = 1 64 | 65 | return y 66 | 67 | def logit(x): 68 | """ 69 | Elementwise logit (inverse logistic sigmoid). 70 | :param x: numpy array 71 | :return: numpy array 72 | """ 73 | return np.log(x / (1.0 - x)) 74 | -------------------------------------------------------------------------------- /src/flows/models/glow/coupling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from models.glow.actnorm import ActNorm 6 | 7 | class AffineCoupling(nn.Module): 8 | """ Affine coupling layer; cf Glow section 3.3; RealNVP figure 2 """ 9 | def __init__(self, n_channels, width): 10 | super().__init__() 11 | # network layers; 12 | # per realnvp, network splits input, operates on half of it, and returns shift and scale of dim = half the input channels 13 | self.conv1 = nn.Conv2d(n_channels//2, width, kernel_size=3, padding=1, bias=False) # input is split along channel dim 14 | self.actnorm1 = ActNorm(param_dim=(1, width, 1, 1)) 15 | self.conv2 = nn.Conv2d(width, width, kernel_size=1, padding=1, bias=False) 16 | self.actnorm2 = ActNorm(param_dim=(1, width, 1, 1)) 17 | self.conv3 = nn.Conv2d(width, n_channels, kernel_size=3) # output is split into scale and shift components 18 | self.log_scale_factor = nn.Parameter(torch.zeros(n_channels,1,1)) # learned scale (cf RealNVP sec 4.1 / Glow official code 19 | 20 | # initialize last convolution with zeros, such that each affine coupling layer performs an identity function 21 | self.conv3.weight.data.zero_() 22 | self.conv3.bias.data.zero_() 23 | 24 | def forward(self, x): 25 | x_a, x_b = x.chunk(2, 1) # split along channel dim 26 | 27 | h = F.relu(self.actnorm1(self.conv1(x_b))[0]) 28 | h = F.relu(self.actnorm2(self.conv2(h))[0]) 29 | h = self.conv3(h) * self.log_scale_factor.exp() 30 | t = h[:,0::2,:,:] # shift; take even channels 31 | s = h[:,1::2,:,:] # scale; take odd channels 32 | s = torch.sigmoid(s + 2.) # at initalization, s is 0 and sigmoid(2) is near identity 33 | 34 | z_a = s * x_a + t 35 | z_b = x_b 36 | z = torch.cat([z_a, z_b], dim=1) # concat along channel dim 37 | 38 | logdet = s.log().sum([1, 2, 3]) 39 | 40 | return z, logdet 41 | 42 | def inverse(self, z): 43 | z_a, z_b = z.chunk(2, 1) # split along channel dim 44 | 45 | h = F.relu(self.actnorm1(self.conv1(z_b))[0]) 46 | h = F.relu(self.actnorm2(self.conv2(h))[0]) 47 | h = self.conv3(h) * self.log_scale_factor.exp() 48 | t = h[:,0::2,:,:] # shift; take even channels 49 | s = h[:,1::2,:,:] # scale; take odd channels 50 | s = torch.sigmoid(s + 2.) 51 | 52 | x_a = (z_a - t) / s 53 | x_b = z_b 54 | x = torch.cat([x_a, x_b], dim=1) # concat along channel dim 55 | 56 | logdet = - s.log().sum([1, 2, 3]) 57 | 58 | return x, logdet -------------------------------------------------------------------------------- /src/flows/trainers/glow1.py: -------------------------------------------------------------------------------- 1 | """Train script. 2 | 3 | Usage: 4 | train.py --encode 5 | """ 6 | import os 7 | import torch 8 | import vision 9 | from docopt import docopt 10 | from torchvision import transforms 11 | from torch.utils.data import DataLoader 12 | 13 | from glow1.builder import build 14 | from glow1.trainer import Trainer 15 | from glow1.config import JsonConfig 16 | 17 | 18 | if __name__ == "__main__": 19 | args = docopt(__doc__) 20 | hparams = args[""] 21 | dataset = args[""] 22 | dataset_root = args[""] 23 | z_dir = args[""] 24 | encode = args["--encode"] 25 | 26 | assert dataset in vision.Datasets, ( 27 | "`{}` is not supported, use `{}`".format(dataset, vision.Datasets.keys())) 28 | assert os.path.exists(dataset_root), ( 29 | "Failed to find root dir `{}` of dataset.".format(dataset_root)) 30 | assert os.path.exists(hparams), ( 31 | "Failed to find hparams josn `{}`".format(hparams)) 32 | hparams = JsonConfig(hparams) 33 | dataset = vision.Datasets[dataset] 34 | # set transform of dataset 35 | transform = transforms.Compose([ 36 | transforms.CenterCrop(hparams.Data.center_crop), 37 | transforms.Resize(hparams.Data.resize), 38 | transforms.ToTensor()]) 39 | 40 | is_train = not encode 41 | # build graph and dataset 42 | built = build(hparams, is_train) 43 | dataset = dataset(dataset_root, transform=transform) 44 | 45 | if is_train: 46 | # begin to train 47 | trainer = Trainer(**built, dataset=dataset, hparams=hparams) 48 | trainer.train() 49 | elif encode: 50 | from copy import deepcopy 51 | 52 | # pretrained model 53 | model = built["graph"] 54 | 55 | loader = DataLoader(dataset, batch_size=hparams.Train.batch_size) 56 | ys = [] 57 | idx = 1 58 | save_dir = os.path.join(z_dir, 'encodings') 59 | os.makedirs(save_dir, exist_ok=True) 60 | for i, batch in enumerate(loader): 61 | if i % 20 == 0: 62 | print(f'Encoding batch #{i}') 63 | x_batch = batch["x"] 64 | print('x_batch.shape: ', x_batch.shape) 65 | y_batch = batch["y_onehot"] 66 | for x in x_batch: 67 | z = model.generate_z(x) 68 | save_file = os.path.join(save_dir, f'{idx}.pt') 69 | torch.save(z, save_file) 70 | idx += 1 71 | ys.append(deepcopy(y_batch)) 72 | save_file = os.path.join(z_dir, f'labels.pt') 73 | ys = torch.cat(ys) 74 | torch.save(ys, save_file) 75 | 76 | -------------------------------------------------------------------------------- /src/classification/models/omni_classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | 9 | def convLayer(in_channels, out_channels, keep_prob=0.0): 10 | """3*3 convolution with padding,ever time call it the output size become half""" 11 | cnn_seq = nn.Sequential( 12 | nn.Conv2d(in_channels, out_channels, 3, 1, 1), 13 | nn.ReLU(True), 14 | nn.BatchNorm2d(out_channels), 15 | nn.MaxPool2d(kernel_size=2, stride=2), 16 | nn.Dropout(keep_prob) 17 | ) 18 | return cnn_seq 19 | 20 | 21 | class Classifier(nn.Module): 22 | def __init__(self, layer_size=64, num_channels=1, keep_prob=1.0, image_size=28): 23 | super(Classifier, self).__init__() 24 | """ 25 | Build a CNN to produce embeddings 26 | :param layer_size:64(default) 27 | :param num_channels: 28 | :param keep_prob: 29 | :param image_size: 30 | """ 31 | self.layer1 = convLayer(num_channels, layer_size, keep_prob) 32 | self.layer2 = convLayer(layer_size, layer_size, keep_prob) 33 | self.layer3 = convLayer(layer_size, layer_size, keep_prob) 34 | self.layer4 = convLayer(layer_size, layer_size, keep_prob) 35 | 36 | finalSize = int(math.floor(image_size / (2 * 2 * 2 * 2))) 37 | self.outSize = finalSize * finalSize * layer_size 38 | 39 | def forward(self, image_input): 40 | """ 41 | Use CNN defined above 42 | :param image_input: 43 | :return: 44 | """ 45 | x = self.layer1(image_input) 46 | x = self.layer2(x) 47 | x = self.layer3(x) 48 | x = self.layer4(x) 49 | x = x.view(x.size()[0], -1) 50 | return x 51 | 52 | class TREClassifier(nn.Module): 53 | def __init__(self, layer_size=64, num_channels=1, keep_prob=1.0, image_size=28): 54 | super(TREClassifier, self).__init__() 55 | 56 | self.layer1 = convLayer(num_channels, layer_size, keep_prob) 57 | self.layer2 = convLayer(layer_size, layer_size, keep_prob) 58 | self.layer3 = convLayer(layer_size, layer_size, keep_prob) 59 | self.layer4 = convLayer(layer_size, layer_size, keep_prob) 60 | 61 | finalSize = int(math.floor(image_size / (2 * 2 * 2 * 2))) 62 | self.outSize = finalSize * finalSize * layer_size 63 | 64 | def forward(self, image_input): 65 | """ 66 | Use CNN defined above 67 | :param image_input: 68 | :return: 69 | """ 70 | x = self.layer1(image_input) 71 | x = self.layer2(x) 72 | x = self.layer3(x) 73 | x = self.layer4(x) 74 | x = x.view(x.size()[0], -1) 75 | return x -------------------------------------------------------------------------------- /src/flows/models/flowplusplus/log_dist.py: -------------------------------------------------------------------------------- 1 | """Logistic distribution functions.""" 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from util import safe_log 6 | 7 | 8 | def _log_pdf(x, mean, log_scale): 9 | """Element-wise log density of the logistic distribution.""" 10 | z = (x - mean) * torch.exp(-log_scale) 11 | log_p = z - log_scale - 2 * F.softplus(z) 12 | 13 | return log_p 14 | 15 | 16 | def _log_cdf(x, mean, log_scale): 17 | """Element-wise log CDF of the logistic distribution.""" 18 | z = (x - mean) * torch.exp(-log_scale) 19 | log_p = F.logsigmoid(z) 20 | 21 | return log_p 22 | 23 | 24 | def mixture_log_pdf(x, prior_logits, means, log_scales): 25 | """Log PDF of a mixture of logistic distributions.""" 26 | log_ps = F.log_softmax(prior_logits, dim=1) \ 27 | + _log_pdf(x.unsqueeze(1), means, log_scales) 28 | log_p = torch.logsumexp(log_ps, dim=1) 29 | 30 | return log_p 31 | 32 | 33 | def mixture_log_cdf(x, prior_logits, means, log_scales): 34 | """Log CDF of a mixture of logistic distributions.""" 35 | log_ps = F.log_softmax(prior_logits, dim=1) \ 36 | + _log_cdf(x.unsqueeze(1), means, log_scales) 37 | log_p = torch.logsumexp(log_ps, dim=1) 38 | 39 | return log_p 40 | 41 | 42 | def mixture_inv_cdf(y, prior_logits, means, log_scales, 43 | eps=1e-10, max_iters=100): 44 | """Inverse CDF of a mixture of logisitics. Iterative algorithm.""" 45 | if y.min() <= 0 or y.max() >= 1: 46 | raise RuntimeError('Inverse logisitic CDF got y outside (0, 1)') 47 | 48 | def body(x_, lb_, ub_): 49 | cur_y = torch.exp(mixture_log_cdf(x_, prior_logits, means, 50 | log_scales)) 51 | gt = (cur_y > y).type(y.dtype) 52 | lt = 1 - gt 53 | new_x_ = gt * (x_ + lb_) / 2. + lt * (x_ + ub_) / 2. 54 | new_lb = gt * lb_ + lt * x_ 55 | new_ub = gt * x_ + lt * ub_ 56 | return new_x_, new_lb, new_ub 57 | 58 | x = torch.zeros_like(y) 59 | max_scales = torch.sum(torch.exp(log_scales), dim=1, keepdim=True) 60 | lb, _ = (means - 20 * max_scales).min(dim=1) 61 | ub, _ = (means + 20 * max_scales).max(dim=1) 62 | diff = float('inf') 63 | 64 | i = 0 65 | while diff > eps and i < max_iters: 66 | new_x, lb, ub = body(x, lb, ub) 67 | diff = (new_x - x).abs().max() 68 | x = new_x 69 | i += 1 70 | 71 | return x 72 | 73 | 74 | def inverse(x, reverse=False): 75 | """Inverse logistic function.""" 76 | if reverse: 77 | z = torch.sigmoid(x) 78 | ldj = F.softplus(x) + F.softplus(-x) 79 | else: 80 | z = -safe_log(x.reciprocal() - 1.) 81 | ldj = -safe_log(x) - safe_log(1. - x) 82 | 83 | return z, ldj 84 | -------------------------------------------------------------------------------- /src/classification/models/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class MLPClassifierv2(nn.Module): 7 | """ 8 | simple MLP classifier (e.g. for classifying in z-space) 9 | slightly deeper than MLPClassifier 10 | """ 11 | def __init__(self, config): 12 | super(MLPClassifierv2, self).__init__() 13 | self.config = config 14 | self.h_dim = config.model.h_dim 15 | self.n_classes = config.model.n_classes 16 | self.in_dim = config.model.in_dim 17 | 18 | self.fc1 = nn.Linear(self.in_dim, self.h_dim) 19 | self.fc2 = nn.Linear(self.h_dim, self.h_dim) 20 | self.fc3 = nn.Linear(self.h_dim, self.h_dim) 21 | self.fc4 = nn.Linear(self.h_dim, 1) 22 | 23 | def forward(self, x): 24 | x = F.relu(self.fc1(x)) 25 | x = F.relu(self.fc2(x)) 26 | x = F.relu(self.fc3(x)) 27 | logits = self.fc4(x) 28 | probas = torch.sigmoid(logits) 29 | 30 | return logits, probas 31 | 32 | 33 | class TREMLPClassifier(nn.Module): 34 | def __init__(self, config): 35 | super(TREMLPClassifier, self).__init__() 36 | self.config = config 37 | self.h_dim = config.model.h_dim 38 | self.n_classes = config.model.n_classes 39 | self.in_dim = config.model.in_dim 40 | 41 | # m = number of bridges (intermediate ratios) 42 | self.m = config.tre.m 43 | self.p = config.tre.p 44 | 45 | self.fc1 = nn.Linear(self.in_dim, self.h_dim) 46 | self.fc2 = nn.Linear(self.h_dim, self.h_dim) 47 | self.fc3 = nn.Linear(self.h_dim, self.h_dim) 48 | 49 | # bridge-specific heads 50 | self.fc4s = nn.ModuleList( 51 | [nn.Linear(self.h_dim, 1) for _ in range(self.m)]) 52 | 53 | # do we need this? 54 | # for w in self.fc4s: 55 | # nn.init.xavier_normal_(w.weight) 56 | 57 | def forward(self, x): 58 | ''' 59 | Returns logits, probas where len(logits) = len(probas) = m 60 | ''' 61 | 62 | x = F.relu(self.fc1(x)) 63 | x = F.relu(self.fc2(x)) 64 | x = F.relu(self.fc3(x)) 65 | 66 | # separate xs into m chunks 67 | # xs = [torch.cat([x[i], x[i+1]]) for i in range(self.m)] 68 | xs = [x for _ in range(self.m)] 69 | logits = [] 70 | probas = [] 71 | for x, fc4 in zip(xs, self.fc4s): 72 | curr_logits = F.relu(fc4(x)) # quadratic head 73 | curr_probas = torch.sigmoid(curr_logits) 74 | 75 | logits.append(curr_logits) 76 | 77 | logits = torch.stack(logits) 78 | probas = torch.sigmoid(logits) 79 | 80 | return logits, probas 81 | 82 | 83 | class MLPClassifier(nn.Module): 84 | """ 85 | simple MLP classifier (e.g. for classifying in z-space) 86 | """ 87 | def __init__(self, config): 88 | super(MLPClassifier, self).__init__() 89 | self.config = config 90 | self.h_dim = config.model.h_dim 91 | self.n_classes = config.model.n_classes 92 | self.in_dim = config.model.in_dim 93 | 94 | self.fc1 = nn.Linear(self.in_dim, self.h_dim) 95 | self.fc2 = nn.Linear(self.h_dim, self.h_dim) 96 | self.fc3 = nn.Linear(self.h_dim, self.n_classes) 97 | 98 | def forward(self, x): 99 | x = F.relu(self.fc1(x)) 100 | x = F.relu(self.fc2(x)) 101 | logits = self.fc3(x) 102 | probas = F.softmax(logits, dim=1) 103 | 104 | return logits, probas -------------------------------------------------------------------------------- /src/flows/trainers/vision/datasets/celeba.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import numpy as np 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | 8 | 9 | IMAGE_EXTENSTOINS = [".png", ".jpg", ".jpeg", ".bmp"] 10 | ATTR_ANNO = "list_attr_celeba.txt" 11 | 12 | def _is_image(fname): 13 | _, ext = os.path.splitext(fname) 14 | return ext.lower() in IMAGE_EXTENSTOINS 15 | 16 | 17 | def _find_images_and_annotation(root_dir): 18 | images = {} 19 | attr = None 20 | assert os.path.exists(root_dir), "{} not exists".format(root_dir) 21 | for root, _, fnames in sorted(os.walk(root_dir)): 22 | for fname in sorted(fnames): 23 | if _is_image(fname): 24 | path = os.path.join(root, fname) 25 | images[os.path.splitext(fname)[0]] = path 26 | elif fname.lower() == ATTR_ANNO: 27 | attr = os.path.join(root, fname) 28 | 29 | assert attr is not None, "Failed to find `list_attr_celeba.txt`" 30 | 31 | # begin to parse all image 32 | print("Begin to parse all image attrs") 33 | final = [] 34 | with open(attr, "r") as fin: 35 | image_total = 0 36 | attrs = [] 37 | for i_line, line in enumerate(fin): 38 | line = line.strip() 39 | if i_line == 0: 40 | image_total = int(line) 41 | elif i_line == 1: 42 | attrs = line.split(" ") 43 | else: 44 | line = re.sub("[ ]+", " ", line) 45 | line = line.split(" ") 46 | fname = os.path.splitext(line[0])[0] 47 | onehot = [int(int(d) > 0) for d in line[1:]] 48 | assert len(onehot) == len(attrs), "{} only has {} attrs < {}".format( 49 | fname, len(onehot), len(attrs)) 50 | final.append({ 51 | "path": images[fname], 52 | "attr": onehot 53 | }) 54 | print("Find {} images, with {} attrs".format(len(final), len(attrs))) 55 | return final, attrs 56 | 57 | 58 | class CelebADataset(Dataset): 59 | def __init__(self, root_dir, transform=transforms.Compose([ 60 | transforms.CenterCrop(160), 61 | transforms.Resize(32), 62 | transforms.ToTensor()])): 63 | super().__init__() 64 | dicts, attrs = _find_images_and_annotation(root_dir) 65 | self.data = dicts 66 | self.attrs = attrs 67 | self.transform = transform 68 | 69 | def __getitem__(self, index): 70 | data = self.data[index] 71 | path = data["path"] 72 | attr = data["attr"] 73 | image= Image.open(path).convert("RGB") 74 | if self.transform is not None: 75 | image = self.transform(image) 76 | return { 77 | "x": image, 78 | "y_onehot": np.asarray(attr, dtype=np.float32) 79 | } 80 | 81 | def __len__(self): 82 | return len(self.data) 83 | 84 | 85 | if __name__ == "__main__": 86 | import cv2 87 | celeba = CelebADataset("/home/chaiyujin/Downloads/Dataset/CelebA") 88 | d = celeba[0] 89 | print(d["x"].size()) 90 | img = d["x"].permute(1, 2, 0).contiguous().numpy() 91 | print(np.min(img), np.max(img)) 92 | cv2.imshow("img", img) 93 | cv2.waitKey() 94 | -------------------------------------------------------------------------------- /src/flows/functions/ckpt_util.py: -------------------------------------------------------------------------------- 1 | import os, hashlib 2 | import requests 3 | from tqdm import tqdm 4 | 5 | URL_MAP = { 6 | "cifar10": "https://heibox.uni-heidelberg.de/f/869980b53bf5416c8a28/?dl=1", 7 | "ema_cifar10": "https://heibox.uni-heidelberg.de/f/2e4f01e2d9ee49bab1d5/?dl=1", 8 | "lsun_bedroom": "https://heibox.uni-heidelberg.de/f/f179d4f21ebc4d43bbfe/?dl=1", 9 | "ema_lsun_bedroom": "https://heibox.uni-heidelberg.de/f/b95206528f384185889b/?dl=1", 10 | "lsun_cat": "https://heibox.uni-heidelberg.de/f/fac870bd988348eab88e/?dl=1", 11 | "ema_lsun_cat": "https://heibox.uni-heidelberg.de/f/0701aac3aa69457bbe34/?dl=1", 12 | "lsun_church": "https://heibox.uni-heidelberg.de/f/2711a6f712e34b06b9d8/?dl=1", 13 | "ema_lsun_church": "https://heibox.uni-heidelberg.de/f/44ccb50ef3c6436db52e/?dl=1", 14 | } 15 | CKPT_MAP = { 16 | "cifar10": "diffusion_cifar10_model/model-790000.ckpt", 17 | "ema_cifar10": "ema_diffusion_cifar10_model/model-790000.ckpt", 18 | "lsun_bedroom": "diffusion_lsun_bedroom_model/model-2388000.ckpt", 19 | "ema_lsun_bedroom": "ema_diffusion_lsun_bedroom_model/model-2388000.ckpt", 20 | "lsun_cat": "diffusion_lsun_cat_model/model-1761000.ckpt", 21 | "ema_lsun_cat": "ema_diffusion_lsun_cat_model/model-1761000.ckpt", 22 | "lsun_church": "diffusion_lsun_church_model/model-4432000.ckpt", 23 | "ema_lsun_church": "ema_diffusion_lsun_church_model/model-4432000.ckpt", 24 | } 25 | MD5_MAP = { 26 | "cifar10": "82ed3067fd1002f5cf4c339fb80c4669", 27 | "ema_cifar10": "1fa350b952534ae442b1d5235cce5cd3", 28 | "lsun_bedroom": "f70280ac0e08b8e696f42cb8e948ff1c", 29 | "ema_lsun_bedroom": "1921fa46b66a3665e450e42f36c2720f", 30 | "lsun_cat": "bbee0e7c3d7abfb6e2539eaf2fb9987b", 31 | "ema_lsun_cat": "646f23f4821f2459b8bafc57fd824558", 32 | "lsun_church": "eb619b8a5ab95ef80f94ce8a5488dae3", 33 | "ema_lsun_church": "fdc68a23938c2397caba4a260bc2445f", 34 | } 35 | 36 | 37 | def download(url, local_path, chunk_size=1024): 38 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 39 | with requests.get(url, stream=True) as r: 40 | total_size = int(r.headers.get("content-length", 0)) 41 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 42 | with open(local_path, "wb") as f: 43 | for data in r.iter_content(chunk_size=chunk_size): 44 | if data: 45 | f.write(data) 46 | pbar.update(chunk_size) 47 | 48 | 49 | def md5_hash(path): 50 | with open(path, "rb") as f: 51 | content = f.read() 52 | return hashlib.md5(content).hexdigest() 53 | 54 | 55 | def get_ckpt_path(name, root=None, check=False): 56 | if 'church_outdoor' in name: 57 | name = name.replace('church_outdoor', 'church') 58 | assert name in URL_MAP 59 | # Modify the path when necessary 60 | cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("/atlas/u/tsong/.cache")) 61 | root = ( 62 | root 63 | if root is not None 64 | else os.path.join(cachedir, "diffusion_models_converted") 65 | ) 66 | path = os.path.join(root, CKPT_MAP[name]) 67 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 68 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 69 | download(URL_MAP[name], path) 70 | md5 = md5_hash(path) 71 | assert md5 == MD5_MAP[name], md5 72 | return path 73 | -------------------------------------------------------------------------------- /src/datasets/vision.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data as data 4 | 5 | 6 | class VisionDataset(data.Dataset): 7 | _repr_indent = 4 8 | 9 | def __init__(self, root, transforms=None, transform=None, target_transform=None): 10 | if isinstance(root, torch._six.string_classes): 11 | root = os.path.expanduser(root) 12 | self.root = root 13 | 14 | has_transforms = transforms is not None 15 | has_separate_transform = transform is not None or target_transform is not None 16 | if has_transforms and has_separate_transform: 17 | raise ValueError("Only transforms or transform/target_transform can " 18 | "be passed as argument") 19 | 20 | # for backwards-compatibility 21 | self.transform = transform 22 | self.target_transform = target_transform 23 | 24 | if has_separate_transform: 25 | transforms = StandardTransform(transform, target_transform) 26 | self.transforms = transforms 27 | 28 | def __getitem__(self, index): 29 | raise NotImplementedError 30 | 31 | def __len__(self): 32 | raise NotImplementedError 33 | 34 | def __repr__(self): 35 | head = "Dataset " + self.__class__.__name__ 36 | body = ["Number of datapoints: {}".format(self.__len__())] 37 | if self.root is not None: 38 | body.append("Root location: {}".format(self.root)) 39 | body += self.extra_repr().splitlines() 40 | if hasattr(self, 'transform') and self.transform is not None: 41 | body += self._format_transform_repr(self.transform, 42 | "Transforms: ") 43 | if hasattr(self, 'target_transform') and self.target_transform is not None: 44 | body += self._format_transform_repr(self.target_transform, 45 | "Target transforms: ") 46 | lines = [head] + [" " * self._repr_indent + line for line in body] 47 | return '\n'.join(lines) 48 | 49 | def _format_transform_repr(self, transform, head): 50 | lines = transform.__repr__().splitlines() 51 | return (["{}{}".format(head, lines[0])] + 52 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 53 | 54 | def extra_repr(self): 55 | return "" 56 | 57 | 58 | class StandardTransform(object): 59 | def __init__(self, transform=None, target_transform=None): 60 | self.transform = transform 61 | self.target_transform = target_transform 62 | 63 | def __call__(self, input, target): 64 | if self.transform is not None: 65 | input = self.transform(input) 66 | if self.target_transform is not None: 67 | target = self.target_transform(target) 68 | return input, target 69 | 70 | def _format_transform_repr(self, transform, head): 71 | lines = transform.__repr__().splitlines() 72 | return (["{}{}".format(head, lines[0])] + 73 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 74 | 75 | def __repr__(self): 76 | body = [self.__class__.__name__] 77 | if self.transform is not None: 78 | body += self._format_transform_repr(self.transform, 79 | "Transform: ") 80 | if self.target_transform is not None: 81 | body += self._format_transform_repr(self.target_transform, 82 | "Target transform: ") 83 | 84 | return '\n'.join(body) 85 | -------------------------------------------------------------------------------- /src/datasets/mnist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gzip 3 | import pickle 4 | import matplotlib.pyplot as plt 5 | import os 6 | 7 | import datasets 8 | import datasets.util as util 9 | 10 | 11 | class MNIST: 12 | """ 13 | The MNIST dataset of handwritten digits. 14 | """ 15 | 16 | alpha = 1.0e-6 17 | 18 | class Data: 19 | """ 20 | Constructs the dataset. 21 | """ 22 | 23 | def __init__(self, data, logit, dequantize, rng, rgb): 24 | 25 | x = self._dequantize(data[0], rng) if dequantize else data[0] # dequantize pixels 26 | self.x = self._logit_transform(x) if logit else x # logit 27 | if rgb: 28 | self.x = np.stack([self.x, self.x, self.x], axis=1) 29 | self.labels = data[1] # numeric labels 30 | self.y = util.one_hot_encode(self.labels, 10) # 1-hot encoded labels 31 | self.N = self.x.shape[0] # number of datapoints 32 | 33 | @staticmethod 34 | def _dequantize(x, rng): 35 | """ 36 | Adds noise to pixels to dequantize them. 37 | """ 38 | return x + rng.rand(*x.shape) / 256.0 39 | 40 | @staticmethod 41 | def _logit_transform(x): 42 | """ 43 | Transforms pixel values with logit to be unconstrained. 44 | """ 45 | return util.logit(MNIST.alpha + (1 - 2*MNIST.alpha) * x) 46 | 47 | def __init__(self, logit=True, dequantize=True, rgb=False): 48 | 49 | # load dataset 50 | f = gzip.open(os.path.join(datasets.root, 'mnist/mnist.pkl.gz'), 'rb') 51 | trn, val, tst = pickle.load(f, encoding='latin1') 52 | f.close() 53 | 54 | rng = np.random.RandomState(42) 55 | self.trn = self.Data(trn, logit, dequantize, rng, rgb) 56 | self.val = self.Data(val, logit, dequantize, rng, rgb) 57 | self.tst = self.Data(tst, logit, dequantize, rng, rgb) 58 | 59 | im_dim_idx = 2 if rgb else 1 60 | im_dim = int(np.sqrt(self.trn.x.shape[im_dim_idx])) 61 | self.n_dims = (3, im_dim, im_dim) if rgb else (1, im_dim, im_dim) 62 | self.n_labels = self.trn.y.shape[1] 63 | self.image_size = [im_dim, im_dim] 64 | 65 | def show_pixel_histograms(self, split, pixel=None): 66 | """ 67 | Shows the histogram of pixel values, or of a specific pixel if given. 68 | """ 69 | 70 | data_split = getattr(self, split, None) 71 | if data_split is None: 72 | raise ValueError('Invalid data split') 73 | 74 | if pixel is None: 75 | data = data_split.x.flatten() 76 | 77 | else: 78 | row, col = pixel 79 | idx = row * self.image_size[0] + col 80 | data = data_split.x[:, idx] 81 | 82 | n_bins = int(np.sqrt(data_split.N)) 83 | fig, ax = plt.subplots(1, 1) 84 | ax.hist(data, n_bins, normed=True) 85 | plt.show() 86 | 87 | def show_images(self, split): 88 | """ 89 | Displays the images in a given split. 90 | :param split: string 91 | """ 92 | 93 | # get split 94 | data_split = getattr(self, split, None) 95 | if data_split is None: 96 | raise ValueError('Invalid data split') 97 | 98 | # display images 99 | util.disp_imdata(data_split.x, self.image_size, [6, 10]) 100 | 101 | plt.show() 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Featurized Density Ratio Estimation 2 | This repo contains a reference implementation for featurized density ratio estimation (f-dre) as described in the paper: 3 | > Featurized Density Ratio Estimation
4 | > [Kristy Choi*](http://kristychoi.com/), [Madeline Liao*](https://www.linkedin.com/in/madelineliao/), [Stefano Ermon](https://cs.stanford.edu/~ermon/)
5 | > Uncertainty in Artificial Intelligence (UAI), 2021.
6 | > Paper: https://arxiv.org/abs/2107.02212
7 | 8 | 9 | ## For non-KMM/KLIEP experiments: 10 | ### 1) Environment setup: 11 | (a) Necessary packages can be found in `src/requirements.txt`. 12 | 13 | (b) Set the correct Python path using the following command: 14 | ``` 15 | source init_env.sh 16 | ``` 17 | Once this step is completed, all further steps for running experiments should be launched from the `src/` directory. 18 | 19 | ### 2) Pre-train the normalizing flow (Masked Autoregressive Flow). 20 | As the experimental workflow is quite similar across datasets, we'll use the toy Gaussian mixtures as a concrete example. To first train the flow, run: 21 | ``` 22 | python3 main.py --config flows/gmm/maf.yaml --exp_id=gmm_flow --ni 23 | ``` 24 | Config files for other datasets can be found in `src/configs/flows//`. The trained flow will be saved in `src/flows/results/` (note the path where it is saved) and will be used for downstream z-space density ratio estimation. 25 | 26 | 27 | ### 3) Generate encodings prior to running binary classification. 28 | The following script will use the pre-trained flow in Step #2 to generate encodings of the data points: 29 | ``` 30 | python3 main.py --config flows/gmm/maf.yaml --exp_id encode_gmm_z --restore_file=./flows/results/gmm_flow/ --sample --encode_z --ni 31 | ``` 32 | 33 | ### 4) Train density ratio estimator (classifier) on the encoded data points. 34 | Running the following script will estimate density ratios in feature space: 35 | ``` 36 | python3 main.py --classify --config classification/gmm/mlp_z.yaml --exp_id gmm_z --ni 37 | ``` 38 | Note that config files for other baselines such as training on the x's directly and joint training of the flow and classifier can be found in `classification/gmm/.yaml`. The scripts to run these methods are the same as the above, just with a modification of the `exp_id`. 39 | 40 | ### 5) (For MNIST targeted generation experiment only) Generate samples! 41 | #### 5.1) Train an attribute classifier 42 | In order to get stats on the generated samples, train an attribute classifer using the following script: 43 | ``` 44 | python3 main.py \ 45 | --classify \ 46 | --attr background \ # or "digits" 47 | --config classification/mnist/diff_bkgd/attr_bkgd.yaml \ 48 | --exp_id classify_attr_bkgd \ 49 | --ni 50 | ``` 51 | #### 5.2) Generate the samples 52 | The following script performs regular, unweighted generation. To perform reweighted generation with DRE in z-space, replace `--generate_samples` with `--fair_generate`. To perform this with DRE in x-space, replace `--generate_samples` with `--fair_generate` and `--dre_x`. 53 | ``` 54 | python3 main.py \ 55 | --sample \ 56 | --seed 10 \ 57 | --generate_samples \ 58 | --config flows/mnist/diff_bkgd/perc0.1.yaml \ 59 | --exp_id regular_generation_perc0.1 \ 60 | --attr_clf_ckpt=/classification/results/{path to attr classifier checkpoint.pth} \ 61 | --restore_file=/flows/results/{directory with flow checkpoint} \ 62 | --ni 63 | ``` 64 | ## For KMM/KLIEP experiments: 65 | Each of these experiments is self-contained within one Jupyter notebook. Simply run the corresponding notebook cell by cell in `/notebooks`. 66 | 67 | ## References 68 | If you find this work useful in your research, please consider citing the following paper: 69 | ``` 70 | @article{choi2021featurized, 71 | title={Featurized Density Ratio Estimation}, 72 | author={Choi, Kristy and Liao, Madeline and Ermon, Stefano}, 73 | journal={arXiv preprint arXiv:2107.02212}, 74 | year={2021} 75 | } 76 | ``` 77 | -------------------------------------------------------------------------------- /src/flows/models/glow/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | class Squeeze(nn.Module): 6 | """ RealNVP squeezing operation layer (cf RealNVP section 3.6; Glow figure 2b): 7 | For each channel, it divides the image into subsquares of shape 2 × 2 × c, then reshapes them into subsquares of 8 | shape 1 × 1 × 4c. The squeezing operation transforms an s × s × c tensor into an s × s × 4c tensor """ 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def forward(self, x): 13 | B,C,H,W = x.shape 14 | x = x.reshape(B, C, H//2, 2, W//2, 2) # factor spatial dim 15 | x = x.permute(0, 1, 3, 5, 2, 4) # transpose to (B, C, 2, 2, H//2, W//2) 16 | x = x.reshape(B, 4*C, H//2, W//2) # aggregate spatial dim factors into channels 17 | return x 18 | 19 | def inverse(self, x): 20 | B,C,H,W = x.shape 21 | x = x.reshape(B, C//4, 2, 2, H, W) # factor channel dim 22 | x = x.permute(0, 1, 4, 2, 5, 3) # transpose to (B, C//4, H, 2, W, 2) 23 | x = x.reshape(B, C//4, 2*H, 2*W) # aggregate channel dim factors into spatial dims 24 | return x 25 | 26 | 27 | class Split(nn.Module): 28 | """ Split layer; cf Glow figure 2 / RealNVP figure 4b 29 | Based on RealNVP multi-scale architecture: splits an input in half along the channel dim; half the vars are 30 | directly modeled as Gaussians while the other half undergo further transformations (cf RealNVP figure 4b). 31 | """ 32 | def __init__(self, n_channels): 33 | super().__init__() 34 | self.gaussianize = Gaussianize(n_channels//2) 35 | 36 | def forward(self, x): 37 | x1, x2 = x.chunk(2, dim=1) # split input along channel dim 38 | z2, logdet = self.gaussianize(x1, x2) 39 | return x1, z2, logdet 40 | 41 | def inverse(self, x1, z2): 42 | x2, logdet = self.gaussianize.inverse(x1, z2) 43 | x = torch.cat([x1, x2], dim=1) # cat along channel dim 44 | return x, logdet 45 | 46 | 47 | class Gaussianize(nn.Module): 48 | """ Gaussianization per ReanNVP sec 3.6 / fig 4b -- at each step half the variables are directly modeled as Gaussians. 49 | Model as Gaussians: 50 | x2 = z2 * exp(logs) + mu, so x2 ~ N(mu, exp(logs)^2) where mu, logs = f(x1) 51 | then to recover the random numbers z driving the model: 52 | z2 = (x2 - mu) * exp(-logs) 53 | Here f(x1) is a conv layer initialized to identity. 54 | """ 55 | def __init__(self, n_channels): 56 | super().__init__() 57 | self.net = nn.Conv2d(n_channels, 2*n_channels, kernel_size=3, padding=1) # computes the parameters of Gaussian 58 | self.log_scale_factor = nn.Parameter(torch.zeros(2*n_channels,1,1)) # learned scale (cf RealNVP sec 4.1 / Glow official code 59 | # initialize to identity 60 | self.net.weight.data.zero_() 61 | self.net.bias.data.zero_() 62 | 63 | def forward(self, x1, x2): 64 | h = self.net(x1) * self.log_scale_factor.exp() # use x1 to model x2 as Gaussians; learnable scale 65 | m, logs = h[:,0::2,:,:], h[:,1::2,:,:] # split along channel dims 66 | z2 = (x2 - m) * torch.exp(-logs) # center and scale; log prob is computed at the model forward 67 | logdet = - logs.sum([1,2,3]) 68 | return z2, logdet 69 | 70 | def inverse(self, x1, z2): 71 | h = self.net(x1) * self.log_scale_factor.exp() 72 | m, logs = h[:,0::2,:,:], h[:,1::2,:,:] 73 | x2 = m + z2 * torch.exp(logs) 74 | logdet = logs.sum([1,2,3]) 75 | return x2, logdet 76 | 77 | 78 | class Preprocess(nn.Module): 79 | def __init__(self): 80 | super().__init__() 81 | 82 | def forward(self, x): 83 | logdet = - math.log(256) * x[0].numel() # processing each image dim from [0, 255] to [0,1]; per RealNVP sec 4.1 taken into account 84 | return x - 0.5, logdet # center x at 0 85 | 86 | def inverse(self, x): 87 | logdet = math.log(256) * x[0].numel() 88 | return x + 0.5, logdet 89 | -------------------------------------------------------------------------------- /src/util/array_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Flip(nn.Module): 7 | def forward(self, x, sldj, reverse=False): 8 | assert isinstance(x, tuple) and len(x) == 2 9 | return (x[1], x[0]), sldj 10 | 11 | 12 | def mean_dim(tensor, dim=None, keepdims=False): 13 | """Take the mean along multiple dimensions. 14 | 15 | Args: 16 | tensor (torch.Tensor): Tensor of values to average. 17 | dim (list): List of dimensions along which to take the mean. 18 | keepdims (bool): Keep dimensions rather than squeezing. 19 | 20 | Returns: 21 | mean (torch.Tensor): New tensor of mean value(s). 22 | """ 23 | if dim is None: 24 | return tensor.mean() 25 | else: 26 | if isinstance(dim, int): 27 | dim = [dim] 28 | dim = sorted(dim) 29 | for d in dim: 30 | tensor = tensor.mean(dim=d, keepdim=True) 31 | if not keepdims: 32 | for i, d in enumerate(dim): 33 | tensor.squeeze_(d-i) 34 | return tensor 35 | 36 | 37 | def checkerboard(x, reverse=False): 38 | """Split x in a checkerboard pattern. Collapse horizontally.""" 39 | # Get dimensions 40 | if reverse: 41 | b, c, h, w = x[0].size() 42 | w *= 2 43 | device = x[0].device 44 | else: 45 | b, c, h, w = x.size() 46 | device = x.device 47 | 48 | # Get list of indices in alternating checkerboard pattern 49 | y_idx = [] 50 | z_idx = [] 51 | for i in range(h): 52 | for j in range(w): 53 | if (i % 2) == (j % 2): 54 | y_idx.append(i * w + j) 55 | else: 56 | z_idx.append(i * w + j) 57 | y_idx = torch.tensor(y_idx, dtype=torch.int64, device=device) 58 | z_idx = torch.tensor(z_idx, dtype=torch.int64, device=device) 59 | 60 | if reverse: 61 | y, z = (t.contiguous().view(b, c, h * w // 2) for t in x) 62 | x = torch.zeros(b, c, h * w, dtype=y.dtype, device=y.device) 63 | x[:, :, y_idx] += y 64 | x[:, :, z_idx] += z 65 | x = x.view(b, c, h, w) 66 | 67 | return x 68 | else: 69 | if w % 2 != 0: 70 | raise RuntimeError('Checkerboard got odd width input: {}'.format(w)) 71 | 72 | x = x.view(b, c, h * w) 73 | y = x[:, :, y_idx].view(b, c, h, w // 2) 74 | z = x[:, :, z_idx].view(b, c, h, w // 2) 75 | 76 | return y, z 77 | 78 | 79 | def channelwise(x, reverse=False): 80 | """Split x channel-wise.""" 81 | if reverse: 82 | x = torch.cat(x, dim=1) 83 | return x 84 | else: 85 | y, z = x.chunk(2, dim=1) 86 | return y, z 87 | 88 | 89 | def squeeze(x): 90 | """Trade spatial extent for channels. I.e., convert each 91 | 1x4x4 volume of input into a 4x1x1 volume of output. 92 | 93 | Args: 94 | x (torch.Tensor): Input to squeeze. 95 | 96 | Returns: 97 | x (torch.Tensor): Squeezed or unsqueezed tensor. 98 | """ 99 | b, c, h, w = x.size() 100 | x = x.view(b, c, h // 2, 2, w // 2, 2) 101 | x = x.permute(0, 1, 3, 5, 2, 4).contiguous() 102 | x = x.view(b, c * 2 * 2, h // 2, w // 2) 103 | 104 | return x 105 | 106 | 107 | def unsqueeze(x): 108 | """Trade channels channels for spatial extent. I.e., convert each 109 | 4x1x1 volume of input into a 1x4x4 volume of output. 110 | 111 | Args: 112 | x (torch.Tensor): Input to unsqueeze. 113 | 114 | Returns: 115 | x (torch.Tensor): Unsqueezed tensor. 116 | """ 117 | b, c, h, w = x.size() 118 | x = x.view(b, c // 4, 2, 2, h, w) 119 | x = x.permute(0, 1, 4, 2, 5, 3).contiguous() 120 | x = x.view(b, c // 4, h * 2, w * 2) 121 | 122 | return x 123 | 124 | 125 | def concat_elu(x): 126 | """Concatenated ReLU (http://arxiv.org/abs/1603.05201), but with ELU.""" 127 | return F.elu(torch.cat((x, -x), dim=1)) 128 | 129 | 130 | def safe_log(x): 131 | return torch.log(x.clamp(min=1e-22)) 132 | -------------------------------------------------------------------------------- /src/flows/models/flowplusplus/act_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from util import mean_dim 5 | 6 | 7 | class _BaseNorm(nn.Module): 8 | """Base class for ActNorm (Glow) and PixNorm (Flow++). 9 | 10 | The mean and inv_std get initialized using the mean and variance of the 11 | first mini-batch. After the init, mean and inv_std are trainable parameters. 12 | 13 | Adapted from: 14 | > https://github.com/openai/glow 15 | """ 16 | def __init__(self, num_channels, height, width): 17 | super(_BaseNorm, self).__init__() 18 | 19 | # Input gets concatenated along channel axis 20 | num_channels *= 2 21 | 22 | self.register_buffer('is_initialized', torch.zeros(1)) 23 | self.mean = nn.Parameter(torch.zeros(1, num_channels, height, width)) 24 | self.inv_std = nn.Parameter(torch.zeros(1, num_channels, height, width)) 25 | self.eps = 1e-6 26 | 27 | def initialize_parameters(self, x): 28 | if not self.training: 29 | return 30 | 31 | with torch.no_grad(): 32 | mean, inv_std = self._get_moments(x) 33 | self.mean.data.copy_(mean.data) 34 | self.inv_std.data.copy_(inv_std.data) 35 | self.is_initialized += 1. 36 | 37 | def _center(self, x, reverse=False): 38 | if reverse: 39 | return x + self.mean 40 | else: 41 | return x - self.mean 42 | 43 | def _get_moments(self, x): 44 | raise NotImplementedError('Subclass of _BaseNorm must implement _get_moments') 45 | 46 | def _scale(self, x, sldj, reverse=False): 47 | raise NotImplementedError('Subclass of _BaseNorm must implement _scale') 48 | 49 | def forward(self, x, ldj=None, reverse=False): 50 | x = torch.cat(x, dim=1) 51 | if not self.is_initialized: 52 | self.initialize_parameters(x) 53 | 54 | if reverse: 55 | x, ldj = self._scale(x, ldj, reverse) 56 | x = self._center(x, reverse) 57 | else: 58 | x = self._center(x, reverse) 59 | x, ldj = self._scale(x, ldj, reverse) 60 | x = x.chunk(2, dim=1) 61 | 62 | return x, ldj 63 | 64 | 65 | class ActNorm(_BaseNorm): 66 | """Activation Normalization used in Glow 67 | 68 | The mean and inv_std get initialized using the mean and variance of the 69 | first mini-batch. After the init, mean and inv_std are trainable parameters. 70 | """ 71 | def __init__(self, num_channels): 72 | super(ActNorm, self).__init__(num_channels, 1, 1) 73 | 74 | def _get_moments(self, x): 75 | mean = mean_dim(x.clone(), dim=[0, 2, 3], keepdims=True) 76 | var = mean_dim((x.clone() - mean) ** 2, dim=[0, 2, 3], keepdims=True) 77 | inv_std = 1. / (var.sqrt() + self.eps) 78 | 79 | return mean, inv_std 80 | 81 | def _scale(self, x, sldj, reverse=False): 82 | if reverse: 83 | x = x / self.inv_std 84 | sldj = sldj - self.inv_std.log().sum() * x.size(2) * x.size(3) 85 | else: 86 | x = x * self.inv_std 87 | sldj = sldj + self.inv_std.log().sum() * x.size(2) * x.size(3) 88 | 89 | return x, sldj 90 | 91 | 92 | class PixNorm(_BaseNorm): 93 | """Pixel-wise Activation Normalization used in Flow++ 94 | 95 | Normalizes every activation independently (note this differs from the variant 96 | used in in Glow, where they normalize each channel). The mean and stddev get 97 | initialized using the mean and stddev of the first mini-batch. After the 98 | initialization, `mean` and `inv_std` become trainable parameters. 99 | """ 100 | def _get_moments(self, x): 101 | mean = torch.mean(x.clone(), dim=0, keepdim=True) 102 | var = torch.mean((x.clone() - mean) ** 2, dim=0, keepdim=True) 103 | inv_std = 1. / (var.sqrt() + self.eps) 104 | 105 | return mean, inv_std 106 | 107 | def _scale(self, x, sldj, reverse=False): 108 | if reverse: 109 | x = x / self.inv_std 110 | sldj = sldj - self.inv_std.log().sum() 111 | else: 112 | x = x * self.inv_std 113 | sldj = sldj + self.inv_std.log().sum() 114 | 115 | return x, sldj 116 | -------------------------------------------------------------------------------- /src/flows/trainers/glow1/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import datetime 4 | 5 | 6 | class JsonConfig(dict): 7 | """ 8 | The configures will be loaded and dumped as json file. 9 | The Structure will be maintained as json. 10 | [TODO]: Some `asserts` can be make by key `__assert__` 11 | """ 12 | Indent = 2 13 | 14 | def __init__(self, *argv, **kwargs): 15 | super().__init__() 16 | super().__setitem__("__name", "default") 17 | # check input 18 | assert len(argv) == 0 or len(kwargs) == 0, ( 19 | "[JsonConfig]: Cannot initialize with" 20 | " position parameters (json file or a dict)" 21 | " and named parameters (key and values) at the same time.") 22 | if len(argv) > 0: 23 | # init from a json or dict 24 | assert len(argv) == 1, ( 25 | "[JsonConfig]: Need one positional parameters, found two.") 26 | arg = argv[0] 27 | else: 28 | arg = kwargs 29 | # begin initialization 30 | if isinstance(arg, str): 31 | super().__setitem__("__name", 32 | os.path.splitext(os.path.basename(arg))[0]) 33 | with open(arg, "r") as load_f: 34 | arg = json.load(load_f) 35 | if isinstance(arg, dict): 36 | # case 1: init from dict 37 | for key in arg: 38 | value = arg[key] 39 | if isinstance(value, dict): 40 | value = JsonConfig(value) 41 | super().__setitem__(key, value) 42 | else: 43 | raise TypeError(("[JsonConfig]: Do not support given input" 44 | " with type {}").format(type(arg))) 45 | 46 | def __setattr__(self, attr, value): 47 | raise Exception("[JsonConfig]: Can't set constant key {}".format(attr)) 48 | 49 | def __setitem__(self, item, value): 50 | raise Exception("[JsonConfig]: Can't set constant key {}".format(item)) 51 | 52 | def __getattr__(self, attr): 53 | return super().__getitem__(attr) 54 | 55 | def __str__(self): 56 | return self.__to_string("", 0) 57 | 58 | def __to_string(self, name, intent): 59 | ret = " " * intent + name + " {\n" 60 | for key in self: 61 | if key.find("__") == 0: 62 | continue 63 | value = self[key] 64 | line = " " * intent 65 | if isinstance(value, JsonConfig): 66 | line += value.__to_string(key, intent + JsonConfig.Indent) 67 | else: 68 | line += " " * JsonConfig.Indent + key + ": " + str(value) 69 | ret += line + "\n" 70 | ret += " " * intent + "}" 71 | return ret 72 | 73 | def __add__(self, b): 74 | assert isinstance(b, JsonConfig) 75 | for k in b: 76 | v = b[k] 77 | if k in self: 78 | if isinstance(v, JsonConfig): 79 | super().__setitem__(k, self[k] + v) 80 | else: 81 | if k == "__name": 82 | super().__setitem__(k, self[k] + "&" + v) 83 | else: 84 | assert v == self[k], ( 85 | "[JsonConfig]: Two config conflicts at" 86 | "`{}`, {} != {}".format(k, self[k], v)) 87 | else: 88 | # new key, directly add 89 | super().__setitem__(k, v) 90 | return self 91 | 92 | def date_name(self): 93 | date = str(datetime.datetime.now()) 94 | date = date[:date.rfind(":")].replace("-", "")\ 95 | .replace(":", "")\ 96 | .replace(" ", "_") 97 | return date + "_" + super().__getitem__("__name") + ".json" 98 | 99 | def dump(self, dir_path, json_name=None): 100 | if json_name is None: 101 | json_name = self.date_name() 102 | json_path = os.path.join(dir_path, json_name) 103 | with open(json_path, "w") as fout: 104 | print(str(self)) 105 | json.dump(self.to_dict(), fout, indent=JsonConfig.Indent) 106 | 107 | def to_dict(self): 108 | ret = {} 109 | for k in self: 110 | if k.find("__") == 0: 111 | continue 112 | v = self[k] 113 | if isinstance(v, JsonConfig): 114 | ret[k] = v.to_dict() 115 | else: 116 | ret[k] = v 117 | return ret 118 | -------------------------------------------------------------------------------- /src/flows/trainers/glow1/builder.py: -------------------------------------------------------------------------------- 1 | import re, os 2 | import copy 3 | import torch 4 | from collections import defaultdict 5 | from . import learning_rate_schedule 6 | from .config import JsonConfig 7 | from .models import Glow 8 | from .utils import load, save, get_proper_device 9 | 10 | 11 | def build_adam(params, args): 12 | return torch.optim.Adam(params, **args) 13 | 14 | 15 | __build_optim_dict = { 16 | "adam": build_adam 17 | } 18 | 19 | 20 | def build(hparams, is_training): 21 | if isinstance(hparams, str): 22 | hparams = JsonConfig(hparams) 23 | # get graph and criterions from build function 24 | graph, optim, lrschedule, criterion_dict = None, None, None, None # init with None 25 | cpu, devices = "cpu", None 26 | get_loss = None 27 | # 1. build graph and criterion_dict, (on cpu) 28 | # build and append `device attr` to graph 29 | graph = Glow(hparams) 30 | graph.device = hparams.Device.glow 31 | if graph is not None: 32 | # get device 33 | devices = get_proper_device(graph.device) 34 | graph.device = devices 35 | graph.to(cpu) 36 | # 2. get optim (on cpu) 37 | try: 38 | if graph is not None and is_training: 39 | optim_name = hparams.Optim.name 40 | optim = __build_optim_dict[optim_name](graph.parameters(), hparams.Optim.args.to_dict()) 41 | print("[Builder]: Using optimizer `{}`, with args:{}".format(optim_name, hparams.Optim.args)) 42 | # get lrschedule 43 | schedule_name = "default" 44 | schedule_args = {} 45 | if "Schedule" in hparams.Optim: 46 | schedule_name = hparams.Optim.Schedule.name 47 | schedule_args = hparams.Optim.Schedule.args.to_dict() 48 | if not ("init_lr" in schedule_args): 49 | schedule_args["init_lr"] = hparams.Optim.args.lr 50 | assert schedule_args["init_lr"] == hparams.Optim.args.lr,\ 51 | "Optim lr {} != Schedule init_lr {}".format(hparams.Optim.args.lr, schedule_args["init_lr"]) 52 | lrschedule = { 53 | "func": getattr(learning_rate_schedule, schedule_name), 54 | "args": schedule_args 55 | } 56 | except KeyError: 57 | raise ValueError("[Builder]: Optimizer `{}` is not supported.".format(optim_name)) 58 | # 3. warm start and move to devices 59 | if graph is not None: 60 | # 1. warm start from pre-trained model (on cpu) 61 | pre_trained = None 62 | loaded_step = 0 63 | if is_training: 64 | if "warm_start" in hparams.Train and len(hparams.Train.warm_start) > 0: 65 | pre_trained = hparams.Train.warm_start 66 | else: 67 | pre_trained = hparams.Infer.pre_trained 68 | if pre_trained is not None: 69 | loaded_step = load(os.path.basename(pre_trained), 70 | graph=graph, optim=optim, criterion_dict=None, 71 | pkg_dir=os.path.dirname(pre_trained), 72 | device=cpu) 73 | # 2. move graph to device (to cpu or cuda) 74 | use_cpu = any([isinstance(d, str) and d.find("cpu") >= 0 for d in devices]) 75 | if use_cpu: 76 | graph = graph.cpu() 77 | print("[Builder]: Use cpu to train.") 78 | else: 79 | if "data" in hparams.Device: 80 | data_gpu = hparams.Device.data 81 | if isinstance(data_gpu, str): 82 | data_gpu = int(data_gpu[5:]) 83 | else: 84 | data_gpu = devices[0] 85 | # move to first 86 | graph = graph.cuda(device=devices[0]) 87 | if is_training and pre_trained is not None: 88 | # note that it is possible necessary to move optim 89 | if hasattr(optim, "state"): 90 | def move_to(D, device): 91 | for k in D: 92 | if isinstance(D[k], dict) or isinstance(D[k], defaultdict): 93 | move_to(D[k], device) 94 | elif torch.is_tensor(D[k]): 95 | D[k] = D[k].cuda(device) 96 | move_to(optim.state, devices[0]) 97 | print("[Builder]: Use cuda {} to train, use {} to load data and get loss.".format(devices, data_gpu)) 98 | 99 | return { 100 | "graph": graph, 101 | "optim": optim, 102 | "lrschedule": lrschedule, 103 | "devices": devices, 104 | "data_device": data_gpu if not use_cpu else "cpu", 105 | "loaded_step": loaded_step 106 | } 107 | -------------------------------------------------------------------------------- /src/classification/models/cnn.py: -------------------------------------------------------------------------------- 1 | # https://github.com/BoyuanJiang/matching-networks-pytorch/blob/master/matching_networks.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import math 6 | import numpy as np 7 | import torch.nn.functional as F 8 | 9 | 10 | def convLayer(in_channels, out_channels, keep_prob=0.0): 11 | """3*3 convolution with padding,ever time call it the output size become half""" 12 | cnn_seq = nn.Sequential( 13 | nn.Conv2d(in_channels, out_channels, 3, 1, 1), 14 | nn.BatchNorm2d(out_channels), 15 | nn.ReLU(True), 16 | nn.MaxPool2d(2) 17 | # nn.Dropout(keep_prob) 18 | ) 19 | return cnn_seq 20 | 21 | 22 | class CNNClassifier(nn.Module): 23 | def __init__(self, layer_size=64, num_channels=1, keep_prob=1.0, image_size=28): 24 | super(CNNClassifier, self).__init__() 25 | """ 26 | Build a CNN to produce embeddings 27 | :param layer_size:64(default) 28 | :param num_channels: 29 | :param keep_prob: 30 | :param image_size: 31 | """ 32 | self.layer1 = convLayer(num_channels, layer_size, keep_prob) 33 | self.layer2 = convLayer(layer_size, layer_size, keep_prob) 34 | self.layer3 = convLayer(layer_size, layer_size, keep_prob) 35 | self.layer4 = convLayer(layer_size, layer_size, keep_prob) 36 | 37 | finalSize = int(math.floor(image_size / (2 * 2 * 2 * 2))) 38 | self.outSize = finalSize * finalSize * layer_size 39 | 40 | # TODO: how many classes??? 41 | self.fc = nn.Linear(self.outSize, 1622) 42 | 43 | def forward(self, image_input): 44 | """ 45 | Use CNN defined above 46 | :param image_input: 47 | :return: 48 | """ 49 | x = self.layer1(image_input) 50 | x = self.layer2(x) 51 | x = self.layer3(x) 52 | x = self.layer4(x) 53 | x = x.view(x.size()[0], -1) 54 | 55 | logits = self.fc(x) 56 | probas = F.softmax(logits, dim=-1) 57 | 58 | return logits, probas 59 | 60 | class TRECNNClassifier(nn.Module): 61 | def __init__(self, config, layer_size=64, num_channels=1, keep_prob=1.0, image_size=28): 62 | super(CNNClassifier, self).__init__() 63 | 64 | self.m = self.config.tre.m 65 | self.p = self.config.tre.p 66 | 67 | self.layer1 = convLayer(num_channels, layer_size, keep_prob) 68 | self.layer2 = convLayer(layer_size, layer_size, keep_prob) 69 | self.layer3 = convLayer(layer_size, layer_size, keep_prob) 70 | self.layer4 = convLayer(layer_size, layer_size, keep_prob) 71 | 72 | finalSize = int(math.floor(image_size / (2 * 2 * 2 * 2))) 73 | self.outSize = finalSize * finalSize * layer_size 74 | self.fc_list = [nn.Linear(self.outSize, 1622) for _ in range(self.m)] 75 | 76 | def forward(self, image_input): 77 | """ 78 | Use CNN defined above 79 | :param image_input: 80 | :return: 81 | """ 82 | x = self.layer1(image_input) 83 | x = self.layer2(x) 84 | x = self.layer3(x) 85 | x = self.layer4(x) 86 | 87 | x = x.view(self.m, x.size()[0], -1) # TODO 88 | 89 | logits = [] 90 | for x_i, fc in zip(x, self.fc_list): 91 | logits.append(fc(x_i)) 92 | 93 | logits = torch.stack(logits) 94 | probas = F.softmax(logits, dim=-1) 95 | 96 | return logits, probas 97 | 98 | 99 | class BinaryCNNClassifier(nn.Module): 100 | def __init__(self, layer_size=64, num_channels=1, keep_prob=1.0, image_size=28): 101 | super(BinaryCNNClassifier, self).__init__() 102 | """ 103 | Build a CNN to produce embeddings 104 | :param layer_size:64(default) 105 | :param num_channels: 106 | :param keep_prob: 107 | :param image_size: 108 | """ 109 | self.layer1 = convLayer(num_channels, layer_size, keep_prob) 110 | self.layer2 = convLayer(layer_size, layer_size, keep_prob) 111 | self.layer3 = convLayer(layer_size, layer_size, keep_prob) 112 | self.layer4 = convLayer(layer_size, layer_size, keep_prob) 113 | 114 | finalSize = int(math.floor(image_size / (2 * 2 * 2 * 2))) 115 | self.outSize = finalSize * finalSize * layer_size 116 | 117 | self.fc = nn.Linear(self.outSize, 1) 118 | 119 | def forward(self, image_input): 120 | """ 121 | Use CNN defined above 122 | :param image_input: 123 | :return: 124 | """ 125 | x = self.layer1(image_input) 126 | x = self.layer2(x) 127 | x = self.layer3(x) 128 | x = self.layer4(x) 129 | x = x.view(x.size()[0], -1) 130 | 131 | logits = self.fc(x) 132 | probas = torch.sigmoid(logits) 133 | 134 | return logits, probas 135 | -------------------------------------------------------------------------------- /src/flows/results/kmm_kliep_synthetic_maf/results.txt: -------------------------------------------------------------------------------- 1 | Evaluate (epoch 0) -- logp(x) = -3.936 +/- 0.046 2 | Evaluate (epoch 1) -- logp(x) = -3.630 +/- 0.046 3 | Evaluate (epoch 2) -- logp(x) = -3.570 +/- 0.048 4 | Evaluate (epoch 3) -- logp(x) = -3.568 +/- 0.047 5 | Evaluate (epoch 4) -- logp(x) = -3.557 +/- 0.049 6 | Evaluate (epoch 5) -- logp(x) = -3.547 +/- 0.051 7 | Evaluate (epoch 6) -- logp(x) = -3.542 +/- 0.051 8 | Evaluate (epoch 7) -- logp(x) = -3.537 +/- 0.053 9 | Evaluate (epoch 8) -- logp(x) = -3.518 +/- 0.052 10 | Evaluate (epoch 9) -- logp(x) = -3.535 +/- 0.056 11 | Evaluate (epoch 10) -- logp(x) = -3.516 +/- 0.055 12 | Evaluate (epoch 11) -- logp(x) = -3.512 +/- 0.054 13 | Evaluate (epoch 12) -- logp(x) = -3.515 +/- 0.053 14 | Evaluate (epoch 13) -- logp(x) = -3.508 +/- 0.056 15 | Evaluate (epoch 14) -- logp(x) = -3.515 +/- 0.057 16 | Evaluate (epoch 15) -- logp(x) = -3.515 +/- 0.058 17 | Evaluate (epoch 16) -- logp(x) = -3.505 +/- 0.056 18 | Evaluate (epoch 17) -- logp(x) = -3.505 +/- 0.055 19 | Evaluate (epoch 18) -- logp(x) = -3.514 +/- 0.057 20 | Evaluate (epoch 19) -- logp(x) = -3.493 +/- 0.057 21 | Evaluate (epoch 20) -- logp(x) = -3.506 +/- 0.057 22 | Evaluate (epoch 21) -- logp(x) = -3.495 +/- 0.056 23 | Evaluate (epoch 22) -- logp(x) = -3.495 +/- 0.056 24 | Evaluate (epoch 23) -- logp(x) = -3.489 +/- 0.056 25 | Evaluate (epoch 24) -- logp(x) = -3.483 +/- 0.055 26 | Evaluate (epoch 25) -- logp(x) = -3.484 +/- 0.058 27 | Evaluate (epoch 26) -- logp(x) = -3.488 +/- 0.054 28 | Evaluate (epoch 27) -- logp(x) = -3.496 +/- 0.060 29 | Evaluate (epoch 28) -- logp(x) = -3.488 +/- 0.056 30 | Evaluate (epoch 29) -- logp(x) = -3.478 +/- 0.056 31 | Evaluate (epoch 30) -- logp(x) = -3.498 +/- 0.059 32 | Evaluate (epoch 31) -- logp(x) = -3.493 +/- 0.056 33 | Evaluate (epoch 32) -- logp(x) = -3.475 +/- 0.058 34 | Evaluate (epoch 33) -- logp(x) = -3.485 +/- 0.061 35 | Evaluate (epoch 34) -- logp(x) = -3.483 +/- 0.056 36 | Evaluate (epoch 35) -- logp(x) = -3.473 +/- 0.058 37 | Evaluate (epoch 36) -- logp(x) = -3.483 +/- 0.059 38 | Evaluate (epoch 37) -- logp(x) = -3.500 +/- 0.055 39 | Evaluate (epoch 38) -- logp(x) = -3.531 +/- 0.070 40 | Evaluate (epoch 39) -- logp(x) = -3.490 +/- 0.055 41 | Evaluate (epoch 40) -- logp(x) = -3.481 +/- 0.056 42 | Evaluate (epoch 41) -- logp(x) = -3.486 +/- 0.059 43 | Evaluate (epoch 42) -- logp(x) = -3.474 +/- 0.056 44 | Evaluate (epoch 43) -- logp(x) = -3.474 +/- 0.058 45 | Evaluate (epoch 44) -- logp(x) = -3.470 +/- 0.057 46 | Evaluate (epoch 45) -- logp(x) = -3.469 +/- 0.055 47 | Evaluate (epoch 46) -- logp(x) = -3.497 +/- 0.062 48 | Evaluate (epoch 47) -- logp(x) = -3.537 +/- 0.056 49 | Evaluate (epoch 48) -- logp(x) = -3.484 +/- 0.062 50 | Evaluate (epoch 49) -- logp(x) = -3.472 +/- 0.058 51 | Evaluate (epoch 50) -- logp(x) = -3.473 +/- 0.055 52 | Evaluate (epoch 51) -- logp(x) = -3.467 +/- 0.058 53 | Evaluate (epoch 52) -- logp(x) = -3.474 +/- 0.059 54 | Evaluate (epoch 53) -- logp(x) = -3.472 +/- 0.056 55 | Evaluate (epoch 54) -- logp(x) = -3.478 +/- 0.058 56 | Evaluate (epoch 55) -- logp(x) = -3.481 +/- 0.056 57 | Evaluate (epoch 56) -- logp(x) = -3.490 +/- 0.064 58 | Evaluate (epoch 57) -- logp(x) = -3.484 +/- 0.055 59 | Evaluate (epoch 58) -- logp(x) = -3.470 +/- 0.057 60 | Evaluate (epoch 59) -- logp(x) = -3.487 +/- 0.055 61 | Evaluate (epoch 60) -- logp(x) = -3.466 +/- 0.058 62 | Evaluate (epoch 61) -- logp(x) = -3.474 +/- 0.055 63 | Evaluate (epoch 62) -- logp(x) = -3.472 +/- 0.060 64 | Evaluate (epoch 63) -- logp(x) = -3.467 +/- 0.059 65 | Evaluate (epoch 64) -- logp(x) = -3.478 +/- 0.056 66 | Evaluate (epoch 65) -- logp(x) = -3.465 +/- 0.057 67 | Evaluate (epoch 66) -- logp(x) = -3.475 +/- 0.057 68 | Evaluate (epoch 67) -- logp(x) = -3.471 +/- 0.059 69 | Evaluate (epoch 68) -- logp(x) = -3.458 +/- 0.056 70 | Evaluate (epoch 69) -- logp(x) = -3.466 +/- 0.056 71 | Evaluate (epoch 70) -- logp(x) = -3.490 +/- 0.055 72 | Evaluate (epoch 71) -- logp(x) = -3.477 +/- 0.060 73 | Evaluate (epoch 72) -- logp(x) = -3.471 +/- 0.057 74 | Evaluate (epoch 73) -- logp(x) = -3.482 +/- 0.056 75 | Evaluate (epoch 74) -- logp(x) = -3.466 +/- 0.057 76 | Evaluate (epoch 75) -- logp(x) = -3.476 +/- 0.058 77 | Evaluate (epoch 76) -- logp(x) = -3.466 +/- 0.058 78 | Evaluate (epoch 77) -- logp(x) = -3.475 +/- 0.057 79 | Evaluate (epoch 78) -- logp(x) = -3.469 +/- 0.058 80 | Evaluate (epoch 79) -- logp(x) = -3.463 +/- 0.059 81 | Evaluate (epoch 80) -- logp(x) = -3.477 +/- 0.056 82 | Evaluate (epoch 81) -- logp(x) = -3.474 +/- 0.059 83 | Evaluate (epoch 82) -- logp(x) = -3.468 +/- 0.057 84 | Evaluate (epoch 83) -- logp(x) = -3.465 +/- 0.058 85 | Evaluate (epoch 84) -- logp(x) = -3.462 +/- 0.056 86 | Evaluate (epoch 85) -- logp(x) = -3.477 +/- 0.061 87 | Evaluate (epoch 86) -- logp(x) = -3.468 +/- 0.056 88 | Evaluate (epoch 87) -- logp(x) = -3.465 +/- 0.057 89 | Evaluate (epoch 88) -- logp(x) = -3.464 +/- 0.059 90 | Evaluate (epoch 89) -- logp(x) = -3.466 +/- 0.057 91 | Evaluate (epoch 90) -- logp(x) = -3.469 +/- 0.057 92 | Evaluate (epoch 91) -- logp(x) = -3.488 +/- 0.056 93 | Evaluate (epoch 92) -- logp(x) = -3.466 +/- 0.058 94 | Evaluate (epoch 93) -- logp(x) = -3.473 +/- 0.059 95 | Evaluate (epoch 94) -- logp(x) = -3.465 +/- 0.058 96 | Evaluate (epoch 95) -- logp(x) = -3.474 +/- 0.062 97 | Evaluate (epoch 96) -- logp(x) = -3.471 +/- 0.058 98 | Evaluate (epoch 97) -- logp(x) = -3.473 +/- 0.057 99 | Evaluate (epoch 98) -- logp(x) = -3.464 +/- 0.057 100 | Evaluate (epoch 99) -- logp(x) = -3.461 +/- 0.056 101 | -------------------------------------------------------------------------------- /src/flows/results/uci_breast_cancer_maf/results.txt: -------------------------------------------------------------------------------- 1 | Evaluate (epoch 0) -- logp(x) = -12.644 +/- 0.407 2 | Evaluate (epoch 1) -- logp(x) = -12.440 +/- 0.392 3 | Evaluate (epoch 2) -- logp(x) = -12.295 +/- 0.387 4 | Evaluate (epoch 3) -- logp(x) = -12.174 +/- 0.380 5 | Evaluate (epoch 4) -- logp(x) = -12.059 +/- 0.378 6 | Evaluate (epoch 5) -- logp(x) = -11.944 +/- 0.382 7 | Evaluate (epoch 6) -- logp(x) = -11.826 +/- 0.390 8 | Evaluate (epoch 7) -- logp(x) = -11.712 +/- 0.393 9 | Evaluate (epoch 8) -- logp(x) = -11.603 +/- 0.395 10 | Evaluate (epoch 9) -- logp(x) = -11.494 +/- 0.391 11 | Evaluate (epoch 10) -- logp(x) = -11.385 +/- 0.395 12 | Evaluate (epoch 11) -- logp(x) = -11.280 +/- 0.397 13 | Evaluate (epoch 12) -- logp(x) = -11.164 +/- 0.410 14 | Evaluate (epoch 13) -- logp(x) = -11.052 +/- 0.417 15 | Evaluate (epoch 14) -- logp(x) = -10.952 +/- 0.414 16 | Evaluate (epoch 15) -- logp(x) = -10.844 +/- 0.419 17 | Evaluate (epoch 16) -- logp(x) = -10.739 +/- 0.427 18 | Evaluate (epoch 17) -- logp(x) = -10.627 +/- 0.444 19 | Evaluate (epoch 18) -- logp(x) = -10.530 +/- 0.441 20 | Evaluate (epoch 19) -- logp(x) = -10.433 +/- 0.446 21 | Evaluate (epoch 20) -- logp(x) = -10.335 +/- 0.451 22 | Evaluate (epoch 21) -- logp(x) = -10.228 +/- 0.470 23 | Evaluate (epoch 22) -- logp(x) = -10.135 +/- 0.477 24 | Evaluate (epoch 23) -- logp(x) = -10.037 +/- 0.488 25 | Evaluate (epoch 24) -- logp(x) = -9.943 +/- 0.498 26 | Evaluate (epoch 25) -- logp(x) = -9.856 +/- 0.499 27 | Evaluate (epoch 26) -- logp(x) = -9.764 +/- 0.523 28 | Evaluate (epoch 27) -- logp(x) = -9.675 +/- 0.533 29 | Evaluate (epoch 28) -- logp(x) = -9.587 +/- 0.550 30 | Evaluate (epoch 29) -- logp(x) = -9.507 +/- 0.565 31 | Evaluate (epoch 30) -- logp(x) = -9.423 +/- 0.562 32 | Evaluate (epoch 31) -- logp(x) = -9.345 +/- 0.568 33 | Evaluate (epoch 32) -- logp(x) = -9.259 +/- 0.595 34 | Evaluate (epoch 33) -- logp(x) = -9.180 +/- 0.591 35 | Evaluate (epoch 34) -- logp(x) = -9.105 +/- 0.594 36 | Evaluate (epoch 35) -- logp(x) = -9.023 +/- 0.601 37 | Evaluate (epoch 36) -- logp(x) = -8.948 +/- 0.634 38 | Evaluate (epoch 37) -- logp(x) = -8.869 +/- 0.614 39 | Evaluate (epoch 38) -- logp(x) = -8.790 +/- 0.607 40 | Evaluate (epoch 39) -- logp(x) = -8.708 +/- 0.610 41 | Evaluate (epoch 40) -- logp(x) = -8.625 +/- 0.622 42 | Evaluate (epoch 41) -- logp(x) = -8.546 +/- 0.608 43 | Evaluate (epoch 42) -- logp(x) = -8.463 +/- 0.609 44 | Evaluate (epoch 43) -- logp(x) = -8.374 +/- 0.621 45 | Evaluate (epoch 44) -- logp(x) = -8.286 +/- 0.627 46 | Evaluate (epoch 45) -- logp(x) = -8.203 +/- 0.627 47 | Evaluate (epoch 46) -- logp(x) = -8.115 +/- 0.626 48 | Evaluate (epoch 47) -- logp(x) = -8.034 +/- 0.618 49 | Evaluate (epoch 48) -- logp(x) = -7.923 +/- 0.623 50 | Evaluate (epoch 49) -- logp(x) = -7.823 +/- 0.625 51 | Evaluate (epoch 50) -- logp(x) = -7.727 +/- 0.625 52 | Evaluate (epoch 51) -- logp(x) = -7.644 +/- 0.629 53 | Evaluate (epoch 52) -- logp(x) = -7.553 +/- 0.626 54 | Evaluate (epoch 53) -- logp(x) = -7.467 +/- 0.631 55 | Evaluate (epoch 54) -- logp(x) = -7.404 +/- 0.635 56 | Evaluate (epoch 55) -- logp(x) = -7.308 +/- 0.639 57 | Evaluate (epoch 56) -- logp(x) = -7.218 +/- 0.645 58 | Evaluate (epoch 57) -- logp(x) = -7.141 +/- 0.649 59 | Evaluate (epoch 58) -- logp(x) = -7.060 +/- 0.649 60 | Evaluate (epoch 59) -- logp(x) = -7.060 +/- 0.658 61 | Evaluate (epoch 60) -- logp(x) = -6.907 +/- 0.660 62 | Evaluate (epoch 61) -- logp(x) = -6.828 +/- 0.667 63 | Evaluate (epoch 62) -- logp(x) = -6.753 +/- 0.668 64 | Evaluate (epoch 63) -- logp(x) = -6.682 +/- 0.668 65 | Evaluate (epoch 64) -- logp(x) = -6.620 +/- 0.679 66 | Evaluate (epoch 65) -- logp(x) = -6.678 +/- 0.686 67 | Evaluate (epoch 66) -- logp(x) = -6.522 +/- 0.683 68 | Evaluate (epoch 67) -- logp(x) = -6.464 +/- 0.689 69 | Evaluate (epoch 68) -- logp(x) = -6.391 +/- 0.692 70 | Evaluate (epoch 69) -- logp(x) = -6.315 +/- 0.698 71 | Evaluate (epoch 70) -- logp(x) = -6.272 +/- 0.699 72 | Evaluate (epoch 71) -- logp(x) = -6.192 +/- 0.709 73 | Evaluate (epoch 72) -- logp(x) = -6.152 +/- 0.712 74 | Evaluate (epoch 73) -- logp(x) = -6.084 +/- 0.711 75 | Evaluate (epoch 74) -- logp(x) = -6.048 +/- 0.714 76 | Evaluate (epoch 75) -- logp(x) = -6.007 +/- 0.722 77 | Evaluate (epoch 76) -- logp(x) = -5.936 +/- 0.726 78 | Evaluate (epoch 77) -- logp(x) = -5.862 +/- 0.732 79 | Evaluate (epoch 78) -- logp(x) = -5.785 +/- 0.738 80 | Evaluate (epoch 79) -- logp(x) = -5.724 +/- 0.735 81 | Evaluate (epoch 80) -- logp(x) = -5.674 +/- 0.734 82 | Evaluate (epoch 81) -- logp(x) = -5.579 +/- 0.739 83 | Evaluate (epoch 82) -- logp(x) = -5.515 +/- 0.740 84 | Evaluate (epoch 83) -- logp(x) = -5.469 +/- 0.741 85 | Evaluate (epoch 84) -- logp(x) = -5.422 +/- 0.739 86 | Evaluate (epoch 85) -- logp(x) = -5.408 +/- 0.733 87 | Evaluate (epoch 86) -- logp(x) = -5.302 +/- 0.743 88 | Evaluate (epoch 87) -- logp(x) = -5.257 +/- 0.742 89 | Evaluate (epoch 88) -- logp(x) = -5.309 +/- 0.725 90 | Evaluate (epoch 89) -- logp(x) = -5.219 +/- 0.733 91 | Evaluate (epoch 90) -- logp(x) = -5.091 +/- 0.746 92 | Evaluate (epoch 91) -- logp(x) = -5.058 +/- 0.748 93 | Evaluate (epoch 92) -- logp(x) = -4.988 +/- 0.749 94 | Evaluate (epoch 93) -- logp(x) = -4.942 +/- 0.752 95 | Evaluate (epoch 94) -- logp(x) = -4.855 +/- 0.756 96 | Evaluate (epoch 95) -- logp(x) = -4.786 +/- 0.760 97 | Evaluate (epoch 96) -- logp(x) = -4.751 +/- 0.758 98 | Evaluate (epoch 97) -- logp(x) = -4.696 +/- 0.767 99 | Evaluate (epoch 98) -- logp(x) = -4.616 +/- 0.769 100 | Evaluate (epoch 99) -- logp(x) = -4.658 +/- 0.764 101 | -------------------------------------------------------------------------------- /src/flows/results/uci_wine_quality_maf/results.txt: -------------------------------------------------------------------------------- 1 | Evaluate (epoch 0) -- logp(x) = -15.039 +/- 0.168 2 | Evaluate (epoch 1) -- logp(x) = -14.559 +/- 0.155 3 | Evaluate (epoch 2) -- logp(x) = -14.117 +/- 0.140 4 | Evaluate (epoch 3) -- logp(x) = -13.707 +/- 0.154 5 | Evaluate (epoch 4) -- logp(x) = -13.279 +/- 0.138 6 | Evaluate (epoch 5) -- logp(x) = -12.837 +/- 0.122 7 | Evaluate (epoch 6) -- logp(x) = -12.378 +/- 0.117 8 | Evaluate (epoch 7) -- logp(x) = -11.996 +/- 0.118 9 | Evaluate (epoch 8) -- logp(x) = -11.694 +/- 0.115 10 | Evaluate (epoch 9) -- logp(x) = -11.472 +/- 0.115 11 | Evaluate (epoch 10) -- logp(x) = -11.287 +/- 0.110 12 | Evaluate (epoch 11) -- logp(x) = -11.164 +/- 0.105 13 | Evaluate (epoch 12) -- logp(x) = -11.046 +/- 0.101 14 | Evaluate (epoch 13) -- logp(x) = -10.932 +/- 0.102 15 | Evaluate (epoch 14) -- logp(x) = -10.832 +/- 0.097 16 | Evaluate (epoch 15) -- logp(x) = -10.709 +/- 0.101 17 | Evaluate (epoch 16) -- logp(x) = -10.649 +/- 0.098 18 | Evaluate (epoch 17) -- logp(x) = -10.582 +/- 0.096 19 | Evaluate (epoch 18) -- logp(x) = -10.486 +/- 0.095 20 | Evaluate (epoch 19) -- logp(x) = -10.426 +/- 0.095 21 | Evaluate (epoch 20) -- logp(x) = -10.336 +/- 0.097 22 | Evaluate (epoch 21) -- logp(x) = -10.283 +/- 0.092 23 | Evaluate (epoch 22) -- logp(x) = -10.245 +/- 0.092 24 | Evaluate (epoch 23) -- logp(x) = -10.204 +/- 0.091 25 | Evaluate (epoch 24) -- logp(x) = -10.186 +/- 0.093 26 | Evaluate (epoch 25) -- logp(x) = -10.137 +/- 0.089 27 | Evaluate (epoch 26) -- logp(x) = -10.050 +/- 0.089 28 | Evaluate (epoch 27) -- logp(x) = -10.022 +/- 0.089 29 | Evaluate (epoch 28) -- logp(x) = -9.970 +/- 0.090 30 | Evaluate (epoch 29) -- logp(x) = -9.975 +/- 0.086 31 | Evaluate (epoch 30) -- logp(x) = -9.932 +/- 0.088 32 | Evaluate (epoch 31) -- logp(x) = -9.895 +/- 0.086 33 | Evaluate (epoch 32) -- logp(x) = -9.874 +/- 0.086 34 | Evaluate (epoch 33) -- logp(x) = -9.839 +/- 0.087 35 | Evaluate (epoch 34) -- logp(x) = -9.848 +/- 0.087 36 | Evaluate (epoch 35) -- logp(x) = -9.803 +/- 0.087 37 | Evaluate (epoch 36) -- logp(x) = -9.771 +/- 0.086 38 | Evaluate (epoch 37) -- logp(x) = -9.758 +/- 0.085 39 | Evaluate (epoch 38) -- logp(x) = -9.795 +/- 0.086 40 | Evaluate (epoch 39) -- logp(x) = -9.715 +/- 0.085 41 | Evaluate (epoch 40) -- logp(x) = -9.679 +/- 0.086 42 | Evaluate (epoch 41) -- logp(x) = -9.701 +/- 0.083 43 | Evaluate (epoch 42) -- logp(x) = -9.648 +/- 0.084 44 | Evaluate (epoch 43) -- logp(x) = -9.665 +/- 0.084 45 | Evaluate (epoch 44) -- logp(x) = -9.636 +/- 0.084 46 | Evaluate (epoch 45) -- logp(x) = -9.662 +/- 0.084 47 | Evaluate (epoch 46) -- logp(x) = -9.622 +/- 0.083 48 | Evaluate (epoch 47) -- logp(x) = -9.646 +/- 0.082 49 | Evaluate (epoch 48) -- logp(x) = -9.541 +/- 0.083 50 | Evaluate (epoch 49) -- logp(x) = -9.565 +/- 0.082 51 | Evaluate (epoch 50) -- logp(x) = -9.514 +/- 0.083 52 | Evaluate (epoch 51) -- logp(x) = -9.530 +/- 0.082 53 | Evaluate (epoch 52) -- logp(x) = -9.570 +/- 0.082 54 | Evaluate (epoch 53) -- logp(x) = -9.475 +/- 0.083 55 | Evaluate (epoch 54) -- logp(x) = -9.462 +/- 0.083 56 | Evaluate (epoch 55) -- logp(x) = -9.546 +/- 0.081 57 | Evaluate (epoch 56) -- logp(x) = -9.460 +/- 0.082 58 | Evaluate (epoch 57) -- logp(x) = -9.439 +/- 0.081 59 | Evaluate (epoch 58) -- logp(x) = -9.429 +/- 0.082 60 | Evaluate (epoch 59) -- logp(x) = -9.462 +/- 0.081 61 | Evaluate (epoch 60) -- logp(x) = -9.433 +/- 0.083 62 | Evaluate (epoch 61) -- logp(x) = -9.414 +/- 0.081 63 | Evaluate (epoch 62) -- logp(x) = -9.384 +/- 0.081 64 | Evaluate (epoch 63) -- logp(x) = -9.383 +/- 0.081 65 | Evaluate (epoch 64) -- logp(x) = -9.386 +/- 0.082 66 | Evaluate (epoch 65) -- logp(x) = -9.361 +/- 0.082 67 | Evaluate (epoch 66) -- logp(x) = -9.326 +/- 0.082 68 | Evaluate (epoch 67) -- logp(x) = -9.340 +/- 0.083 69 | Evaluate (epoch 68) -- logp(x) = -9.350 +/- 0.082 70 | Evaluate (epoch 69) -- logp(x) = -9.320 +/- 0.081 71 | Evaluate (epoch 70) -- logp(x) = -9.318 +/- 0.080 72 | Evaluate (epoch 71) -- logp(x) = -9.474 +/- 0.080 73 | Evaluate (epoch 72) -- logp(x) = -9.283 +/- 0.081 74 | Evaluate (epoch 73) -- logp(x) = -9.335 +/- 0.080 75 | Evaluate (epoch 74) -- logp(x) = -9.277 +/- 0.082 76 | Evaluate (epoch 75) -- logp(x) = -9.263 +/- 0.082 77 | Evaluate (epoch 76) -- logp(x) = -9.255 +/- 0.081 78 | Evaluate (epoch 77) -- logp(x) = -9.228 +/- 0.081 79 | Evaluate (epoch 78) -- logp(x) = -9.248 +/- 0.080 80 | Evaluate (epoch 79) -- logp(x) = -9.240 +/- 0.080 81 | Evaluate (epoch 80) -- logp(x) = -9.211 +/- 0.080 82 | Evaluate (epoch 81) -- logp(x) = -9.225 +/- 0.082 83 | Evaluate (epoch 82) -- logp(x) = -9.228 +/- 0.080 84 | Evaluate (epoch 83) -- logp(x) = -9.235 +/- 0.081 85 | Evaluate (epoch 84) -- logp(x) = -9.240 +/- 0.080 86 | Evaluate (epoch 85) -- logp(x) = -9.175 +/- 0.080 87 | Evaluate (epoch 86) -- logp(x) = -9.184 +/- 0.080 88 | Evaluate (epoch 87) -- logp(x) = -9.180 +/- 0.079 89 | Evaluate (epoch 88) -- logp(x) = -9.149 +/- 0.081 90 | Evaluate (epoch 89) -- logp(x) = -9.152 +/- 0.080 91 | Evaluate (epoch 90) -- logp(x) = -9.138 +/- 0.079 92 | Evaluate (epoch 91) -- logp(x) = -9.137 +/- 0.079 93 | Evaluate (epoch 92) -- logp(x) = -9.201 +/- 0.079 94 | Evaluate (epoch 93) -- logp(x) = -9.167 +/- 0.079 95 | Evaluate (epoch 94) -- logp(x) = -9.114 +/- 0.080 96 | Evaluate (epoch 95) -- logp(x) = -9.108 +/- 0.079 97 | Evaluate (epoch 96) -- logp(x) = -9.124 +/- 0.079 98 | Evaluate (epoch 97) -- logp(x) = -9.117 +/- 0.080 99 | Evaluate (epoch 98) -- logp(x) = -9.100 +/- 0.078 100 | Evaluate (epoch 99) -- logp(x) = -9.099 +/- 0.080 101 | -------------------------------------------------------------------------------- /src/flows/results/uci_blood_transfusion_maf/results.txt: -------------------------------------------------------------------------------- 1 | Evaluate (epoch 0) -- logp(x) = -5.561 +/- 0.229 2 | Evaluate (epoch 1) -- logp(x) = -5.485 +/- 0.223 3 | Evaluate (epoch 2) -- logp(x) = -5.412 +/- 0.220 4 | Evaluate (epoch 3) -- logp(x) = -5.340 +/- 0.220 5 | Evaluate (epoch 4) -- logp(x) = -5.282 +/- 0.211 6 | Evaluate (epoch 5) -- logp(x) = -5.222 +/- 0.208 7 | Evaluate (epoch 6) -- logp(x) = -5.157 +/- 0.207 8 | Evaluate (epoch 7) -- logp(x) = -5.093 +/- 0.207 9 | Evaluate (epoch 8) -- logp(x) = -5.032 +/- 0.204 10 | Evaluate (epoch 9) -- logp(x) = -4.973 +/- 0.201 11 | Evaluate (epoch 10) -- logp(x) = -4.905 +/- 0.206 12 | Evaluate (epoch 11) -- logp(x) = -4.846 +/- 0.201 13 | Evaluate (epoch 12) -- logp(x) = -4.786 +/- 0.199 14 | Evaluate (epoch 13) -- logp(x) = -4.726 +/- 0.196 15 | Evaluate (epoch 14) -- logp(x) = -4.658 +/- 0.198 16 | Evaluate (epoch 15) -- logp(x) = -4.592 +/- 0.199 17 | Evaluate (epoch 16) -- logp(x) = -4.531 +/- 0.193 18 | Evaluate (epoch 17) -- logp(x) = -4.468 +/- 0.191 19 | Evaluate (epoch 18) -- logp(x) = -4.398 +/- 0.192 20 | Evaluate (epoch 19) -- logp(x) = -4.328 +/- 0.192 21 | Evaluate (epoch 20) -- logp(x) = -4.260 +/- 0.191 22 | Evaluate (epoch 21) -- logp(x) = -4.188 +/- 0.191 23 | Evaluate (epoch 22) -- logp(x) = -4.114 +/- 0.192 24 | Evaluate (epoch 23) -- logp(x) = -4.037 +/- 0.192 25 | Evaluate (epoch 24) -- logp(x) = -3.962 +/- 0.189 26 | Evaluate (epoch 25) -- logp(x) = -3.880 +/- 0.190 27 | Evaluate (epoch 26) -- logp(x) = -3.798 +/- 0.188 28 | Evaluate (epoch 27) -- logp(x) = -3.706 +/- 0.190 29 | Evaluate (epoch 28) -- logp(x) = -3.609 +/- 0.192 30 | Evaluate (epoch 29) -- logp(x) = -3.509 +/- 0.194 31 | Evaluate (epoch 30) -- logp(x) = -3.402 +/- 0.196 32 | Evaluate (epoch 31) -- logp(x) = -3.294 +/- 0.198 33 | Evaluate (epoch 32) -- logp(x) = -3.184 +/- 0.198 34 | Evaluate (epoch 33) -- logp(x) = -3.065 +/- 0.199 35 | Evaluate (epoch 34) -- logp(x) = -2.946 +/- 0.200 36 | Evaluate (epoch 35) -- logp(x) = -2.818 +/- 0.203 37 | Evaluate (epoch 36) -- logp(x) = -2.693 +/- 0.204 38 | Evaluate (epoch 37) -- logp(x) = -2.558 +/- 0.207 39 | Evaluate (epoch 38) -- logp(x) = -2.419 +/- 0.210 40 | Evaluate (epoch 39) -- logp(x) = -2.284 +/- 0.214 41 | Evaluate (epoch 40) -- logp(x) = -2.142 +/- 0.214 42 | Evaluate (epoch 41) -- logp(x) = -2.001 +/- 0.216 43 | Evaluate (epoch 42) -- logp(x) = -1.869 +/- 0.215 44 | Evaluate (epoch 43) -- logp(x) = -1.743 +/- 0.212 45 | Evaluate (epoch 44) -- logp(x) = -1.574 +/- 0.215 46 | Evaluate (epoch 45) -- logp(x) = -1.398 +/- 0.215 47 | Evaluate (epoch 46) -- logp(x) = -1.331 +/- 0.206 48 | Evaluate (epoch 47) -- logp(x) = -1.152 +/- 0.216 49 | Evaluate (epoch 48) -- logp(x) = -0.972 +/- 0.211 50 | Evaluate (epoch 49) -- logp(x) = -0.860 +/- 0.210 51 | Evaluate (epoch 50) -- logp(x) = -0.772 +/- 0.213 52 | Evaluate (epoch 51) -- logp(x) = -0.803 +/- 0.202 53 | Evaluate (epoch 52) -- logp(x) = -0.578 +/- 0.209 54 | Evaluate (epoch 53) -- logp(x) = -0.533 +/- 0.209 55 | Evaluate (epoch 54) -- logp(x) = -0.334 +/- 0.212 56 | Evaluate (epoch 55) -- logp(x) = -0.542 +/- 0.196 57 | Evaluate (epoch 56) -- logp(x) = -0.397 +/- 0.200 58 | Evaluate (epoch 57) -- logp(x) = -0.172 +/- 0.214 59 | Evaluate (epoch 58) -- logp(x) = -0.217 +/- 0.209 60 | Evaluate (epoch 59) -- logp(x) = -0.193 +/- 0.204 61 | Evaluate (epoch 60) -- logp(x) = -0.120 +/- 0.211 62 | Evaluate (epoch 61) -- logp(x) = -0.120 +/- 0.217 63 | Evaluate (epoch 62) -- logp(x) = 0.049 +/- 0.211 64 | Evaluate (epoch 63) -- logp(x) = -0.092 +/- 0.209 65 | Evaluate (epoch 64) -- logp(x) = -0.093 +/- 0.199 66 | Evaluate (epoch 65) -- logp(x) = -0.133 +/- 0.203 67 | Evaluate (epoch 66) -- logp(x) = -0.401 +/- 0.192 68 | Evaluate (epoch 67) -- logp(x) = -0.011 +/- 0.219 69 | Evaluate (epoch 68) -- logp(x) = 0.067 +/- 0.204 70 | Evaluate (epoch 69) -- logp(x) = 0.031 +/- 0.214 71 | Evaluate (epoch 70) -- logp(x) = 0.087 +/- 0.208 72 | Evaluate (epoch 71) -- logp(x) = 0.010 +/- 0.203 73 | Evaluate (epoch 72) -- logp(x) = 0.118 +/- 0.209 74 | Evaluate (epoch 73) -- logp(x) = 0.188 +/- 0.208 75 | Evaluate (epoch 74) -- logp(x) = 0.020 +/- 0.204 76 | Evaluate (epoch 75) -- logp(x) = 0.199 +/- 0.211 77 | Evaluate (epoch 76) -- logp(x) = 0.049 +/- 0.208 78 | Evaluate (epoch 77) -- logp(x) = 0.219 +/- 0.201 79 | Evaluate (epoch 78) -- logp(x) = 0.421 +/- 0.214 80 | Evaluate (epoch 79) -- logp(x) = 0.410 +/- 0.213 81 | Evaluate (epoch 80) -- logp(x) = 0.389 +/- 0.214 82 | Evaluate (epoch 81) -- logp(x) = 0.269 +/- 0.211 83 | Evaluate (epoch 82) -- logp(x) = 0.158 +/- 0.190 84 | Evaluate (epoch 83) -- logp(x) = 0.265 +/- 0.199 85 | Evaluate (epoch 84) -- logp(x) = 0.407 +/- 0.212 86 | Evaluate (epoch 85) -- logp(x) = 0.209 +/- 0.216 87 | Evaluate (epoch 86) -- logp(x) = 0.207 +/- 0.219 88 | Evaluate (epoch 87) -- logp(x) = 0.294 +/- 0.214 89 | Evaluate (epoch 88) -- logp(x) = 0.353 +/- 0.203 90 | Evaluate (epoch 89) -- logp(x) = 0.601 +/- 0.220 91 | Evaluate (epoch 90) -- logp(x) = 0.567 +/- 0.214 92 | Evaluate (epoch 91) -- logp(x) = 0.542 +/- 0.220 93 | Evaluate (epoch 92) -- logp(x) = 0.623 +/- 0.219 94 | Evaluate (epoch 93) -- logp(x) = 0.658 +/- 0.217 95 | Evaluate (epoch 94) -- logp(x) = 0.711 +/- 0.216 96 | Evaluate (epoch 95) -- logp(x) = 0.656 +/- 0.211 97 | Evaluate (epoch 96) -- logp(x) = 0.360 +/- 0.210 98 | Evaluate (epoch 97) -- logp(x) = 0.671 +/- 0.219 99 | Evaluate (epoch 98) -- logp(x) = 0.715 +/- 0.221 100 | Evaluate (epoch 99) -- logp(x) = 0.705 +/- 0.221 101 | Evaluate (epoch 0) -- logp(x) = -5.561 +/- 0.229 102 | Evaluate (epoch 1) -- logp(x) = -5.485 +/- 0.223 103 | Evaluate (epoch 2) -- logp(x) = -5.412 +/- 0.220 104 | Evaluate (epoch 3) -- logp(x) = -5.340 +/- 0.220 105 | Evaluate (epoch 4) -- logp(x) = -5.282 +/- 0.211 106 | Evaluate (epoch 5) -- logp(x) = -5.222 +/- 0.208 107 | -------------------------------------------------------------------------------- /src/classification/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import traceback 3 | import time 4 | import shutil 5 | import logging 6 | import yaml 7 | import sys 8 | import os 9 | import torch 10 | import numpy as np 11 | sys.path.append(os.path.abspath(os.getcwd())) 12 | from src.classification.trainers import * 13 | import getpass 14 | 15 | 16 | def parse_args_and_config(): 17 | parser = argparse.ArgumentParser(description=globals()['__doc__']) 18 | 19 | parser.add_argument('--trainer', type=str, default='Classifier', help='The runner to execute') 20 | parser.add_argument('--config', type=str, default='mnist/resnet.yml', help='Path to the config file') 21 | parser.add_argument('--seed', type=int, default=7777, help='Random seed') 22 | parser.add_argument('--run', type=str, default='run', help='Path for saving running related data.') 23 | parser.add_argument('--doc', type=str, default='0', help='A string for documentation purpose') 24 | parser.add_argument('--comment', type=str, default='', help='A string for experiment comment') 25 | parser.add_argument('--verbose', type=str, default='info', help='Verbose level: info | debug | warning | critical') 26 | parser.add_argument('--test', action='store_true', help='Whether to test the model') 27 | parser.add_argument('--resume_training', action='store_true', help='Whether to resume training') 28 | 29 | args = parser.parse_args() 30 | run_id = str(os.getpid()) 31 | run_time = time.strftime('%Y-%b-%d-%H-%M-%S') 32 | # args.doc = '_'.join([args.doc, run_id, run_time]) 33 | args.log = os.path.join(args.run, 'logs', args.doc) 34 | 35 | # parse config file 36 | if not args.test: 37 | with open(os.path.join('src/classification/configs', args.config), 'r') as f: 38 | config = yaml.load(f) 39 | new_config = dict2namespace(config) 40 | else: 41 | with open(os.path.join(args.log, 'config.yml'), 'r') as f: 42 | config = yaml.load(f) 43 | new_config = config 44 | 45 | # make output directories 46 | output_dir = os.path.join( 47 | new_config.training.out_dir, 'results', f'{new_config.training.exp_id}_perc{new_config.data.perc}') 48 | 49 | os.makedirs(output_dir, exist_ok=True) 50 | new_config.out_dir = output_dir 51 | ckpt_dir = os.path.join(new_config.out_dir, 'checkpoints') 52 | os.makedirs(ckpt_dir, exist_ok=True) 53 | 54 | new_config.ckpt_dir = ckpt_dir 55 | if not args.test: 56 | if not args.resume_training: 57 | if os.path.exists(args.log): 58 | shutil.rmtree(args.log) 59 | os.makedirs(args.log) 60 | 61 | with open(os.path.join(args.log, 'config.yml'), 'w') as f: 62 | yaml.dump(new_config, f, default_flow_style=False) 63 | 64 | # setup logger 65 | level = getattr(logging, args.verbose.upper(), None) 66 | if not isinstance(level, int): 67 | raise ValueError('level {} not supported'.format(args.verbose)) 68 | 69 | handler1 = logging.StreamHandler() 70 | handler2 = logging.FileHandler(os.path.join(args.log, 'stdout.txt')) 71 | formatter = logging.Formatter('%(levelname)s - %(filename)s - %(asctime)s - %(message)s') 72 | handler1.setFormatter(formatter) 73 | handler2.setFormatter(formatter) 74 | logger = logging.getLogger() 75 | logger.addHandler(handler1) 76 | logger.addHandler(handler2) 77 | logger.setLevel(level) 78 | 79 | else: 80 | level = getattr(logging, args.verbose.upper(), None) 81 | if not isinstance(level, int): 82 | raise ValueError('level {} not supported'.format(args.verbose)) 83 | 84 | handler1 = logging.StreamHandler() 85 | formatter = logging.Formatter('%(levelname)s - %(filename)s - %(asctime)s - %(message)s') 86 | handler1.setFormatter(formatter) 87 | logger = logging.getLogger() 88 | logger.addHandler(handler1) 89 | logger.setLevel(level) 90 | 91 | # add device 92 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 93 | logging.info("Using device: {}".format(device)) 94 | new_config.device = device 95 | 96 | # set random seed 97 | torch.manual_seed(args.seed) 98 | np.random.seed(args.seed) 99 | if torch.cuda.is_available(): 100 | torch.cuda.manual_seed_all(args.seed) 101 | 102 | torch.backends.cudnn.benchmark = True 103 | 104 | return args, new_config 105 | 106 | 107 | def dict2namespace(config): 108 | namespace = argparse.Namespace() 109 | for key, value in config.items(): 110 | if isinstance(value, dict): 111 | new_value = dict2namespace(value) 112 | else: 113 | new_value = value 114 | setattr(namespace, key, new_value) 115 | return namespace 116 | 117 | 118 | def main(): 119 | args, config = parse_args_and_config() 120 | logging.info("Writing log file to {}".format(args.log)) 121 | logging.info("Exp instance id = {}".format(os.getpid())) 122 | logging.info("Exp comment = {}".format(args.comment)) 123 | logging.info("Config =") 124 | print(">" * 80) 125 | print(config) 126 | print("<" * 80) 127 | 128 | try: 129 | runner = eval(args.trainer)(args, config) 130 | if not args.test: 131 | runner.train() 132 | test_loss, test_acc, test_labels, test_probs, test_ratios = runner.test(runner.test_dataloader, 'test') 133 | runner.clf_diagnostics(test_labels, test_probs, test_ratios, 'test') 134 | except: 135 | logging.error(traceback.format_exc()) 136 | 137 | return 0 138 | 139 | 140 | if __name__ == '__main__': 141 | sys.exit(main()) 142 | -------------------------------------------------------------------------------- /src/classification/models/networks.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Properly implemented ResNet-s for CIFAR10 as described in paper [1]. 3 | 4 | The implementation and structure of this file is hugely influenced by [2] 5 | which is implemented for ImageNet and doesn't have option A for identity. 6 | Moreover, most of the implementations on the web is copy-paste from 7 | torchvision's resnet and has wrong number of params. 8 | 9 | Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following 10 | number of layers and parameters: 11 | 12 | name | layers | params 13 | ResNet20 | 20 | 0.27M 14 | ResNet32 | 32 | 0.46M 15 | ResNet44 | 44 | 0.66M 16 | ResNet56 | 56 | 0.85M 17 | ResNet110 | 110 | 1.7M 18 | ResNet1202| 1202 | 19.4m 19 | 20 | which this implementation indeed has. 21 | 22 | Reference: 23 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 24 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 25 | [2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 26 | 27 | If you use this implementation in you work, please don't forget to mention the 28 | author, Yerlan Idelbayev. 29 | ''' 30 | import torch 31 | import torch.nn as nn 32 | import torch.nn.functional as F 33 | import torch.nn.init as init 34 | 35 | from torch.autograd import Variable 36 | 37 | __all__ = ['ResNet_v2', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202'] 38 | 39 | def _weights_init(m): 40 | classname = m.__class__.__name__ 41 | #print(classname) 42 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 43 | init.kaiming_normal_(m.weight) 44 | 45 | class LambdaLayer(nn.Module): 46 | def __init__(self, lambd): 47 | super(LambdaLayer, self).__init__() 48 | self.lambd = lambd 49 | 50 | def forward(self, x): 51 | return self.lambd(x) 52 | 53 | 54 | class BasicBlock_v2(nn.Module): 55 | expansion = 1 56 | 57 | def __init__(self, in_planes, planes, stride=1, option='A'): 58 | super(BasicBlock_v2, self).__init__() 59 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 60 | self.bn1 = nn.BatchNorm2d(planes) 61 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 62 | self.bn2 = nn.BatchNorm2d(planes) 63 | 64 | self.shortcut = nn.Sequential() 65 | if stride != 1 or in_planes != planes: 66 | if option == 'A': 67 | """ 68 | For CIFAR10 ResNet paper uses option A. 69 | """ 70 | self.shortcut = LambdaLayer(lambda x: 71 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 72 | elif option == 'B': 73 | self.shortcut = nn.Sequential( 74 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 75 | nn.BatchNorm2d(self.expansion * planes) 76 | ) 77 | 78 | def forward(self, x): 79 | out = F.relu(self.bn1(self.conv1(x))) 80 | out = self.bn2(self.conv2(out)) 81 | out += self.shortcut(x) 82 | out = F.relu(out) 83 | return out 84 | 85 | 86 | class ResNet_v2(nn.Module): 87 | def __init__(self, block, num_blocks, num_classes=9): 88 | super(ResNet_v2, self).__init__() 89 | self.in_planes = 16 90 | 91 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 92 | self.bn1 = nn.BatchNorm2d(16) 93 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 94 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 95 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 96 | self.linear = nn.Linear(64, num_classes) 97 | 98 | self.apply(_weights_init) 99 | 100 | def _make_layer(self, block, planes, num_blocks, stride): 101 | strides = [stride] + [1]*(num_blocks-1) 102 | layers = [] 103 | for stride in strides: 104 | layers.append(block(self.in_planes, planes, stride)) 105 | self.in_planes = planes * block.expansion 106 | 107 | return nn.Sequential(*layers) 108 | 109 | def forward(self, x): 110 | out = F.relu(self.bn1(self.conv1(x))) 111 | out = self.layer1(out) 112 | out = self.layer2(out) 113 | out = self.layer3(out) 114 | out = F.avg_pool2d(out, out.size()[3]) 115 | out = out.view(out.size(0), -1) 116 | logits = self.linear(out) 117 | probas = F.softmax(logits, dim=1) 118 | # probas = torch.sigmoid(logits) 119 | return logits, probas 120 | 121 | 122 | def resnet20(): 123 | return ResNet_v2(BasicBlock_v2, [3, 3, 3]) 124 | 125 | 126 | def resnet32(): 127 | return ResNet_v2(BasicBlock_v2, [5, 5, 5]) 128 | 129 | 130 | def resnet44(): 131 | return ResNet_v2(BasicBlock_v2, [7, 7, 7]) 132 | 133 | 134 | def resnet56(): 135 | return ResNet_v2(BasicBlock_v2, [9, 9, 9]) 136 | 137 | 138 | def resnet110(): 139 | return ResNet_v2(BasicBlock_v2, [18, 18, 18]) 140 | 141 | 142 | def resnet1202(): 143 | return ResNet_v2(BasicBlock_v2, [200, 200, 200]) 144 | 145 | 146 | def test(net): 147 | import numpy as np 148 | total_params = 0 149 | 150 | for x in filter(lambda p: p.requires_grad, net.parameters()): 151 | total_params += np.prod(x.data.numpy().shape) 152 | print("Total number of params", total_params) 153 | print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters())))) 154 | 155 | 156 | if __name__ == "__main__": 157 | for net_name in __all__: 158 | if net_name.startswith('resnet'): 159 | print(net_name) 160 | test(globals()[net_name]()) 161 | print() -------------------------------------------------------------------------------- /src/flows/models/glow/glow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.distributions as D 4 | 5 | from models.glow.actnorm import ActNorm 6 | from models.glow.coupling import AffineCoupling 7 | from models.glow.invconv import Invertible1x1Conv 8 | from models.glow.layers import * 9 | 10 | class FlowSequential(nn.Sequential): 11 | """ Container for layers of a normalizing flow """ 12 | def __init__(self, *args, **kwargs): 13 | self.checkpoint_grads = kwargs.pop('checkpoint_grads', None) 14 | super().__init__(*args, **kwargs) 15 | 16 | def forward(self, x): 17 | sum_logdets = 0. 18 | for module in self: 19 | x, logdet = module(x) if not self.checkpoint_grads else checkpoint(module, x) 20 | sum_logdets = sum_logdets + logdet 21 | return x, sum_logdets 22 | 23 | def inverse(self, z): 24 | sum_logdets = 0. 25 | for module in reversed(self): 26 | z, logdet = module.inverse(z) 27 | sum_logdets = sum_logdets + logdet 28 | return z, sum_logdets 29 | 30 | 31 | class FlowStep(FlowSequential): 32 | """ One step of Glow flow (ActNorm -> Invertible 1x1 conv -> Affine coupling); cf Glow Figure 2a """ 33 | def __init__(self, n_channels, width, lu_factorize=False): 34 | super().__init__(ActNorm(param_dim=(1,n_channels,1,1)), 35 | Invertible1x1Conv(n_channels, lu_factorize), 36 | AffineCoupling(n_channels, width)) 37 | 38 | 39 | class FlowLevel(nn.Module): 40 | """ One depth level of Glow flow (Squeeze -> FlowStep x K -> Split); cf Glow figure 2b """ 41 | def __init__(self, n_channels, width, depth, checkpoint_grads=False, lu_factorize=False): 42 | super().__init__() 43 | # network layers 44 | self.squeeze = Squeeze() 45 | self.flowsteps = FlowSequential(*[FlowStep(4*n_channels, width, lu_factorize) for _ in range(depth)], checkpoint_grads=checkpoint_grads) 46 | self.split = Split(4*n_channels) 47 | 48 | def forward(self, x): 49 | x = self.squeeze(x) 50 | x, logdet_flowsteps = self.flowsteps(x) 51 | x1, z2, logdet_split = self.split(x) 52 | logdet = logdet_flowsteps + logdet_split 53 | return x1, z2, logdet 54 | 55 | def inverse(self, x1, z2): 56 | x, logdet_split = self.split.inverse(x1, z2) 57 | x, logdet_flowsteps = self.flowsteps.inverse(x) 58 | x = self.squeeze.inverse(x) 59 | logdet = logdet_flowsteps + logdet_split 60 | return x, logdet 61 | 62 | class Glow(nn.Module): 63 | """ Glow multi-scale architecture with depth of flow K and number of levels L; cf Glow figure 2; section 3""" 64 | def __init__(self, width, depth, n_levels, input_dims=(3,32,32), checkpoint_grads=False, lu_factorize=False): 65 | super().__init__() 66 | # calculate output dims 67 | in_channels, H, W = input_dims 68 | out_channels = int(in_channels * 4**(n_levels+1) / 2**n_levels) # each Squeeze results in 4x in_channels (cf RealNVP section 3.6); each Split in 1/2x in_channels 69 | out_HW = int(H / 2**(n_levels+1)) # each Squeeze is 1/2x HW dim (cf RealNVP section 3.6) 70 | self.output_dims = out_channels, out_HW, out_HW 71 | 72 | # preprocess images 73 | self.preprocess = Preprocess() 74 | 75 | # network layers cf Glow figure 2b: (Squeeze -> FlowStep x depth -> Split) x n_levels -> Squeeze -> FlowStep x depth 76 | self.flowlevels = nn.ModuleList([FlowLevel(in_channels * 2**i, width, depth, checkpoint_grads, lu_factorize) for i in range(n_levels)]) 77 | self.squeeze = Squeeze() 78 | self.flowstep = FlowSequential(*[FlowStep(out_channels, width, lu_factorize) for _ in range(depth)], checkpoint_grads=checkpoint_grads) 79 | 80 | # gaussianize the final z output; initialize to identity 81 | self.gaussianize = Gaussianize(out_channels) 82 | 83 | # base distribution of the flow 84 | self.register_buffer('base_dist_mean', torch.zeros(1)) 85 | self.register_buffer('base_dist_var', torch.ones(1)) 86 | 87 | def forward(self, x): 88 | x, sum_logdets = self.preprocess(x) 89 | # pass through flow 90 | zs = [] 91 | for m in self.flowlevels: 92 | x, z, logdet = m(x) 93 | sum_logdets = sum_logdets + logdet 94 | zs.append(z) 95 | x = self.squeeze(x) 96 | z, logdet = self.flowstep(x) 97 | sum_logdets = sum_logdets + logdet 98 | 99 | # gaussianize the final z 100 | z, logdet = self.gaussianize(torch.zeros_like(z), z) 101 | sum_logdets = sum_logdets + logdet 102 | zs.append(z) 103 | return zs, sum_logdets 104 | 105 | def inverse(self, zs=None, batch_size=None, z_std=1.): 106 | if zs is None: # if no random numbers are passed, generate new from the base distribution 107 | assert batch_size is not None, 'Must either specify batch_size or pass a batch of z random numbers.' 108 | zs = [z_std * self.base_dist.sample((batch_size, *self.output_dims)).squeeze()] 109 | # pass through inverse flow 110 | z, sum_logdets = self.gaussianize.inverse(torch.zeros_like(zs[-1]), zs[-1]) 111 | x, logdet = self.flowstep.inverse(z) 112 | sum_logdets = sum_logdets + logdet 113 | x = self.squeeze.inverse(x) 114 | for i, m in enumerate(reversed(self.flowlevels)): 115 | z = z_std * (self.base_dist.sample(x.shape).squeeze() if len(zs)==1 else zs[-i-2]) # if no z's are passed, generate new random numbers from the base dist 116 | x, logdet = m.inverse(x, z) 117 | sum_logdets = sum_logdets + logdet 118 | # postprocess 119 | x, logdet = self.preprocess.inverse(x) 120 | sum_logdets = sum_logdets + logdet 121 | return x, sum_logdets 122 | 123 | @property 124 | def base_dist(self): 125 | return D.Normal(self.base_dist_mean, self.base_dist_var) 126 | 127 | def log_prob(self, x, bits_per_pixel=False): 128 | zs, logdet = self.forward(x) 129 | log_prob = sum(self.base_dist.log_prob(z).sum([1,2,3]) for z in zs) + logdet 130 | if bits_per_pixel: 131 | log_prob /= (math.log(2) * x[0].numel()) 132 | return log_prob 133 | -------------------------------------------------------------------------------- /src/flows/models/maf/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class MaskedLinear(nn.Linear): 9 | """ MADE building block layer """ 10 | def __init__(self, input_size, n_outputs, mask, cond_label_size=None): 11 | super().__init__(input_size, n_outputs) 12 | 13 | self.register_buffer('mask', mask) 14 | 15 | self.cond_label_size = cond_label_size 16 | if cond_label_size is not None: 17 | self.cond_weight = nn.Parameter(torch.rand(n_outputs, cond_label_size) / math.sqrt(cond_label_size)) 18 | 19 | def forward(self, x, y=None): 20 | out = F.linear(x, self.weight * self.mask, self.bias) 21 | if y is not None: 22 | out = out + F.linear(y, self.cond_weight) 23 | return out 24 | 25 | def extra_repr(self): 26 | return 'in_features={}, out_features={}, bias={}'.format( 27 | self.in_features, self.out_features, self.bias is not None 28 | ) + (self.cond_label_size != None) * ', cond_features={}'.format(self.cond_label_size) 29 | 30 | 31 | class LinearMaskedCoupling(nn.Module): 32 | """ Modified RealNVP Coupling Layers per the MAF paper """ 33 | def __init__(self, input_size, hidden_size, n_hidden, mask, cond_label_size=None): 34 | super().__init__() 35 | 36 | self.register_buffer('mask', mask) 37 | 38 | # scale function 39 | s_net = [nn.Linear(input_size + (cond_label_size if cond_label_size is not None else 0), hidden_size)] 40 | for _ in range(n_hidden): 41 | s_net += [nn.Tanh(), nn.Linear(hidden_size, hidden_size)] 42 | s_net += [nn.Tanh(), nn.Linear(hidden_size, input_size)] 43 | self.s_net = nn.Sequential(*s_net) 44 | 45 | # translation function 46 | self.t_net = copy.deepcopy(self.s_net) 47 | # replace Tanh with ReLU's per MAF paper 48 | for i in range(len(self.t_net)): 49 | if not isinstance(self.t_net[i], nn.Linear): self.t_net[i] = nn.ReLU() 50 | 51 | def forward(self, x, y=None): 52 | # apply mask 53 | mx = x * self.mask 54 | 55 | # run through model 56 | s = self.s_net(mx if y is None else torch.cat([y, mx], dim=1)) 57 | t = self.t_net(mx if y is None else torch.cat([y, mx], dim=1)) 58 | u = mx + (1 - self.mask) * (x - t) * torch.exp(-s) # cf RealNVP eq 8 where u corresponds to x (here we're modeling u) 59 | 60 | log_abs_det_jacobian = - (1 - self.mask) * s # log det du/dx; cf RealNVP 8 and 6; note, sum over input_size done at model log_prob 61 | 62 | return u, log_abs_det_jacobian 63 | 64 | def inverse(self, u, y=None): 65 | # apply mask 66 | mu = u * self.mask 67 | 68 | # run through model 69 | s = self.s_net(mu if y is None else torch.cat([y, mu], dim=1)) 70 | t = self.t_net(mu if y is None else torch.cat([y, mu], dim=1)) 71 | x = mu + (1 - self.mask) * (u * s.exp() + t) # cf RealNVP eq 7 72 | 73 | log_abs_det_jacobian = (1 - self.mask) * s # log det dx/du 74 | 75 | return x, log_abs_det_jacobian 76 | 77 | 78 | class BatchNorm(nn.Module): 79 | """ RealNVP BatchNorm layer """ 80 | def __init__(self, input_size, momentum=0.9, eps=1e-5): 81 | super().__init__() 82 | self.momentum = momentum 83 | self.eps = eps 84 | 85 | self.log_gamma = nn.Parameter(torch.zeros(input_size)) 86 | self.beta = nn.Parameter(torch.zeros(input_size)) 87 | 88 | self.register_buffer('running_mean', torch.zeros(input_size)) 89 | self.register_buffer('running_var', torch.ones(input_size)) 90 | 91 | def forward(self, x, cond_y=None): 92 | if self.training: 93 | self.batch_mean = x.mean(0) 94 | self.batch_var = x.var(0) # note MAF paper uses biased variance estimate; ie x.var(0, unbiased=False) 95 | 96 | # update running mean 97 | self.running_mean.mul_(self.momentum).add_(self.batch_mean.data * (1 - self.momentum)) 98 | self.running_var.mul_(self.momentum).add_(self.batch_var.data * (1 - self.momentum)) 99 | 100 | mean = self.batch_mean 101 | var = self.batch_var 102 | else: 103 | mean = self.running_mean 104 | var = self.running_var 105 | 106 | # compute normalized input (cf original batch norm paper algo 1) 107 | x_hat = (x - mean) / torch.sqrt(var + self.eps) 108 | y = self.log_gamma.exp() * x_hat + self.beta 109 | 110 | # compute log_abs_det_jacobian (cf RealNVP paper) 111 | log_abs_det_jacobian = self.log_gamma - 0.5 * torch.log(var + self.eps) 112 | # print('in sum log var {:6.3f} ; out sum log var {:6.3f}; sum log det {:8.3f}; mean log_gamma {:5.3f}; mean beta {:5.3f}'.format( 113 | # (var + self.eps).log().sum().data.numpy(), y.var(0).log().sum().data.numpy(), log_abs_det_jacobian.mean(0).item(), self.log_gamma.mean(), self.beta.mean())) 114 | return y, log_abs_det_jacobian.expand_as(x) 115 | 116 | def inverse(self, y, cond_y=None): 117 | if self.training: 118 | mean = self.batch_mean 119 | var = self.batch_var 120 | else: 121 | mean = self.running_mean 122 | var = self.running_var 123 | 124 | x_hat = (y - self.beta) * torch.exp(-self.log_gamma) 125 | x = x_hat * torch.sqrt(var + self.eps) + mean 126 | 127 | log_abs_det_jacobian = 0.5 * torch.log(var + self.eps) - self.log_gamma 128 | 129 | return x, log_abs_det_jacobian.expand_as(x) 130 | 131 | 132 | class FlowSequential(nn.Sequential): 133 | """ Container for layers of a normalizing flow """ 134 | def forward(self, x, y): 135 | sum_log_abs_det_jacobians = 0 136 | for module in self: 137 | x, log_abs_det_jacobian = module(x, y) 138 | sum_log_abs_det_jacobians = sum_log_abs_det_jacobians + log_abs_det_jacobian 139 | return x, sum_log_abs_det_jacobians 140 | 141 | def inverse(self, u, y): 142 | sum_log_abs_det_jacobians = 0 143 | for module in reversed(self): 144 | u, log_abs_det_jacobian = module.inverse(u, y) 145 | sum_log_abs_det_jacobians = sum_log_abs_det_jacobians + log_abs_det_jacobian 146 | return u, sum_log_abs_det_jacobians 147 | -------------------------------------------------------------------------------- /src/datasets/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import hashlib 4 | import errno 5 | from torch.utils.model_zoo import tqdm 6 | 7 | 8 | def gen_bar_updater(): 9 | pbar = tqdm(total=None) 10 | 11 | def bar_update(count, block_size, total_size): 12 | if pbar.total is None and total_size: 13 | pbar.total = total_size 14 | progress_bytes = count * block_size 15 | pbar.update(progress_bytes - pbar.n) 16 | 17 | return bar_update 18 | 19 | 20 | def check_integrity(fpath, md5=None): 21 | if md5 is None: 22 | return True 23 | if not os.path.isfile(fpath): 24 | return False 25 | md5o = hashlib.md5() 26 | with open(fpath, 'rb') as f: 27 | # read in 1MB chunks 28 | for chunk in iter(lambda: f.read(1024 * 1024), b''): 29 | md5o.update(chunk) 30 | md5c = md5o.hexdigest() 31 | if md5c != md5: 32 | return False 33 | return True 34 | 35 | 36 | def makedir_exist_ok(dirpath): 37 | """ 38 | Python2 support for os.makedirs(.., exist_ok=True) 39 | """ 40 | try: 41 | os.makedirs(dirpath) 42 | except OSError as e: 43 | if e.errno == errno.EEXIST: 44 | pass 45 | else: 46 | raise 47 | 48 | 49 | def download_url(url, root, filename=None, md5=None): 50 | """Download a file from a url and place it in root. 51 | 52 | Args: 53 | url (str): URL to download file from 54 | root (str): Directory to place downloaded file in 55 | filename (str, optional): Name to save the file under. If None, use the basename of the URL 56 | md5 (str, optional): MD5 checksum of the download. If None, do not check 57 | """ 58 | from six.moves import urllib 59 | 60 | root = os.path.expanduser(root) 61 | if not filename: 62 | filename = os.path.basename(url) 63 | fpath = os.path.join(root, filename) 64 | 65 | makedir_exist_ok(root) 66 | 67 | # downloads file 68 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 69 | print('Using downloaded and verified file: ' + fpath) 70 | else: 71 | try: 72 | print('Downloading ' + url + ' to ' + fpath) 73 | urllib.request.urlretrieve( 74 | url, fpath, 75 | reporthook=gen_bar_updater() 76 | ) 77 | except OSError: 78 | if url[:5] == 'https': 79 | url = url.replace('https:', 'http:') 80 | print('Failed download. Trying https -> http instead.' 81 | ' Downloading ' + url + ' to ' + fpath) 82 | urllib.request.urlretrieve( 83 | url, fpath, 84 | reporthook=gen_bar_updater() 85 | ) 86 | 87 | 88 | def list_dir(root, prefix=False): 89 | """List all directories at a given root 90 | 91 | Args: 92 | root (str): Path to directory whose folders need to be listed 93 | prefix (bool, optional): If true, prepends the path to each result, otherwise 94 | only returns the name of the directories found 95 | """ 96 | root = os.path.expanduser(root) 97 | directories = list( 98 | filter( 99 | lambda p: os.path.isdir(os.path.join(root, p)), 100 | os.listdir(root) 101 | ) 102 | ) 103 | 104 | if prefix is True: 105 | directories = [os.path.join(root, d) for d in directories] 106 | 107 | return directories 108 | 109 | 110 | def list_files(root, suffix, prefix=False): 111 | """List all files ending with a suffix at a given root 112 | 113 | Args: 114 | root (str): Path to directory whose folders need to be listed 115 | suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). 116 | It uses the Python "str.endswith" method and is passed directly 117 | prefix (bool, optional): If true, prepends the path to each result, otherwise 118 | only returns the name of the files found 119 | """ 120 | root = os.path.expanduser(root) 121 | files = list( 122 | filter( 123 | lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix), 124 | os.listdir(root) 125 | ) 126 | ) 127 | 128 | if prefix is True: 129 | files = [os.path.join(root, d) for d in files] 130 | 131 | return files 132 | 133 | 134 | def download_file_from_google_drive(file_id, root, filename=None, md5=None): 135 | """Download a Google Drive file from and place it in root. 136 | 137 | Args: 138 | file_id (str): id of file to be downloaded 139 | root (str): Directory to place downloaded file in 140 | filename (str, optional): Name to save the file under. If None, use the id of the file. 141 | md5 (str, optional): MD5 checksum of the download. If None, do not check 142 | """ 143 | # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url 144 | import requests 145 | url = "https://docs.google.com/uc?export=download" 146 | 147 | root = os.path.expanduser(root) 148 | if not filename: 149 | filename = file_id 150 | fpath = os.path.join(root, filename) 151 | 152 | makedir_exist_ok(root) 153 | 154 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 155 | print('Using downloaded and verified file: ' + fpath) 156 | else: 157 | session = requests.Session() 158 | 159 | response = session.get(url, params={'id': file_id}, stream=True) 160 | token = _get_confirm_token(response) 161 | 162 | if token: 163 | params = {'id': file_id, 'confirm': token} 164 | response = session.get(url, params=params, stream=True) 165 | 166 | _save_response_content(response, fpath) 167 | 168 | 169 | def _get_confirm_token(response): 170 | for key, value in response.cookies.items(): 171 | if key.startswith('download_warning'): 172 | return value 173 | 174 | return None 175 | 176 | 177 | def _save_response_content(response, destination, chunk_size=32768): 178 | with open(destination, "wb") as f: 179 | pbar = tqdm(total=None) 180 | progress = 0 181 | for chunk in response.iter_content(chunk_size): 182 | if chunk: # filter out keep-alive new chunks 183 | f.write(chunk) 184 | progress += len(chunk) 185 | pbar.update(progress - pbar.n) 186 | pbar.close() 187 | -------------------------------------------------------------------------------- /src/datasets/omniglot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset 5 | from .vision import VisionDataset 6 | 7 | DATA_ROOT = '/atlas/u/kechoi/f-dre/data/datasets/omniglot/' 8 | 9 | 10 | def logit_transform(image, lam=1e-6): 11 | image = lam + (1 - 2 * lam) * image 12 | return torch.log(image) - torch.log1p(-image) 13 | 14 | 15 | class OmniglotMixture(VisionDataset): 16 | """ 17 | Dataset class exclusively for training the DRE classifier 18 | (will need a different dataset class for actual downstream classification) 19 | 20 | NOTE: assumes that DAGAN has already been used to generate samples! 21 | """ 22 | def __init__(self, root, 23 | config, 24 | split='train', 25 | target_type='attr', 26 | transform=None, target_transform=None, load_in_mem=False, 27 | download=True, **kwargs): 28 | super(OmniglotMixture, self).__init__(root) 29 | 30 | self.config = config 31 | self.split = split 32 | self.root = root 33 | self.perc = config.data.perc 34 | self.flow = True if config.model.name == 'maf' else False 35 | 36 | # paths to real omniglot data 37 | ref_data = np.load(os.path.join(DATA_ROOT, 'omniglot_data.npy')) 38 | ref_data = ref_data[:1200, :, :, :, :] 39 | # path to samples generated 40 | bias_data = np.load(os.path.join(DATA_ROOT, 'generated_omniglot.npy')).reshape(1200, 100, 28, 28, 1) 41 | bias_data = bias_data / 255. 42 | 43 | self.ref_dataset, self.bias_dataset = self.initialize_data_splits(ref_data, bias_data) 44 | 45 | def initialize_data_splits(self, ref_data, bias_data): 46 | if self.split == 'train': 47 | ref = ref_data[:, 0:10, :, :, :] 48 | bias = bias_data[:, 0:50, :, :, :] 49 | elif self.split == 'val': 50 | ref = ref_data[:, 10:15, :, :, :] 51 | bias = bias_data[:, 50:55, :, :, :] 52 | else: 53 | ref = ref_data[:, 15:, :, :, :] 54 | bias = bias_data[:, 55:60, :, :, :] 55 | 56 | # reshape bc numpy 57 | bias = bias.reshape(-1, 28, 28, 1) 58 | ref = ref.reshape(-1, 28, 28, 1) 59 | 60 | # tensorize 61 | bias = torch.from_numpy(bias).permute((0, 3, 1, 2)).float() 62 | ref = torch.from_numpy(ref).permute((0, 3, 1, 2)).float() 63 | 64 | if self.flow: 65 | print('applying flow transforms in advance...') 66 | bias = self._data_transform(bias) 67 | ref = self._data_transform(ref) 68 | 69 | # pseudolabels 70 | bias_y = torch.zeros(len(bias)) 71 | ref_y = torch.ones(len(ref)) 72 | 73 | # construct dataloaders (data, biased/ref dataset) 74 | # NOTE: not saving actual data labels for now 75 | ref_dataset = torch.utils.data.TensorDataset(ref, ref_y) 76 | bias_dataset = torch.utils.data.TensorDataset(bias, bias_y) 77 | 78 | return ref_dataset, bias_dataset 79 | 80 | def __getitem__(self, index): 81 | """ 82 | Make sure dataset doesn't go out of bounds 83 | """ 84 | bias_item, _ = self.bias_dataset[index] 85 | if index >= len(self.ref_dataset): 86 | index = np.random.choice(len(self.ref_dataset)) 87 | ref_item, _ = self.ref_dataset[index] 88 | 89 | return ref_item, bias_item 90 | 91 | def __len__(self): 92 | # iterate through both at the same time 93 | return len(self.bias_dataset) 94 | 95 | def _data_transform(self, x): 96 | # data is originally between [0,1], so change it back 97 | x = (x * 255).byte() 98 | # performs dequantization, rescaling, then logit transform 99 | x = (x + torch.rand(x.size())) / 256. 100 | x = logit_transform(x) 101 | 102 | return x 103 | 104 | 105 | class Omniglot(VisionDataset): 106 | def __init__(self, root, 107 | config, 108 | split='train', 109 | target_type='attr', 110 | transform=None, target_transform=None, load_in_mem=False, 111 | download=True, synthetic=False, **kwargs): 112 | super(Omniglot, self).__init__(root) 113 | 114 | self.config = config 115 | self.split = split 116 | self.root = root 117 | self.perc = config.data.perc 118 | self.synthetic = synthetic 119 | self.augment = config.data.augment 120 | if self.augment: 121 | print('augmenting real data with synthetic data...') 122 | 123 | if not self.synthetic: 124 | true_data = np.load(os.path.join(DATA_ROOT, 'omniglot_data.npy')) 125 | true_data = true_data[0:1200, :, :, :, :] 126 | data = true_data 127 | else: 128 | # tensorflow shenanigans 129 | print('loading synthetic data for training...') 130 | data = np.load(os.path.join(DATA_ROOT, 131 | 'generated_omniglot.npy')).reshape(1200, 100, 28, 28, 1) 132 | data = data[0:1200, :, :, :, :] 133 | data = data / 255. 134 | true_data = np.load(os.path.join(DATA_ROOT, 'omniglot_data.npy'))[0:1200, :, :, :, :] 135 | if self.split == 'train': 136 | if self.synthetic: 137 | data = data[:, 0:50, :, :, :] 138 | n_labels = 50 139 | else: 140 | data = data[:, 0:10, :, :, :] 141 | n_labels = 10 142 | aux_labels = np.ones(1200*n_labels) # fake is y=1 143 | if self.augment: 144 | real = true_data[:, 0:10, :, :, :] 145 | data = np.hstack([data, real]) 146 | n_labels += 10 147 | aux_labels = np.hstack([aux_labels, np.zeros(1200*10)]) # real is y = 0 148 | labels = np.repeat(np.arange(len(data)), n_labels) 149 | elif self.split == 'val': 150 | data = true_data[:, 10:15, :, :, :] 151 | labels = np.repeat(np.arange(len(data)), 5) 152 | aux_labels = np.zeros(1200*5) 153 | else: 154 | data = true_data[:, 15:, :, :, :] 155 | labels = np.repeat(np.arange(len(data)), 5) 156 | aux_labels = np.zeros(1200*5) 157 | data = data.reshape(-1, 28, 28, 1) 158 | labels = np.vstack([labels, aux_labels]) 159 | self.dataset = torch.from_numpy(data).permute((0, 3, 1, 2)).float() 160 | self.labels = torch.from_numpy(labels).float().permute(1,0) # (2, n_data) 161 | 162 | def __getitem__(self, index): 163 | """ 164 | Make sure dataset doesn't go out of bounds 165 | """ 166 | item = self.dataset[index] 167 | label = self.labels[index] 168 | 169 | return item, label 170 | 171 | def __len__(self): 172 | return len(self.dataset) -------------------------------------------------------------------------------- /src/classification/trainers/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/victoresque/pytorch-template/tree/master/base 3 | """ 4 | import os 5 | import sys 6 | sys.path.append('../') 7 | import torch 8 | 9 | 10 | class BaseTrainer: 11 | """ 12 | Base class for all trainers 13 | """ 14 | def __init__(self, model, criterion, metric_ftns, optimizer, config): 15 | self.config = config 16 | self.logger = config.get_logger('training', config['training']['verbosity']) 17 | 18 | # setup GPU device if available, move model into configured device 19 | self.device, device_ids = self._prepare_device(config['n_gpu']) 20 | self.model = model.to(self.device) 21 | if len(device_ids) > 1: 22 | self.model = torch.nn.DataParallel(model, device_ids=device_ids) 23 | 24 | self.criterion = criterion 25 | self.metric_ftns = metric_ftns 26 | self.optimizer = optimizer 27 | 28 | cfg_trainer = config['training'] 29 | self.n_epochs = cfg_trainer['epochs'] 30 | self.iter_save = cfg_trainer['iter_save'] 31 | self.checkpoint_dir = config['ckpt_dir'] 32 | 33 | if config.resume is not None: 34 | self._resume_checkpoint(config.resume) 35 | 36 | def train_epoch(self, epoch): 37 | """ 38 | Training logic for an epoch 39 | 40 | :param epoch: Current epoch number 41 | """ 42 | raise NotImplementedError 43 | 44 | def test(self): 45 | """ 46 | Testing logic for an epoch 47 | 48 | :param epoch: Current epoch number 49 | """ 50 | raise NotImplementedError 51 | 52 | def train(self): 53 | """ 54 | Full training logic 55 | """ 56 | # TODO: need to clean up this code -- lots of unnecessary components 57 | not_improved_count = 0 58 | for epoch in range(1, self.n_epochs + 1): 59 | result = self.train_epoch(epoch) 60 | 61 | # save logged informations into log dict 62 | log = {'epoch': epoch} 63 | log.update(result) 64 | 65 | # print logged informations to the screen 66 | for key, value in log.items(): 67 | self.logger.info(' {:15s}: {}'.format(str(key), value)) 68 | 69 | # evaluate model performance according to configured metric, save best checkpoint as model_best 70 | best = False 71 | if self.mnt_mode != 'off': 72 | try: 73 | # check whether model performance improved or not, according to specified metric(mnt_metric) 74 | improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \ 75 | (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best) 76 | except KeyError: 77 | self.logger.warning("Warning: Metric '{}' is not found. " 78 | "Model performance monitoring is disabled.".format(self.mnt_metric)) 79 | self.mnt_mode = 'off' 80 | improved = False 81 | 82 | if improved: 83 | self.mnt_best = log[self.mnt_metric] 84 | not_improved_count = 0 85 | best = True 86 | else: 87 | not_improved_count += 1 88 | 89 | if not_improved_count > self.early_stop: 90 | self.logger.info("Validation performance didn\'t improve for {} epochs. " 91 | "Training stops.".format(self.early_stop)) 92 | break 93 | 94 | if epoch % self.iter_save == 0: 95 | self._save_checkpoint(epoch, save_best=best) 96 | 97 | def _prepare_device(self, n_gpu_use): 98 | """ 99 | setup GPU device if available, move model into configured device 100 | """ 101 | n_gpu = torch.cuda.device_count() 102 | if n_gpu_use > 0 and n_gpu == 0: 103 | self.logger.warning("Warning: There\'s no GPU available on this machine, training will be performed on CPU.") 104 | n_gpu_use = 0 105 | if n_gpu_use > n_gpu: 106 | self.logger.warning("Warning: The number of GPU\'s configured to use is {}, but only {} are available on this machine.".format(n_gpu_use, n_gpu)) 107 | n_gpu_use = n_gpu 108 | device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu') 109 | list_ids = list(range(n_gpu_use)) 110 | return device, list_ids 111 | 112 | def _save_checkpoint(self, epoch, save_best=False): 113 | """ 114 | Saving checkpoints 115 | 116 | :param epoch: current epoch number 117 | :param log: logging information of the epoch 118 | :param save_best: if True, rename the saved checkpoint to 'model_best.pth' 119 | """ 120 | state = { 121 | 'epoch': epoch, 122 | 'state_dict': self.model.state_dict(), 123 | 'optimizer': self.optimizer.state_dict(), 124 | 'config': self.config 125 | } 126 | if save_best: 127 | print("Saving current best...") 128 | filename = os.path.join(self.checkpoint_dir, 'model_best.pth') 129 | else: 130 | print('Saving checkpoint...') 131 | filename = os.path.join(self.checkpoint_dir, 'checkpoint.pth') 132 | torch.save(state, filename) 133 | 134 | def _resume_checkpoint(self, resume_path): 135 | """ 136 | Resume from saved checkpoints 137 | 138 | :param resume_path: Checkpoint path to be resumed 139 | """ 140 | resume_path = str(resume_path) 141 | if not os.path.exists(resume_path): 142 | print('Model checkpoint does not exist!') 143 | sys.exit(1) 144 | # self.logger.info("Loading checkpoint: {} ...".format(resume_path)) 145 | print("Loading checkpoint: {} ...".format(resume_path)) 146 | checkpoint = torch.load(resume_path) 147 | self.start_epoch = checkpoint['epoch'] 148 | 149 | # load architecture params from checkpoint. 150 | # if checkpoint['config']['arch'] != self.config['arch']: 151 | # self.logger.warning("Warning: Architecture configuration given in config file is different from that of checkpoint. This may yield an exception while state_dict is being loaded.") 152 | self.model.load_state_dict(checkpoint['state_dict']) 153 | 154 | # load optimizer state from checkpoint only when optimizer type is not changed. 155 | # if checkpoint['config'].optim.optimizer != self.config.optim.optimizer: 156 | # self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. Optimizer parameters not being resumed.") 157 | # else: 158 | self.optimizer.load_state_dict(checkpoint['optimizer']) 159 | 160 | print("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch)) --------------------------------------------------------------------------------