├── model
├── __init__.py
├── nn.py
├── gumbel_masks.py
└── ilcm_vae.py
├── scripts
├── __init__.py
├── produce_all_logs.py
├── compute_udr_npy.py
└── compute_udr_npy_rand_graphs.py
├── baseline_models
├── __init__.py
├── beta-tcvae
│ ├── __init__.py
│ ├── metric_helpers
│ │ ├── mi_metric.py
│ │ └── loader.py
│ ├── LICENSE
│ ├── README.md
│ ├── elbo_decomposition.py
│ └── disentanglement_metrics.py
├── icebeem
│ ├── __init__.py
│ ├── metrics
│ │ └── __init__.py
│ ├── models
│ │ ├── nflib
│ │ │ └── __init__.py
│ │ ├── tcl
│ │ │ ├── __init__.py
│ │ │ ├── tcl_preprocessing.py
│ │ │ ├── tcl_eval.py
│ │ │ └── tcl_wrapper_gpu.py
│ │ ├── __init__.py
│ │ ├── ivae
│ │ │ ├── __init__.py
│ │ │ └── ivae_wrapper.py
│ │ ├── icebeem_wrapper.py
│ │ ├── ebm.py
│ │ └── nets.py
│ ├── data
│ │ ├── __init__.py
│ │ └── utils.py
│ ├── losses
│ │ ├── __init__.py
│ │ └── dsm.py
│ ├── configs
│ │ ├── imca.yaml
│ │ ├── mnist_baseline.yaml
│ │ ├── cifar10.yaml
│ │ ├── fashionmnist_baseline.yaml
│ │ ├── cifar100.yaml
│ │ ├── cifar10_baseline.yaml
│ │ ├── mnist.yaml
│ │ ├── fashionmnist.yaml
│ │ └── cifar100_baseline.yaml
│ ├── LICENSE
│ └── train.py
└── slowvae_pcl
│ ├── scripts
│ ├── __init__.py
│ ├── evaluate_disentanglement.py
│ ├── model.py
│ └── data_analysis_utils.py
│ ├── mcc_metric
│ ├── __init__.py
│ └── metric.py
│ ├── data_generation
│ ├── coco
│ │ ├── __init__.py
│ │ └── mask.py
│ ├── gen_youtube_csv.py
│ └── gen_kitti_masks.py
│ ├── latent_factors.gif
│ ├── requirements.txt
│ ├── metric_configs
│ ├── dci.gin
│ ├── mcc.gin
│ ├── mig.gin
│ ├── sap_score.gin
│ ├── beta_vae_sklearn.gin
│ ├── modularity_explicitness.gin
│ └── factor_vae_metric.gin
│ ├── LICENSE
│ ├── README.md
│ └── train.py
├── universal_logger
├── __init__.py
└── logger.py
├── results
├── CLeaR2022
│ └── all_logs (8).npy
├── JMLR
│ ├── penalty
│ │ ├── all_logs_28jul2023_udr.npy
│ │ └── all_logs_12sep2023_rand_graphs_udr.npy
│ └── constraint
│ │ ├── graphVisualization
│ │ ├── TimeNonDiag
│ │ │ ├── fig.png
│ │ │ ├── fig.xcf
│ │ │ ├── gt_gz_0.png
│ │ │ ├── G^z_300000.png
│ │ │ ├── C_pattern_300000.png
│ │ │ ├── estimated_C_300000.png
│ │ │ ├── permuted_G_mask_300000.png
│ │ │ └── correlation_matrix_300000.png
│ │ ├── ActionNonDiag
│ │ │ ├── fig.png
│ │ │ ├── fig.xcf
│ │ │ ├── gt_ga_0.png
│ │ │ ├── G^a_300000.png
│ │ │ ├── C_pattern_300000.png
│ │ │ ├── estimated_C_300000.png
│ │ │ ├── correlation_matrix_300000.png
│ │ │ └── permuted_GC_mask_300000.png
│ │ ├── TimeBlockNonDiag
│ │ │ ├── fig.png
│ │ │ ├── fig.xcf
│ │ │ ├── gt_gz_0.png
│ │ │ ├── G^z_300000.png
│ │ │ ├── C_pattern_300000.png
│ │ │ ├── estimated_C_300000.png
│ │ │ ├── permuted_G_mask_300000.png
│ │ │ └── correlation_matrix_300000.png
│ │ └── ActionBlockNonDiag
│ │ │ ├── fig.png
│ │ │ ├── fig.xcf
│ │ │ ├── gt_ga_0.png
│ │ │ ├── G^a_300000.png
│ │ │ ├── C_pattern_300000.png
│ │ │ ├── estimated_C_300000.png
│ │ │ ├── correlation_matrix_300000.png
│ │ │ └── permuted_GC_mask_300000.png
│ │ ├── all_logs_16sep2023_JMLR_const_ws_udr.npy
│ │ ├── all_logs_23sep2023_JMLR_const_rand_g.npy
│ │ ├── all_logs_18sep2023_JMLR_const_clear_udr.npy
│ │ └── all_logs_19sep2023_JMLR_const_violation.npy
└── summary.txt
├── plot.py
├── requirements.txt
├── .gitignore
├── README.md
└── optimization.py
/model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/baseline_models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/universal_logger/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/baseline_models/beta-tcvae/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/baseline_models/icebeem/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/baseline_models/slowvae_pcl/scripts/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/baseline_models/slowvae_pcl/mcc_metric/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/baseline_models/slowvae_pcl/data_generation/coco/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/baseline_models/icebeem/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from .mcc import mean_corr_coef
2 | __all__ = ["mcc"]
--------------------------------------------------------------------------------
/baseline_models/icebeem/models/nflib/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = ["conditional_flows", "flows", "spline_flows"]
--------------------------------------------------------------------------------
/baseline_models/icebeem/models/tcl/__init__.py:
--------------------------------------------------------------------------------
1 | from .tcl_wrapper_gpu import train
2 | __all__ = ["tcl_wrapper_gpu", "tcl_core"]
--------------------------------------------------------------------------------
/baseline_models/icebeem/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .icebeem_wrapper import ICEBEEM_wrapper
2 | __all__ = ["icebeem_wrapper", "nets"]
3 |
--------------------------------------------------------------------------------
/baseline_models/icebeem/models/ivae/__init__.py:
--------------------------------------------------------------------------------
1 | from .ivae_wrapper import IVAE_wrapper
2 | __all__ = ["ivae_wrapper", "ivae_core"]
--------------------------------------------------------------------------------
/baseline_models/icebeem/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .imca import generate_synthetic_data
2 | from .utils import to_one_hot
3 |
4 | __all__ = ["imca", "utils"]
--------------------------------------------------------------------------------
/results/CLeaR2022/all_logs (8).npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/CLeaR2022/all_logs (8).npy
--------------------------------------------------------------------------------
/baseline_models/icebeem/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from .dsm import conditional_dsm, dsm_score_estimation, dsm, cdsm
2 | from .fce import ConditionalFCE
3 | __all__ = ["dsm", "fce"]
--------------------------------------------------------------------------------
/baseline_models/slowvae_pcl/latent_factors.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/baseline_models/slowvae_pcl/latent_factors.gif
--------------------------------------------------------------------------------
/results/JMLR/penalty/all_logs_28jul2023_udr.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/penalty/all_logs_28jul2023_udr.npy
--------------------------------------------------------------------------------
/results/JMLR/penalty/all_logs_12sep2023_rand_graphs_udr.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/penalty/all_logs_12sep2023_rand_graphs_udr.npy
--------------------------------------------------------------------------------
/results/JMLR/constraint/graphVisualization/TimeNonDiag/fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/TimeNonDiag/fig.png
--------------------------------------------------------------------------------
/results/JMLR/constraint/graphVisualization/TimeNonDiag/fig.xcf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/TimeNonDiag/fig.xcf
--------------------------------------------------------------------------------
/results/JMLR/constraint/all_logs_16sep2023_JMLR_const_ws_udr.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/all_logs_16sep2023_JMLR_const_ws_udr.npy
--------------------------------------------------------------------------------
/results/JMLR/constraint/all_logs_23sep2023_JMLR_const_rand_g.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/all_logs_23sep2023_JMLR_const_rand_g.npy
--------------------------------------------------------------------------------
/results/JMLR/constraint/graphVisualization/ActionNonDiag/fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/ActionNonDiag/fig.png
--------------------------------------------------------------------------------
/results/JMLR/constraint/graphVisualization/ActionNonDiag/fig.xcf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/ActionNonDiag/fig.xcf
--------------------------------------------------------------------------------
/results/JMLR/constraint/graphVisualization/TimeNonDiag/gt_gz_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/TimeNonDiag/gt_gz_0.png
--------------------------------------------------------------------------------
/results/JMLR/constraint/all_logs_18sep2023_JMLR_const_clear_udr.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/all_logs_18sep2023_JMLR_const_clear_udr.npy
--------------------------------------------------------------------------------
/results/JMLR/constraint/all_logs_19sep2023_JMLR_const_violation.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/all_logs_19sep2023_JMLR_const_violation.npy
--------------------------------------------------------------------------------
/results/JMLR/constraint/graphVisualization/ActionNonDiag/gt_ga_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/ActionNonDiag/gt_ga_0.png
--------------------------------------------------------------------------------
/results/JMLR/constraint/graphVisualization/TimeBlockNonDiag/fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/TimeBlockNonDiag/fig.png
--------------------------------------------------------------------------------
/results/JMLR/constraint/graphVisualization/TimeBlockNonDiag/fig.xcf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/TimeBlockNonDiag/fig.xcf
--------------------------------------------------------------------------------
/results/JMLR/constraint/graphVisualization/ActionBlockNonDiag/fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/ActionBlockNonDiag/fig.png
--------------------------------------------------------------------------------
/results/JMLR/constraint/graphVisualization/ActionBlockNonDiag/fig.xcf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/ActionBlockNonDiag/fig.xcf
--------------------------------------------------------------------------------
/results/JMLR/constraint/graphVisualization/ActionNonDiag/G^a_300000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/ActionNonDiag/G^a_300000.png
--------------------------------------------------------------------------------
/results/JMLR/constraint/graphVisualization/TimeBlockNonDiag/gt_gz_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/TimeBlockNonDiag/gt_gz_0.png
--------------------------------------------------------------------------------
/results/JMLR/constraint/graphVisualization/TimeNonDiag/G^z_300000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/TimeNonDiag/G^z_300000.png
--------------------------------------------------------------------------------
/results/JMLR/constraint/graphVisualization/ActionBlockNonDiag/gt_ga_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/ActionBlockNonDiag/gt_ga_0.png
--------------------------------------------------------------------------------
/results/JMLR/constraint/graphVisualization/ActionBlockNonDiag/G^a_300000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/ActionBlockNonDiag/G^a_300000.png
--------------------------------------------------------------------------------
/results/JMLR/constraint/graphVisualization/TimeBlockNonDiag/G^z_300000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/TimeBlockNonDiag/G^z_300000.png
--------------------------------------------------------------------------------
/results/JMLR/constraint/graphVisualization/TimeNonDiag/C_pattern_300000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/TimeNonDiag/C_pattern_300000.png
--------------------------------------------------------------------------------
/results/JMLR/constraint/graphVisualization/ActionNonDiag/C_pattern_300000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/ActionNonDiag/C_pattern_300000.png
--------------------------------------------------------------------------------
/results/JMLR/constraint/graphVisualization/TimeNonDiag/estimated_C_300000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/TimeNonDiag/estimated_C_300000.png
--------------------------------------------------------------------------------
/results/JMLR/constraint/graphVisualization/ActionNonDiag/estimated_C_300000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/ActionNonDiag/estimated_C_300000.png
--------------------------------------------------------------------------------
/results/JMLR/constraint/graphVisualization/TimeBlockNonDiag/C_pattern_300000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/TimeBlockNonDiag/C_pattern_300000.png
--------------------------------------------------------------------------------
/results/JMLR/constraint/graphVisualization/TimeNonDiag/permuted_G_mask_300000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/TimeNonDiag/permuted_G_mask_300000.png
--------------------------------------------------------------------------------
/results/JMLR/constraint/graphVisualization/ActionBlockNonDiag/C_pattern_300000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/ActionBlockNonDiag/C_pattern_300000.png
--------------------------------------------------------------------------------
/results/JMLR/constraint/graphVisualization/TimeBlockNonDiag/estimated_C_300000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/TimeBlockNonDiag/estimated_C_300000.png
--------------------------------------------------------------------------------
/results/JMLR/constraint/graphVisualization/ActionBlockNonDiag/estimated_C_300000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/ActionBlockNonDiag/estimated_C_300000.png
--------------------------------------------------------------------------------
/results/JMLR/constraint/graphVisualization/ActionNonDiag/correlation_matrix_300000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/ActionNonDiag/correlation_matrix_300000.png
--------------------------------------------------------------------------------
/results/JMLR/constraint/graphVisualization/ActionNonDiag/permuted_GC_mask_300000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/ActionNonDiag/permuted_GC_mask_300000.png
--------------------------------------------------------------------------------
/results/JMLR/constraint/graphVisualization/TimeBlockNonDiag/permuted_G_mask_300000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/TimeBlockNonDiag/permuted_G_mask_300000.png
--------------------------------------------------------------------------------
/results/JMLR/constraint/graphVisualization/TimeNonDiag/correlation_matrix_300000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/TimeNonDiag/correlation_matrix_300000.png
--------------------------------------------------------------------------------
/results/JMLR/constraint/graphVisualization/ActionBlockNonDiag/correlation_matrix_300000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/ActionBlockNonDiag/correlation_matrix_300000.png
--------------------------------------------------------------------------------
/results/JMLR/constraint/graphVisualization/ActionBlockNonDiag/permuted_GC_mask_300000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/ActionBlockNonDiag/permuted_GC_mask_300000.png
--------------------------------------------------------------------------------
/results/JMLR/constraint/graphVisualization/TimeBlockNonDiag/correlation_matrix_300000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/TimeBlockNonDiag/correlation_matrix_300000.png
--------------------------------------------------------------------------------
/baseline_models/slowvae_pcl/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.19.0
2 | torch==1.4.0
3 | torchvision==0.5.0
4 | imageio==2.8.0
5 | matplotlib==3.2.2
6 | scipy==1.1.0
7 | scikit-learn==0.23.1
8 | spriteworld==1.0.2
9 | pandas==1.0.5
10 | disentanglement-lib==1.4
11 | tensorflow==1.14
12 | tqdm==4.31.1
--------------------------------------------------------------------------------
/baseline_models/icebeem/configs/imca.yaml:
--------------------------------------------------------------------------------
1 | data_dim: 5
2 | n_segments: 20
3 | n_layers: [ 2,4 ] # [2,4]
4 | n_obs_per_seg: [100, 200, 500, 1000, 2000]
5 | data_seed: 1
6 |
7 | ivae:
8 | max_iter: 70000
9 | lr: 0.001
10 | cuda: False
11 |
12 | icebeem:
13 | lr_flow: 0.00001
14 | lr_ebm: 0.0003
15 | n_layers_flow: 10 # make 5 for L=2 IMCA best perf
16 | ebm_hidden_size: 32 # make 16 for L=2 IMCA best perf
17 |
--------------------------------------------------------------------------------
/baseline_models/icebeem/configs/mnist_baseline.yaml:
--------------------------------------------------------------------------------
1 | # config file for baseline (of no transfer) example on MNIST
2 |
3 | training:
4 | batch_size: 63
5 | n_epochs: 5
6 | n_iters: 200001
7 | ngpu: 1
8 | snapshot_freq: 50
9 | algo: 'dsm'
10 | anneal_power: 2.0
11 |
12 | data:
13 | dataset: "MNIST_transferBaseline"
14 | image_size: 28
15 | channels: 1
16 | logit_transform: false
17 | random_flip: false
18 | store_loss: true
19 |
20 | model:
21 | sigma_begin: 1
22 | sigma_end: 0.01
23 | num_classes: 10
24 | ngf: 64
25 | final_layer: true
26 | feature_size: 90
27 | augment: false
28 | positive: false
29 | architecture: 'ConvMLP'
30 |
31 | optim:
32 | weight_decay: 0.000
33 | optimizer: "Adam"
34 | lr: 0.001
35 | beta1: 0.9
36 | amsgrad: false
37 |
38 | n_labels: 8
39 |
--------------------------------------------------------------------------------
/baseline_models/icebeem/configs/cifar10.yaml:
--------------------------------------------------------------------------------
1 | # config file for transfer learning example on CIFAR10
2 |
3 | training:
4 | batch_size: 63
5 | n_epochs: 500000
6 | n_iters: 5001
7 | ngpu: 1
8 | snapshot_freq: 250
9 | algo: 'dsm'
10 | anneal_power: 2.0
11 |
12 | data:
13 | dataset: "CIFAR10"
14 | image_size: 32
15 | channels: 3
16 | logit_transform: false
17 | random_flip: true
18 | random_state: 0
19 | split_size: .33
20 |
21 | model:
22 | sigma_begin: 1
23 | sigma_end: 0.01
24 | num_classes: 10
25 | batch_norm: false
26 | ngf: 64
27 | final_layer: true
28 | feature_size: 200
29 | augment: false
30 | positive: false
31 | architecture: 'ConvMLP'
32 |
33 | optim:
34 | weight_decay: 0.000
35 | optimizer: "Adam"
36 | lr: 0.001
37 | beta1: 0.9
38 | amsgrad: false
39 |
40 | n_labels: 8
41 |
--------------------------------------------------------------------------------
/baseline_models/icebeem/configs/fashionmnist_baseline.yaml:
--------------------------------------------------------------------------------
1 | # config file for transfer learning example on FashionMNIST
2 |
3 | training:
4 | batch_size: 63
5 | n_epochs: 15
6 | n_iters: 200001
7 | ngpu: 1
8 | snapshot_freq: 50
9 | algo: 'dsm'
10 | anneal_power: 2.0
11 |
12 | data:
13 | dataset: "FashionMNIST_transferBaseline"
14 | image_size: 28
15 | channels: 1
16 | logit_transform: false
17 | random_flip: false
18 |
19 | model:
20 | sigma_begin: 1
21 | sigma_end: 0.01
22 | num_classes: 10
23 | batch_norm: false
24 | ngf: 64
25 | final_layer: true
26 | feature_size: 200
27 | augment: false
28 | positive: false
29 | architecture: 'ConvMLP'
30 |
31 | optim:
32 | weight_decay: 0.000
33 | optimizer: "Adam"
34 | lr: 0.001
35 | beta1: 0.9
36 | amsgrad: false
37 |
38 | n_labels: 8
39 |
--------------------------------------------------------------------------------
/baseline_models/icebeem/configs/cifar100.yaml:
--------------------------------------------------------------------------------
1 | # config file for transfer learning example on CIFAR100
2 |
3 | training:
4 | batch_size: 63
5 | n_epochs: 500000
6 | n_iters: 5001
7 | ngpu: 1
8 | snapshot_freq: 250
9 | algo: 'dsm'
10 | anneal_power: 2.0
11 |
12 | data:
13 | dataset: "CIFAR100"
14 | image_size: 32
15 | channels: 3
16 | logit_transform: false
17 | random_flip: true
18 | random_state: 0
19 | split_size: .33
20 |
21 | model:
22 | sigma_begin: 1
23 | sigma_end: 0.01
24 | num_classes: 10
25 | batch_norm: false
26 | ngf: 64
27 | final_layer: true
28 | feature_size: 200
29 | augment: false
30 | positive: false
31 | architecture: 'ConvMLP'
32 |
33 | optim:
34 | weight_decay: 0.000
35 | optimizer: "Adam"
36 | lr: 0.001
37 | beta1: 0.9
38 | amsgrad: false
39 |
40 | n_labels: 85
41 |
--------------------------------------------------------------------------------
/baseline_models/icebeem/configs/cifar10_baseline.yaml:
--------------------------------------------------------------------------------
1 | # config file for baseline (of no transfer) example on CIFAR10
2 |
3 | training:
4 | batch_size: 63
5 | n_epochs: 15
6 | n_iters: 200001
7 | ngpu: 1
8 | snapshot_freq: 50
9 | algo: 'dsm'
10 | anneal_power: 2.0
11 |
12 | data:
13 | dataset: "CIFAR10_transferBaseline"
14 | image_size: 32
15 | channels: 3
16 | logit_transform: false
17 | random_flip: true
18 | store_loss: true
19 |
20 | model:
21 | sigma_begin: 1
22 | sigma_end: 0.01
23 | num_classes: 10
24 | batch_norm: false
25 | ngf: 64
26 | final_layer: true
27 | feature_size: 200
28 | augment: false
29 | positive: false
30 | architecture: 'ConvMLP'
31 |
32 | optim:
33 | weight_decay: 0.000
34 | optimizer: "Adam"
35 | lr: 0.001
36 | beta1: 0.9
37 | amsgrad: false
38 |
39 | n_labels: 8
40 |
--------------------------------------------------------------------------------
/baseline_models/icebeem/configs/mnist.yaml:
--------------------------------------------------------------------------------
1 | # config file for transfer learning example on MNIST
2 |
3 | training:
4 | batch_size: 63
5 | n_epochs: 500000
6 | n_iters: 5001
7 | ngpu: 1
8 | snapshot_freq: 250
9 | algo: 'dsm'
10 | anneal_power: 2.0
11 |
12 | data:
13 | dataset: "MNIST"
14 | image_size: 28
15 | channels: 1
16 | logit_transform: false
17 | random_flip: false
18 | random_state: 0
19 | split_size: .15
20 |
21 |
22 | model:
23 | sigma_begin: 1
24 | sigma_end: 0.01
25 | num_classes: 10
26 | batch_norm: false
27 | ngf: 64
28 | final_layer: true
29 | feature_size: 90
30 | augment: false
31 | positive: false
32 | architecture: 'ConvMLP'
33 |
34 | optim:
35 | weight_decay: 0.000
36 | optimizer: "Adam"
37 | lr: 0.001
38 | beta1: 0.9
39 | amsgrad: false
40 |
41 | n_labels: 8
42 |
--------------------------------------------------------------------------------
/results/summary.txt:
--------------------------------------------------------------------------------
1 | CLeaR 2022:
2 | /CLeaR2022/all_logs (8).npy
3 | to generate figures: https://colab.research.google.com/drive/1-gUajJMcGEEcpQA9nm9XcYB4ZLpR-wVc#scrollTo=thick-range
4 |
5 | JMLR (penalty):
6 | JMLR/penalty/all_logs_28jul2023_udr.npy (workshop data)
7 | JMLR/penalty/all_logs_12sep2023_rand_graphs_udr.npy
8 | to generate figures: https://colab.research.google.com/drive/1godG3MbAlLhB-FK95HJCHVQu2IjE8mM7#scrollTo=V0oKAFtzMXmw
9 |
10 | JMLR (constraint):
11 | JMLR/constraint/all_logs_18sep2023_JMLR_const_clear_udr.npy
12 | JMLR/constraint/all_logs_16sep2023_JMLR_const_ws_udr.npy
13 | JMLR/constraint/all_logs_19sep2023_JMLR_const_violation.npy
14 | JMLR/constraint/all_logs_23sep2023_JMLR_const_rand_g.npy
15 | to generate figures: https://colab.research.google.com/drive/1lftjqLUeoIQrNlWyKRrtPyXRXwyQz3Y5
16 |
17 |
--------------------------------------------------------------------------------
/baseline_models/slowvae_pcl/metric_configs/dci.gin:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The DisentanglementLib Authors. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | evaluation.evaluation_fn = @dci
17 | evaluation.random_seed = 0
18 | dci.num_train=10000
19 | dci.num_test=5000
20 |
--------------------------------------------------------------------------------
/baseline_models/icebeem/configs/fashionmnist.yaml:
--------------------------------------------------------------------------------
1 | # config file for transfer learning example on FashionMNIST
2 |
3 | training:
4 | batch_size: 63
5 | n_epochs: 500000
6 | n_iters: 5001
7 | ngpu: 1
8 | snapshot_freq: 250
9 | algo: 'dsm'
10 | anneal_power: 2.0
11 |
12 | data:
13 | dataset: "FashionMNIST"
14 | image_size: 28
15 | channels: 1
16 | logit_transform: false
17 | random_flip: false
18 | random_state: 0
19 | split_size : .15
20 |
21 | model:
22 | sigma_begin: 1
23 | sigma_end: 0.01
24 | num_classes: 10
25 | batch_norm: false
26 | ngf: 64
27 | final_layer: true
28 | feature_size: 200
29 | augment: false
30 | positive: false
31 | architecture: 'ConvMLP'
32 |
33 | optim:
34 | weight_decay: 0.000
35 | optimizer: "Adam"
36 | lr: 0.001
37 | beta1: 0.9
38 | amsgrad: false
39 |
40 | n_labels: 8
41 |
--------------------------------------------------------------------------------
/baseline_models/icebeem/configs/cifar100_baseline.yaml:
--------------------------------------------------------------------------------
1 | # config file for transfer learning example on CIFAR100
2 |
3 | training:
4 | batch_size: 63
5 | n_epochs: 15
6 | n_iters: 200001
7 | ngpu: 1
8 | snapshot_freq: 50
9 | algo: 'dsm'
10 | anneal_power: 2.0
11 |
12 | data:
13 | dataset: "CIFAR100_transferBaseline"
14 | image_size: 32
15 | channels: 3
16 | logit_transform: false
17 | random_flip: true
18 | random_state: 0
19 | split_size: .33
20 |
21 | model:
22 | sigma_begin: 1
23 | sigma_end: 0.01
24 | num_classes: 10
25 | batch_norm: false
26 | ngf: 64
27 | final_layer: true
28 | feature_size: 200
29 | augment: false
30 | positive: false
31 | architecture: 'ConvMLP'
32 |
33 | optim:
34 | weight_decay: 0.000
35 | optimizer: "Adam"
36 | lr: 0.001
37 | beta1: 0.9
38 | amsgrad: false
39 |
40 | n_labels: 85
41 |
--------------------------------------------------------------------------------
/baseline_models/slowvae_pcl/metric_configs/mcc.gin:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The DisentanglementLib Authors. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | evaluation.evaluation_fn = @mcc
17 | evaluation.random_seed = 0
18 | mcc.num_train=10000
19 | mcc.correlation_fn = "Spearman"
20 |
--------------------------------------------------------------------------------
/baseline_models/slowvae_pcl/metric_configs/mig.gin:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The DisentanglementLib Authors. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | evaluation.evaluation_fn = @mig
17 | evaluation.random_seed = 0
18 | mig.num_train=100000
19 | discretizer.discretizer_fn = @histogram_discretizer
20 | discretizer.num_bins = 20
21 |
--------------------------------------------------------------------------------
/baseline_models/slowvae_pcl/metric_configs/sap_score.gin:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The DisentanglementLib Authors. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | evaluation.evaluation_fn = @sap_score
17 | evaluation.random_seed = 0
18 | sap_score.num_train=10000
19 | sap_score.num_test=5000
20 | sap_score.continuous_factors=False
21 |
--------------------------------------------------------------------------------
/baseline_models/slowvae_pcl/metric_configs/beta_vae_sklearn.gin:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The DisentanglementLib Authors. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | evaluation.evaluation_fn = @beta_vae_sklearn
17 | evaluation.random_seed = 0
18 | beta_vae_sklearn.batch_size=64
19 | beta_vae_sklearn.num_train=10000
20 | beta_vae_sklearn.num_eval=5000
21 |
--------------------------------------------------------------------------------
/baseline_models/slowvae_pcl/metric_configs/modularity_explicitness.gin:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The DisentanglementLib Authors. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | evaluation.evaluation_fn = @modularity_explicitness
17 | evaluation.random_seed = 0
18 | modularity_explicitness.num_train=100000
19 | modularity_explicitness.num_test=5000
20 | discretizer.discretizer_fn = @histogram_discretizer
21 | discretizer.num_bins = 20
22 |
--------------------------------------------------------------------------------
/baseline_models/slowvae_pcl/metric_configs/factor_vae_metric.gin:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The DisentanglementLib Authors. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | evaluation.evaluation_fn = @factor_vae_score
17 | evaluation.random_seed = 0
18 | factor_vae_score.num_variance_estimate=10000
19 | factor_vae_score.num_train=10000
20 | factor_vae_score.num_eval=5000
21 | factor_vae_score.batch_size=64
22 | prune_dims.threshold = 0.05
23 |
24 |
--------------------------------------------------------------------------------
/baseline_models/beta-tcvae/metric_helpers/mi_metric.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | metric_name = 'MIG'
4 |
5 |
6 | def MIG(mi_normed):
7 | return torch.mean(mi_normed[:, 0] - mi_normed[:, 1])
8 |
9 |
10 | def compute_metric_shapes(marginal_entropies, cond_entropies):
11 | factor_entropies = [6, 40, 32, 32]
12 | mutual_infos = marginal_entropies[None] - cond_entropies
13 | mutual_infos = torch.sort(mutual_infos, dim=1, descending=True)[0].clamp(min=0)
14 | mi_normed = mutual_infos / torch.Tensor(factor_entropies).log()[:, None]
15 | metric = eval(metric_name)(mi_normed)
16 | return metric
17 |
18 |
19 | def compute_metric_faces(marginal_entropies, cond_entropies):
20 | factor_entropies = [21, 11, 11]
21 | mutual_infos = marginal_entropies[None] - cond_entropies
22 | mutual_infos = torch.sort(mutual_infos, dim=1, descending=True)[0].clamp(min=0)
23 | mi_normed = mutual_infos / torch.Tensor(factor_entropies).log()[:, None]
24 | metric = eval(metric_name)(mi_normed)
25 | return metric
26 |
27 |
--------------------------------------------------------------------------------
/baseline_models/icebeem/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Ilyes Khemakhem
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/baseline_models/slowvae_pcl/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Bethge Lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/baseline_models/beta-tcvae/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Ricky Tian Qi Chen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/baseline_models/beta-tcvae/README.md:
--------------------------------------------------------------------------------
1 | # beta-TCVAE
2 |
3 | This repository contains cleaned-up code for reproducing the quantitative experiments in Isolating Sources of Disentanglement in Variational Autoencoders \[[arxiv](https://arxiv.org/abs/1802.04942)\].
4 |
5 | ## Usage
6 |
7 | To train a model:
8 |
9 | ```
10 | python vae_quant.py --dataset [shapes/faces] --beta 6 --tcvae
11 | ```
12 | Specify `--conv` to use the convolutional VAE. We used a mlp for dSprites and conv for 3d faces. To see all options, use the `-h` flag.
13 |
14 | The main computational difference between beta-VAE and beta-TCVAE is summarized in [these lines](vae_quant.py#L220-L228).
15 |
16 | To evaluate the MIG of a model:
17 | ```
18 | python disentanglement_metrics.py --checkpt [checkpt]
19 | ```
20 | To see all options, use the `-h` flag.
21 |
22 | ## Datasets
23 |
24 | ### dSprites
25 | Download the npz file from [here](https://github.com/deepmind/dsprites-dataset) and place it into `data/`.
26 |
27 | ### 3D faces
28 | We cannot publicly distribute this due to the [license](https://faces.dmi.unibas.ch/bfm/main.php?nav=1-2&id=downloads). Please contact me for the data.
29 |
30 | ## Contact
31 | Email rtqichen@cs.toronto.edu if you have questions about the code/data.
32 |
33 | ## Bibtex
34 | ```
35 | @inproceedings{chen2018isolating,
36 | title={Isolating Sources of Disentanglement in Variational Autoencoders},
37 | author={Chen, Ricky T. Q. and Li, Xuechen and Grosse, Roger and Duvenaud, David},
38 | booktitle = {Advances in Neural Information Processing Systems},
39 | year={2018}
40 | }
41 | ```
42 |
--------------------------------------------------------------------------------
/baseline_models/beta-tcvae/metric_helpers/loader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import lib.dist as dist
3 | import lib.flows as flows
4 | import vae_quant
5 |
6 |
7 | def load_model_and_dataset(checkpt_filename):
8 | print('Loading model and dataset.')
9 | checkpt = torch.load(checkpt_filename, map_location=lambda storage, loc: storage)
10 | args = checkpt['args']
11 | state_dict = checkpt['state_dict']
12 |
13 | # backwards compatibility
14 | if not hasattr(args, 'conv'):
15 | args.conv = False
16 |
17 | if not hasattr(args, 'dist') or args.dist == 'normal':
18 | prior_dist = dist.Normal()
19 | q_dist = dist.Normal()
20 | elif args.dist == 'laplace':
21 | prior_dist = dist.Laplace()
22 | q_dist = dist.Laplace()
23 | elif args.dist == 'flow':
24 | prior_dist = flows.FactorialNormalizingFlow(dim=args.latent_dim, nsteps=32)
25 | q_dist = dist.Normal()
26 |
27 | # model
28 | if hasattr(args, 'ncon'):
29 | # InfoGAN
30 | model = infogan.Model(
31 | args.latent_dim, n_con=args.ncon, n_cat=args.ncat, cat_dim=args.cat_dim, use_cuda=True, conv=args.conv)
32 | model.load_state_dict(state_dict, strict=False)
33 | vae = vae_quant.VAE(
34 | z_dim=args.ncon, use_cuda=True, prior_dist=prior_dist, q_dist=q_dist, conv=args.conv)
35 | vae.encoder = model.encoder
36 | vae.decoder = model.decoder
37 | else:
38 | vae = vae_quant.VAE(
39 | z_dim=args.latent_dim, use_cuda=True, prior_dist=prior_dist, q_dist=q_dist, conv=args.conv)
40 | vae.load_state_dict(state_dict, strict=False)
41 |
42 | # dataset loader
43 | loader = vae_quant.setup_data_loaders(args)
44 | return vae, loader.dataset, args
45 |
--------------------------------------------------------------------------------
/plot.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 |
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | import torch
7 |
8 |
9 | def plot_matrix(matrix, title="", row_label="", col_label="", col_to_mark=[], row_to_mark=[], kl_values=None, row_names=[], col_names=[], vmin=0., vmax=1.):
10 | fig = plt.figure()
11 | if kl_values is not None:
12 | ax = fig.add_subplot(2, 1, 1)
13 | ax_kl = fig.add_subplot(2, 1, 2)
14 | else:
15 | ax = fig.add_subplot(1, 1, 1)
16 | ax.matshow(matrix, vmin=vmin, vmax=vmax)
17 | ax.set_title(title)
18 | ax.set_xlabel(col_label)
19 | ax.set_ylabel(row_label)
20 | ax.set_xticklabels([''] + col_names)
21 | ax.set_yticklabels([''] + row_names)
22 |
23 | # Loop over data dimensions and create text annotations.
24 | for i in range(matrix.shape[0]):
25 | for j in range(matrix.shape[1]):
26 | text = "{:.2f}".format(matrix[i, j])
27 | if i in row_to_mark:
28 | text += "*"
29 | if j in col_to_mark:
30 | text += "*"
31 | ax.text(j, i, text, ha="center", va="center", color="w")
32 |
33 | if kl_values is not None:
34 | ax_kl.matshow(kl_values.reshape(1, -1), vmin=0)
35 | ax_kl.set_xlabel("KL values")
36 |
37 | return fig
38 |
39 | def plot_weighted_adjacency_vs_steps(weighted_adjacency, gt_adjacency, iterations=None):
40 | num_row = weighted_adjacency.shape[1]
41 | num_col = weighted_adjacency.shape[2]
42 | if iterations is None:
43 | iterations = range(weighted_adjacency.shape[0])
44 | assert weighted_adjacency.shape[0] == len(iterations)
45 |
46 | fig, ax1 = plt.subplots()
47 |
48 | # Plot weight of incorrect edges
49 | for i in range(num_row):
50 | for j in range(num_col):
51 | if gt_adjacency[i, j] == 1:
52 | color = 'g' # correct edge
53 | else:
54 | color = 'r' # incorrect edge
55 | y = weighted_adjacency[:, i, j]
56 | ax1.plot(iterations, y, color, linewidth=1)
57 |
58 | fig.tight_layout()
59 | return fig
60 |
61 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | argon2-cffi==20.1.0
2 | async-generator==1.10
3 | attrs==20.3.0
4 | backcall==0.2.0
5 | bleach==3.3.0
6 | cached-property==1.5.2
7 | certifi==2020.12.5
8 | cffi==1.14.4
9 | chardet==4.0.0
10 | comet-ml==3.3.2
11 | configobj==5.0.6
12 | cycler==0.10.0
13 | decorator==4.4.2
14 | defusedxml==0.6.0
15 | dulwich==0.20.18
16 | entrypoints==0.3
17 | everett==1.0.3
18 | h5py==3.2.1
19 | idna==2.10
20 | importlib-metadata==3.4.0
21 | ipdb==0.13.4
22 | ipykernel==5.4.3
23 | ipython==7.20.0
24 | ipython-genutils==0.2.0
25 | ipywidgets==7.6.3
26 | jedi==0.18.0
27 | Jinja2==2.11.3
28 | joblib==1.0.1
29 | jsonschema==3.2.0
30 | jupyter==1.0.0
31 | jupyter-client==6.1.11
32 | jupyter-console==6.2.0
33 | jupyter-core==4.7.1
34 | jupyter-http-over-ws==0.0.8
35 | jupyterlab-pygments==0.1.2
36 | jupyterlab-widgets==1.0.0
37 | kiwisolver==1.3.1
38 | MarkupSafe==1.1.1
39 | matplotlib==3.3.4
40 | mistune==0.8.4
41 | nbclient==0.5.2
42 | nbconvert==6.0.7
43 | nbformat==5.1.2
44 | nest-asyncio==1.5.1
45 | netifaces==0.10.9
46 | notebook==6.2.0
47 | numpy==1.20.1
48 | nvidia-ml-py3==7.352.0
49 | packaging==20.9
50 | pandas==1.2.2
51 | pandocfilters==1.4.3
52 | parso==0.8.1
53 | pexpect==4.8.0
54 | pickleshare==0.7.5
55 | Pillow==8.1.0
56 | prometheus-client==0.9.0
57 | prompt-toolkit==3.0.15
58 | ptyprocess==0.7.0
59 | pycparser==2.20
60 | Pygments==2.7.4
61 | pynvml==8.0.4
62 | pyparsing==2.4.7
63 | pyrsistent==0.17.3
64 | python-dateutil==2.8.1
65 | pytorch-ignite==0.4.3
66 | pytz==2021.1
67 | pyzmq==22.0.2
68 | qtconsole==5.0.2
69 | QtPy==1.9.0
70 | requests==2.25.1
71 | requests-toolbelt==0.9.1
72 | scikit-learn==0.24.1
73 | scipy==1.6.0
74 | seaborn==0.11.1
75 | Send2Trash==1.5.0
76 | six==1.15.0
77 | synbols==1.0.1
78 | terminado==0.9.2
79 | testpath==0.4.4
80 | threadpoolctl==2.1.0
81 | torch==1.7.1
82 | torchvision==0.8.2
83 | tornado==6.1
84 | tqdm==4.56.2
85 | traitlets==5.0.5
86 | typing-extensions==3.7.4.3
87 | urllib3==1.26.3
88 | wcwidth==0.2.5
89 | webencodings==0.5.1
90 | websocket-client==0.57.0
91 | widgetsnbextension==3.5.1
92 | wrapt==1.12.1
93 | wurlitzer==2.0.1
94 | zipp==3.4.0
95 | qj==0.2.0
96 | git+https://github.com/cooper-org/cooper.git
97 |
--------------------------------------------------------------------------------
/baseline_models/icebeem/data/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 |
5 |
6 | def to_one_hot(x, m=None):
7 | "batch one hot"
8 | if type(x) is not list:
9 | x = [x]
10 | if m is None:
11 | ml = []
12 | for xi in x:
13 | ml += [xi.max() + 1]
14 | m = max(ml)
15 | dtp = x[0].dtype
16 | xoh = []
17 | for i, xi in enumerate(x):
18 | xoh += [np.zeros((xi.size, int(m)), dtype=dtp)]
19 | xoh[i][np.arange(xi.size), xi.astype(np.int)] = 1
20 | return xoh
21 |
22 |
23 | def one_hot_encode(labels, n_labels=10):
24 | """
25 | Transforms numeric labels to 1-hot encoded labels. Assumes numeric labels are in the range 0, 1, ..., n_labels-1.
26 | """
27 |
28 | assert np.min(labels) >= 0 and np.max(labels) < n_labels
29 |
30 | y = np.zeros([labels.size, n_labels]).astype(np.float32)
31 | y[range(labels.size), labels] = 1
32 |
33 | return y
34 |
35 |
36 | def single_one_hot_encode(label, n_labels=10):
37 | """
38 | Transforms numeric labels to 1-hot encoded labels. Assumes numeric labels are in the range 0, 1, ..., n_labels-1.
39 | """
40 |
41 | assert label >= 0 and label < n_labels
42 |
43 | y = np.zeros([n_labels]).astype(np.float32)
44 | y[label] = 1
45 |
46 | return y
47 |
48 |
49 | def single_one_hot_encode_rev(label, n_labels=10, start_label=0):
50 | """
51 | Transforms numeric labels to 1-hot encoded labels. Assumes numeric labels are in the range 0, 1, ..., n_labels-1.
52 | """
53 | assert label >= start_label and label < n_labels
54 | y = np.zeros([n_labels - start_label]).astype(np.float32)
55 | y[label - start_label] = 1
56 | return y
57 |
58 |
59 | mnist_one_hot_transform = lambda label: single_one_hot_encode(label, n_labels=10)
60 | contrastive_one_hot_transform = lambda label: single_one_hot_encode(label, n_labels=2)
61 |
62 |
63 | def make_dir(dir_name):
64 | if dir_name[-1] != '/':
65 | dir_name += '/'
66 | if not os.path.exists(dir_name):
67 | os.makedirs(dir_name)
68 | return dir_name
69 |
70 |
71 | def make_file(file_name):
72 | if not os.path.exists(file_name):
73 | open(file_name, 'a').close()
74 | return file_name
75 |
--------------------------------------------------------------------------------
/baseline_models/icebeem/models/tcl/tcl_preprocessing.py:
--------------------------------------------------------------------------------
1 | """Preprocessing"""
2 |
3 | import numpy as np
4 |
5 |
6 | # ============================================================
7 | # ============================================================
8 | def pca(x, num_comp=None, params=None, zerotolerance=1e-7):
9 | """Apply PCA whitening to data.
10 | Args:
11 | x: data. 2D ndarray [num_comp, num_data]
12 | num_comp: number of components
13 | params: (option) dictionary of PCA parameters {'mean':?, 'W':?, 'A':?}. If given, apply this to the data
14 | zerotolerance: (option)
15 | Returns:
16 | x: whitened data
17 | parms: parameters of PCA
18 | mean: subtracted mean
19 | W: whitening matrix
20 | A: mixing matrix
21 | """
22 | # print("PCA...")
23 |
24 | # Dimension
25 | if num_comp is None:
26 | num_comp = x.shape[0]
27 | # print(" num_comp={0:d}".format(num_comp))
28 |
29 | # From learned parameters --------------------------------
30 | if params is not None:
31 | # Use previously-trained model
32 | print(" use learned value")
33 | data_pca = x - params['mean']
34 | x = np.dot(params['W'], data_pca)
35 |
36 | # Learn from data ----------------------------------------
37 | else:
38 | # Zero mean
39 | xmean = np.mean(x, 1).reshape([-1, 1])
40 | x = x - xmean
41 |
42 | # Eigenvalue decomposition
43 | xcov = np.cov(x)
44 | d, V = np.linalg.eigh(xcov) # Ascending order
45 | # Convert to descending order
46 | d = d[::-1]
47 | V = V[:, ::-1]
48 |
49 | zeroeigval = np.sum((d[:num_comp] / d[0]) < zerotolerance)
50 | if zeroeigval > 0: # Do not allow zero eigenval
51 | raise ValueError
52 |
53 | # Calculate contribution ratio
54 | contratio = np.sum(d[:num_comp]) / np.sum(d)
55 | # print(" contribution ratio={0:f}".format(contratio))
56 |
57 | # Construct whitening and dewhitening matrices
58 | dsqrt = np.sqrt(d[:num_comp])
59 | dsqrtinv = 1 / dsqrt
60 | V = V[:, :num_comp]
61 | # Whitening
62 | W = np.dot(np.diag(dsqrtinv), V.transpose()) # whitening matrix
63 | A = np.dot(V, np.diag(dsqrt)) # de-whitening matrix
64 | x = np.dot(W, x)
65 |
66 | params = {'mean': xmean, 'W': W, 'A': A}
67 |
68 | # Check
69 | datacov = np.cov(x)
70 |
71 | return x, params
72 |
--------------------------------------------------------------------------------
/baseline_models/icebeem/losses/dsm.py:
--------------------------------------------------------------------------------
1 | ### conditional dsm objective
2 | #
3 | # this code is adapted from: https://github.com/ermongroup/ncsn/
4 | #
5 |
6 | import torch
7 | import torch.autograd as autograd
8 |
9 |
10 | def dsm(energy_net, samples, sigma=1):
11 | samples.requires_grad_(True)
12 | vector = torch.randn_like(samples) * sigma
13 | perturbed_inputs = samples + vector
14 | logp = -energy_net(perturbed_inputs)
15 | dlogp = sigma ** 2 * autograd.grad(logp.sum(), perturbed_inputs, create_graph=True)[0]
16 | kernel = vector
17 | loss = torch.norm(dlogp + kernel, dim=-1) ** 2
18 | loss = loss.mean() / 2.
19 |
20 | return loss
21 |
22 |
23 | def cdsm(energy_net, samples, conditions, sigma=1.):
24 | """
25 | Conditional denoising score matching
26 | :param energy_net: an energy network that takes x and y as input and outputs energy of shape (batch_size,)
27 | :param samples: values of dependent variable x
28 | :param conditions: values of conditioning variable y
29 | :param sigma: noise level for dsm
30 | :return: cdsm loss of shape (batch_size,)
31 | """
32 | samples.requires_grad_(True)
33 | vector = torch.randn_like(samples) * sigma
34 | perturbed_inputs = samples + vector
35 | logp = -energy_net(perturbed_inputs, conditions)
36 | assert logp.ndim == 1
37 | dlogp = sigma ** 2 * autograd.grad(logp.sum(), perturbed_inputs, create_graph=True)[0]
38 | kernel = vector
39 | loss = torch.norm(dlogp + kernel, dim=-1) ** 2
40 | loss = loss.mean() / 2.
41 | return loss
42 |
43 |
44 | def conditional_dsm(energy_net, samples, segLabels, energy_net_final_layer, sigma=1):
45 | samples.requires_grad_(True)
46 | vector = torch.randn_like(samples) * sigma
47 | perturbed_inputs = samples + vector
48 |
49 | d = samples.shape[-1]
50 |
51 | # apply conditioning
52 | logp = -energy_net(perturbed_inputs).view(-1, d * d)
53 | logp = torch.mm(logp, energy_net_final_layer)
54 | # take only relevant segment energy
55 | logp = logp[segLabels]
56 |
57 | dlogp = sigma ** 2 * autograd.grad(logp.sum(), perturbed_inputs, create_graph=True)[0]
58 | kernel = vector
59 | loss = torch.norm(dlogp + kernel, dim=-1) ** 2
60 | loss = loss.mean() / 2.
61 |
62 | return loss
63 |
64 |
65 | def dsm_score_estimation(scorenet, samples, sigma=0.01):
66 | perturbed_samples = samples + torch.randn_like(samples) * sigma
67 | target = - 1 / (sigma ** 2) * (perturbed_samples - samples)
68 | scores = scorenet(perturbed_samples)
69 | target = target.view(target.shape[0], -1)
70 | scores = scores.view(scores.shape[0], -1)
71 | loss = 1 / 2. * ((scores - target) ** 2).sum(dim=-1).mean(dim=0)
72 |
73 | return loss
74 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # IDE files
2 | .idea
3 |
4 | # Byte-compiled / optimized / DLL files
5 | *__pycache__
6 |
7 | # Docker
8 | Dockerfile
9 |
10 | # roms
11 | roms/*
12 | Roms.rar
13 |
14 | __pycache__/
15 | *.py[cod]
16 | *$py.class
17 |
18 | # C extensions
19 | *.so
20 |
21 | # Distribution / packaging
22 | .Python
23 | build/
24 | develop-eggs/
25 | dist/
26 | downloads/
27 | eggs/
28 | .eggs/
29 | lib/
30 | lib64/
31 | parts/
32 | sdist/
33 | var/
34 | wheels/
35 | share/python-wheels/
36 | *.egg-info/
37 | .installed.cfg
38 | *.egg
39 | MANIFEST
40 |
41 | # PyInstaller
42 | # Usually these files are written by a python script from a template
43 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
44 | *.manifest
45 | *.spec
46 |
47 | # Installer logs
48 | pip-log.txt
49 | pip-delete-this-directory.txt
50 |
51 | # Unit test / coverage reports
52 | htmlcov/
53 | .tox/
54 | .nox/
55 | .coverage
56 | .coverage.*
57 | .cache
58 | nosetests.xml
59 | coverage.xml
60 | *.cover
61 | *.py,cover
62 | .hypothesis/
63 | .pytest_cache/
64 | cover/
65 |
66 | # Translations
67 | *.mo
68 | *.pot
69 |
70 | # Django stuff:
71 | *.log
72 | local_settings.py
73 | db.sqlite3
74 | db.sqlite3-journal
75 |
76 | # Flask stuff:
77 | instance/
78 | .webassets-cache
79 |
80 | # Scrapy stuff:
81 | .scrapy
82 |
83 | # Sphinx documentation
84 | docs/_build/
85 |
86 | # PyBuilder
87 | .pybuilder/
88 | target/
89 |
90 | # Jupyter Notebook
91 | .ipynb_checkpoints
92 |
93 | # IPython
94 | profile_default/
95 | ipython_config.py
96 |
97 | # pyenv
98 | # For a library or package, you might want to ignore these files since the code is
99 | # intended to run in multiple environments; otherwise, check them in:
100 | .python-version
101 |
102 | # pipenv
103 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
104 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
105 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
106 | # install all needed dependencies.
107 | #Pipfile.lock
108 |
109 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
110 | __pypackages__/
111 |
112 | # Celery stuff
113 | celerybeat-schedule
114 | celerybeat.pid
115 |
116 | # SageMath parsed files
117 | *.sage.py
118 |
119 | # Environments
120 | .env
121 | .venv
122 | env/
123 | venv/
124 | ENV/
125 | env.bak/
126 | venv.bak/
127 |
128 | # Spyder project settings
129 | .spyderproject
130 | .spyproject
131 |
132 | # Rope project settings
133 | .ropeproject
134 |
135 | # mkdocs documentation
136 | /site
137 |
138 | # mypy
139 | .mypy_cache/
140 | .dmypy.json
141 | dmypy.json
142 |
143 | # Pyre type checker
144 | .pyre/
145 |
146 | # pytype static type analyzer
147 | .pytype/
148 |
149 | # Cython debug symbols
150 | cython_debug/
151 |
152 | # pytorch files
153 | *.pt
154 |
--------------------------------------------------------------------------------
/baseline_models/icebeem/models/ivae/ivae_wrapper.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 |
4 | import numpy as np
5 | import torch
6 | from torch import optim
7 | from torch.utils.data import DataLoader
8 |
9 | from data.imca import ConditionalDataset
10 | from .ivae_core import iVAE
11 |
12 |
13 |
14 |
15 | def IVAE_wrapper(X, U, latent_dim, batch_size=256, max_iter=7e4, seed=0, n_layers=3, hidden_dim=20, lr=1e-3, cuda=True,
16 | ckpt_folder='ivae.pt', architecture="ivae", logger=None, time_limit=None, learn_decoder_var=False):
17 | " args are the arguments from the main.py file"
18 | torch.manual_seed(seed)
19 | np.random.seed(seed)
20 |
21 | device = torch.device('cuda:0' if cuda else 'cpu')
22 | print('training on {}'.format(torch.cuda.get_device_name(device) if cuda else 'cpu'))
23 |
24 | # load data
25 | # print('Creating shuffled dataset..')
26 | dset = ConditionalDataset(X.astype(np.float32), U.astype(np.float32), device)
27 | loader_params = {'num_workers': 1, 'pin_memory': True, "generator": torch.Generator(device=device)} if cuda else {}
28 | train_loader = DataLoader(dset, shuffle=True, batch_size=batch_size, **loader_params)
29 | data_dim, _, aux_dim = dset.get_dims()
30 | N = len(dset)
31 | max_epochs = int(max_iter // len(train_loader) + 1)
32 |
33 | # define model and optimizer
34 | # print('Defining model and optimizer..')
35 | model = iVAE(latent_dim, data_dim, aux_dim, activation='lrelu', device=device,
36 | n_layers=n_layers, hidden_dim=hidden_dim, architecture=architecture, learn_decoder_var=learn_decoder_var)
37 | optimizer = optim.Adam(model.parameters(), lr=lr)
38 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=20, verbose=True)
39 | # training loop
40 | print("Training..")
41 |
42 | # timer
43 | if time_limit is not None:
44 | time_limit = time_limit * 60 * 60 # convert to seconds
45 | t0 = time.time()
46 | else:
47 | time_limit = np.inf
48 | t0 = 0
49 |
50 | it = 0
51 | model.train()
52 | while it < max_iter and time.time() - t0 < time_limit :
53 | elbo_train = 0
54 | epoch = it // len(train_loader) + 1
55 | for _, (x, u) in enumerate(train_loader):
56 | it += 1
57 | optimizer.zero_grad()
58 | x, u = x.to(device), u.to(device)
59 | elbo, z_est = model.elbo(x, u)
60 | elbo.mul(-1).backward()
61 | optimizer.step()
62 | elbo_train += -elbo.item()
63 | if logger is not None and it%100 == 0:
64 | metrics = {"loss_train": -elbo.item()}
65 | logger.log_metrics(step=it, metrics=metrics)
66 | if it > max_iter or time.time() - t0 > time_limit :
67 | break
68 | elbo_train /= len(train_loader)
69 | scheduler.step(elbo_train)
70 |
71 | #print('epoch {}/{} \tloss: {}'.format(epoch, max_epochs, elbo_train))
72 | # save model checkpoint after training
73 | torch.save(model.state_dict(), os.path.join(ckpt_folder, 'ivae.pt'))
74 |
75 | return model
76 |
--------------------------------------------------------------------------------
/baseline_models/slowvae_pcl/README.md:
--------------------------------------------------------------------------------
1 | # Towards Nonlinear Disentanglement in Natural Data with Temporal Sparse Coding
2 |
3 | This repository contains the code release for:
4 |
5 | **Towards Nonlinear Disentanglement in Natural Data with Temporal Sparse Coding.**
6 | David Klindt*, Lukas Schott*, Yash Sharma*, Ivan Ustyuzhaninov, Wieland Brendel, Matthias Bethge†, Dylan Paiton†
7 | https://arxiv.org/abs/2007.10930
8 |
9 | An example latent traversal using our learned model:
10 | 
11 |
12 |
13 | **Abstract:** We construct an unsupervised learning model that achieves nonlinear disentanglement of underlying factors of variation in naturalistic videos. Previous work suggests that representations can be disentangled if all but a few factors in the environment stay constant at any point in time. As a result, algorithms proposed for this problem have only been tested on carefully constructed datasets with this exact property, leaving it unclear whether they will transfer to natural scenes. Here we provide evidence that objects in segmented natural movies undergo transitions that are typically small in magnitude with occasional large jumps, which is characteristic of a temporally sparse distribution. We leverage this finding and present SlowVAE, a model for unsupervised representation learning that uses a sparse prior on temporally adjacent observations to disentangle generative factors without any assumptions on the number of changing factors. We provide a proof of identifiability and show that the model reliably learns disentangled representations on several established benchmark datasets, often surpassing the current state-of-the-art. We additionally demonstrate transferability towards video datasets with natural dynamics, Natural Sprites and KITTI Masks, which we contribute as benchmarks for guiding disentanglement research towards more natural data domains.
14 |
15 | ### Cite
16 | If you make use of this code in your own work, please cite our paper:
17 | ```
18 | @inproceedings{klindt2021towards,
19 | title={Towards Nonlinear Disentanglement in Natural Data with Temporal Sparse Coding},
20 | author={David A. Klindt and Lukas Schott and Yash Sharma and Ivan Ustyuzhaninov and Wieland Brendel and Matthias Bethge and Dylan Paiton},
21 | booktitle={International Conference on Learning Representations},
22 | year={2021},
23 | url={https://openreview.net/forum?id=EbIDjBynYJ8}
24 | }
25 | ```
26 |
27 | ### Datasets
28 | Our work also contributes two new datasets.
29 | The Natural Sprites dataset can be downloaded here: https://zenodo.org/record/3948069
30 | The KITTI Masks dataset can be downloaded here: https://zenodo.org/record/3931823
31 |
32 |
33 | ### Acknowledgements
34 |
35 | The repository is based on the following [Beta-VAE reproduction](https://github.com/1Konny/Beta-VAE). The MCC metric was adopted from the [Time-Contrastive Learning release](https://github.com/hirosm/TCL).
36 |
37 | ### Contact
38 |
39 | - Maintainers: [David Klindt](https://github.com/david-klindt) & [Lukas Schott](https://github.com/lukas-schott) & [Yash Sharma](https://github.com/ysharma1126)
40 |
--------------------------------------------------------------------------------
/baseline_models/slowvae_pcl/data_generation/gen_youtube_csv.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import os
3 | from coco.coco import COCO
4 | import coco.mask as maskUtils
5 | import numpy as np
6 | from scipy import ndimage
7 | from scipy.misc import imresize
8 | from tqdm import tqdm
9 |
10 | def myRange(start,end,step):
11 | i = start
12 | while i < end:
13 | yield i
14 | i += step
15 | yield end
16 |
17 | def main(args, data_dir):
18 | downscale = args.downscale
19 | keepaspectratio = args.keepaspect
20 | stride = args.stride
21 | annIds = coco.getAnnIds()
22 | anns = coco.loadAnns(annIds)
23 | max_len = np.max([len(x['segmentations']) for x in anns])
24 | file_name = '{}{}{}'.format('downscale_' if args.downscale else '',
25 | 'keepaspect_' if args.keepaspect else '',
26 | 'stride_{}_'.format(args.stride) if args.stride != 32 else '')
27 | with open('{}.csv'.format(file_name.rstrip('_')), mode='w') as movie_file:
28 | movie_writer = csv.writer(movie_file, delimiter=',')
29 | movie_writer.writerow(['id','cat_id'] + ['t_{}'.format(i) for i in range(max_len)])
30 | for ann in tqdm(anns) if args.verbose else anns:
31 | vals = []
32 | for seg in ann['segmentations']:
33 | if seg:
34 | rle = maskUtils.frPyObjects([seg], ann['height'], ann['width'])
35 | mask = np.squeeze(maskUtils.decode(rle))
36 | if mask.any():
37 | if downscale:
38 | if keepaspectratio:
39 | tr_mask = imresize(mask, (64,128))
40 | tr_masks = []
41 | window_idxes = list(myRange(0,64,stride))
42 | for i in range(len(window_idxes)):
43 | tr_masks.append(tr_mask[:,window_idxes[i]:window_idxes[i]+64])
44 | else:
45 | tr_masks=[imresize(mask, (64,64))]
46 | else:
47 | tr_masks=[mask]
48 | temp = []
49 | for tr_mask in tr_masks:
50 | if tr_mask.any():
51 | com_val = np.array(ndimage.measurements.center_of_mass(tr_mask)).astype(np.float).tolist()
52 | y_val = com_val[0]
53 | x_val = com_val[1]
54 | rle = maskUtils.encode(np.asfortranarray(tr_mask))
55 | area_val = maskUtils.area(rle).astype(np.int)
56 | temp.append((y_val,x_val,area_val))
57 | else:
58 | temp.append(None)
59 | vals.append(temp)
60 | else:
61 | vals.append(None)
62 | else:
63 | vals.append(None)
64 | movie_writer.writerow([ann['id'],ann['category_id']] + vals)
65 |
66 |
67 | if __name__ == "__main__":
68 | import argparse
69 | parser = argparse.ArgumentParser()
70 | parser.add_argument('--verbose', action='store_true')
71 | parser.add_argument('--stride', type=int, default=32)
72 | parser.add_argument('--downscale', action='store_true')
73 | parser.add_argument('--keepaspect', action='store_true')
74 | args = parser.parse_args()
75 | #download data from https://competitions.codalab.org/competitions/20127#participate-get-data
76 | data_dir = './data/youtube_voc'
77 | annFile= os.path.join(data_dir, 'train.json')
78 | # initialize COCO api for instance annotations
79 | coco=COCO(annFile)
80 | main(args, data_dir)
--------------------------------------------------------------------------------
/scripts/produce_all_logs.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import json
4 | import math
5 | import os
6 | import random
7 | import shutil
8 | import sys
9 | # sys.path.insert(0, os.path.abspath('..'))
10 | from copy import deepcopy
11 |
12 | import numpy as np
13 |
14 | import time
15 |
16 | def load_hparams(folder_path):
17 | with open(os.path.join(folder_path, 'hparams.json'), 'r') as infile:
18 | opt = json.load(infile)
19 | return opt
20 | # class Bunch:
21 | # def __init__(self, opt):
22 | # self.__dict__.update(opt)
23 | #return Bunch(opt)
24 |
25 | def load_metrics(folder_path):
26 | with open(os.path.join(folder_path, 'log.ndjson'), 'r') as infile:
27 | metrics = json.loads(infile.readlines()[-1]) # reading last line
28 | return metrics
29 |
30 | def main(args=None):
31 | ROOT = '/network/scratch/l/lachaseb/identifiable_latent_causal_model/exp/'
32 | #META_EXPS = {"temporal_ilcm": "e0b63788-0eeb-11ee-8f41-424954050100",
33 | # "temporal_random": "0b4e90b8-0f81-11ee-b8a0-424954050300",
34 | # "temporal_supervised": "e3baec86-0f80-11ee-9b91-424954050300",
35 | # "temporal_tcvae": "25e9291a-0f95-11ee-9cf4-424954050100",
36 | # "temporal_pcl": "81168a28-106f-11ee-bb87-424954050200",
37 | # "temporal_slowvae": "ee3dd866-1071-11ee-ab0d-424954050200",
38 | # "action_ilcm": "4472ac22-1124-11ee-80d1-48df37d42c20",
39 | # "action_random": "f668798e-1124-11ee-86e1-48df37d42c20",
40 | # "action_supervised": "f81f68f4-1125-11ee-9e63-48df37d42c20",
41 | # "action_tcvae": "fc4313ee-1504-11ee-bca8-d8c497b83240",
42 | # "action_ivae": "5d35b610-1119-11ee-bc8f-424954050100",
43 | # }
44 | #META_EXPS = {"temporal_ilcm": "71fee648-4e81-11ee-9556-424954050200",
45 | # "action_ilcm": "3645f1aa-4e81-11ee-b87d-424954050200",
46 | # }
47 | #META_EXPS = {"temporal_ilcm": "462a70b4-5318-11ee-8099-424954050100",
48 | # "action_ilcm": "66143ee0-5332-11ee-b033-424954050100",
49 | # }
50 | #META_EXPS = {"temporal_ilcm": "95b2eb58-53ee-11ee-929b-424954050300",
51 | # "action_ilcm": "3d959e64-5404-11ee-8e54-424954050300",
52 | # }
53 | #META_EXPS = {"temporal_ilcm": "4d5d4c0a-54b0-11ee-be29-424954050100",
54 | # "action_ilcm": "f854193c-54af-11ee-be2a-424954050100",
55 | # }
56 | META_EXPS = {"temporal_ilcm": "2cdcae28-58b9-11ee-befb-424954050100",
57 | "action_ilcm": "77382470-58b9-11ee-a8ff-424954050100",
58 | }
59 |
60 |
61 |
62 | all_logs = []
63 | for name, meta_exp in META_EXPS.items():
64 | meta_exp_path = os.path.join(ROOT, meta_exp)
65 | num_exps = 0
66 | for exp in os.listdir(meta_exp_path):
67 | exp_path = os.path.join(meta_exp_path, exp)
68 | if not os.path.isdir(exp_path):
69 | continue
70 |
71 | # verify if run is completed
72 | if not os.path.exists(os.path.join(exp_path, "z_hat_final.npy")):
73 | print(f"{name}: Run {meta_exp}/{exp} is not completed thus excluded from all_logs file")
74 | #log = load_hparams(exp_path)
75 | #print("beta", log["beta"])
76 | continue
77 |
78 | num_exps += 1
79 |
80 | log = load_hparams(exp_path)
81 | metrics = load_metrics(exp_path)
82 | log.update(metrics)
83 |
84 | all_logs.append(log)
85 | print(f"{name}: Done with {num_exps} experiments.")
86 |
87 | save_path = os.path.join(ROOT, "all_logs_23sep2023_JMLR_const_rand_g.npy")
88 | np.save(save_path, all_logs)
89 | print("Saved to:", save_path)
90 |
91 | if __name__ == "__main__":
92 | main()
93 |
--------------------------------------------------------------------------------
/baseline_models/slowvae_pcl/mcc_metric/metric.py:
--------------------------------------------------------------------------------
1 | """Mean Correlation Coefficient from Hyvarinen & Morioka
2 | """
3 | from absl import logging
4 | from disentanglement_lib.evaluation.metrics import utils
5 | import numpy as np
6 | import gin.tf
7 | import scipy as sp
8 | from mcc_metric.munkres import Munkres
9 |
10 | def correlation(x, y, method='Pearson'):
11 | """Evaluate correlation
12 | Args:
13 | x: data to be sorted
14 | y: target data
15 | Returns:
16 | corr_sort: correlation matrix between x and y (after sorting)
17 | sort_idx: sorting index
18 | x_sort: x after sorting
19 | method: correlation method ('Pearson' or 'Spearman')
20 | """
21 |
22 | print("Calculating correlation...")
23 |
24 | x = x.copy()
25 | y = y.copy()
26 | dim = x.shape[0]
27 |
28 | # Calculate correlation -----------------------------------
29 | if method=='Pearson':
30 | corr = np.corrcoef(y, x)
31 | corr = corr[0:dim,dim:]
32 | elif method=='Spearman':
33 | corr, pvalue = sp.stats.spearmanr(y.T, x.T)
34 | corr = corr[0:dim, dim:]
35 |
36 | # Sort ----------------------------------------------------
37 | munk = Munkres()
38 | indexes = munk.compute(-np.absolute(corr))
39 |
40 | sort_idx = np.zeros(dim)
41 | x_sort = np.zeros(x.shape)
42 | for i in range(dim):
43 | sort_idx[i] = indexes[i][1]
44 | x_sort[i,:] = x[indexes[i][1],:]
45 |
46 | # Re-calculate correlation --------------------------------
47 | if method=='Pearson':
48 | corr_sort = np.corrcoef(y, x_sort)
49 | corr_sort = corr_sort[0:dim,dim:]
50 | elif method=='Spearman':
51 | corr_sort, pvalue = sp.stats.spearmanr(y.T, x_sort.T)
52 | corr_sort = corr_sort[0:dim, dim:]
53 |
54 | return corr_sort, sort_idx, x_sort
55 |
56 |
57 | @gin.configurable(
58 | "mcc",
59 | blacklist=["ground_truth_data", "representation_function", "random_state",
60 | "artifact_dir"])
61 | def compute_mcc(ground_truth_data,
62 | representation_function,
63 | random_state,
64 | artifact_dir=None,
65 | num_train=gin.REQUIRED,
66 | correlation_fn=gin.REQUIRED,
67 | batch_size=16):
68 | """Computes the mean correlation coefficient.
69 |
70 | Args:
71 | ground_truth_data: GroundTruthData to be sampled from.
72 | representation_function: Function that takes observations as input and
73 | outputs a dim_representation sized representation for each observation.
74 | random_state: Numpy random state used for randomness.
75 | artifact_dir: Optional path to directory where artifacts can be saved.
76 | num_train: Number of points used for training.
77 | batch_size: Batch size for sampling.
78 |
79 | Returns:
80 | Dict with mcc stats
81 | """
82 | del artifact_dir
83 | logging.info("Generating training set.")
84 | mus_train, ys_train = utils.generate_batch_factor_code(
85 | ground_truth_data, representation_function, num_train,
86 | random_state, batch_size)
87 | assert mus_train.shape[1] == num_train
88 | return _compute_mcc(mus_train, ys_train, correlation_fn, random_state)
89 |
90 |
91 | def _compute_mcc(mus_train, ys_train, correlation_fn, random_state):
92 | """Computes score based on both training and testing codes and factors."""
93 | score_dict = {}
94 | result = np.zeros(mus_train.shape)
95 | result[:ys_train.shape[0],:ys_train.shape[1]] = ys_train
96 | for i in range(len(mus_train) - len(ys_train)):
97 | result[ys_train.shape[0] + i, :] = random_state.normal(size=ys_train.shape[1])
98 | corr_sorted, sort_idx, mu_sorted = correlation(mus_train, result, method=correlation_fn)
99 | score_dict["meanabscorr"] = np.mean(np.abs(np.diag(corr_sorted)[:len(ys_train)]))
100 | for i in range(len(corr_sorted)):
101 | for j in range(len(corr_sorted[0])):
102 | score_dict["corr_sorted_{}{}".format(i,j)] = corr_sorted[i][j]
103 | for i in range(len(sort_idx)):
104 | score_dict["sort_idx_{}".format(i)] = sort_idx[i]
105 | return score_dict
106 |
107 |
--------------------------------------------------------------------------------
/model/nn.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | import numpy as np
6 | from torch.nn.utils import spectral_norm as sn
7 |
8 |
9 | class MLP(torch.nn.Module):
10 | def __init__(self, ni, no, nhidden, nlayers, spectral_norm=False, batch_norm=True):
11 | super().__init__()
12 | self.nlayers = nlayers
13 | self.batch_norm = batch_norm
14 | for i in range(nlayers):
15 | if i == 0:
16 | if spectral_norm:
17 | setattr(self, "linear%d" % i, sn(torch.nn.Linear(ni, nhidden, bias=(not batch_norm))))
18 | else:
19 | setattr(self, "linear%d" % i, torch.nn.Linear(ni, nhidden, bias=(not batch_norm)))
20 |
21 | else:
22 | if spectral_norm:
23 | setattr(self, "linear%d" % i, sn(torch.nn.Linear(nhidden, nhidden, bias=(not batch_norm))))
24 | else:
25 | setattr(self, "linear%d" % i, torch.nn.Linear(nhidden, nhidden, bias=(not batch_norm)))
26 | if batch_norm:
27 | setattr(self, "bn%d" % i, torch.nn.BatchNorm1d(nhidden))
28 | if nlayers == 0:
29 | nhidden = ni
30 | self.linear_out = torch.nn.Linear(nhidden, no)
31 |
32 | def forward(self, x):
33 | for i in range(self.nlayers):
34 | linear = getattr(self, "linear%d" % i)
35 | x = linear(x)
36 | if self.batch_norm:
37 | bn = getattr(self, "bn%d" % i)
38 | x = bn(x)
39 | x = F.leaky_relu(x, 0.2, True)
40 | return self.linear_out(x)
41 |
42 |
43 | class ParallelMLP(torch.nn.Module):
44 | def __init__(self, ni, no, nhidden, nlayers, nMLPs, bn=True):
45 | super().__init__()
46 | self.nlayers = nlayers
47 | self.nMLPs = nMLPs
48 | self.bn = bn
49 | for i in range(nlayers):
50 | if i == 0:
51 | setattr(self, "linear%d" % i, ParallelLinear(ni, nhidden, nMLPs, bias=False))
52 | else:
53 | setattr(self, "linear%d" % i, ParallelLinear(nhidden, nhidden, nMLPs, bias=False))
54 | if self.bn:
55 | setattr(self, "bn%d" % i, torch.nn.BatchNorm1d(nhidden * nMLPs))
56 | if nlayers == 0:
57 | nhidden = ni
58 | self.linear_out = ParallelLinear(nhidden, no, nMLPs)
59 |
60 | def forward(self, x):
61 | assert self.nMLPs == x.shape[1]
62 | bs, nMLPs, ni = x.shape
63 | for i in range(self.nlayers):
64 | linear = getattr(self, "linear%d" % i)
65 | x = linear(x)
66 | # this `reshape` instead of `view` call is necessary since x is not contiguous.
67 | if self.bn:
68 | bn = getattr(self, "bn%d" % i)
69 | x = bn(x.reshape(bs, -1)).view(bs, nMLPs, -1) # TODO: should I worry about the copy made by reshape?
70 | x = F.leaky_relu(x, 0.2, True)
71 | return self.linear_out(x)
72 |
73 |
74 | class ParallelLinear(torch.nn.Module):
75 | def __init__(self, in_features, out_features, num_linears, bias=True):
76 | super(ParallelLinear, self).__init__()
77 | self.in_features = in_features
78 | self.out_features = out_features
79 | self.num_linears = num_linears
80 | self.weight = torch.nn.Parameter(torch.Tensor(num_linears, out_features, in_features))
81 | if bias:
82 | self.bias = torch.nn.Parameter(torch.Tensor(num_linears, out_features))
83 | else:
84 | self.register_parameter('bias', None)
85 | self.reset_parameters()
86 |
87 | def reset_parameters(self):
88 | for linear in range(self.num_linears):
89 | torch.nn.init.kaiming_uniform_(self.weight[linear], a=math.sqrt(5))
90 | if self.bias is not None:
91 | fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight[linear])
92 | bound = 1 / math.sqrt(fan_in)
93 | torch.nn.init.uniform_(self.bias, -bound, bound)
94 |
95 | def forward(self, input):
96 | # input shape: (bs, num_linears, in_features)
97 | x = torch.einsum("bli,lji->blj", input, self.weight)
98 | if self.bias is not None:
99 | x = x + self.bias
100 | return x # (bs, num_linears, out_features)
101 |
102 |
--------------------------------------------------------------------------------
/baseline_models/slowvae_pcl/scripts/evaluate_disentanglement.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import sys
4 | import numpy as np
5 | import gin.tf
6 | gin.enter_interactive_mode()
7 | import time
8 | import os
9 | import json
10 | import traceback
11 | from scripts.model import reparametrize
12 | from scripts.model import BetaVAE_H as BetaVAE
13 | from disentanglement_lib.utils import results
14 |
15 | # needed later:
16 | from disentanglement_lib.evaluation.metrics import beta_vae # pylint: disable=unused-import
17 | from disentanglement_lib.evaluation.metrics import dci # pylint: disable=unused-import
18 | from disentanglement_lib.evaluation.metrics import downstream_task # pylint: disable=unused-import
19 | from disentanglement_lib.evaluation.metrics import factor_vae # pylint: disable=unused-import
20 | from disentanglement_lib.evaluation.metrics import fairness # pylint: disable=unused-import
21 | from disentanglement_lib.evaluation.metrics import irs # pylint: disable=unused-import
22 | from disentanglement_lib.evaluation.metrics import mig # pylint: disable=unused-import
23 | from disentanglement_lib.evaluation.metrics import modularity_explicitness # pylint: disable=unused-import
24 | from disentanglement_lib.evaluation.metrics import reduced_downstream_task # pylint: disable=unused-import
25 | from disentanglement_lib.evaluation.metrics import sap_score # pylint: disable=unused-import
26 | from disentanglement_lib.evaluation.metrics import unsupervised_metrics # pylint: disable=unused-import
27 |
28 | import mcc_metric.metric as mcc # pylint: disable=unused-import
29 |
30 |
31 |
32 | def main(args, dataset):
33 | device = torch.device('cuda' if torch.cuda.is_available() and args.cuda else 'cpu')
34 | net = BetaVAE
35 | net = net(args.z_dim, args.num_channel, args.pcl).to(device)
36 | file_path = os.path.join(args.ckpt_dir, 'last')
37 | checkpoint = torch.load(file_path)
38 | net.load_state_dict(checkpoint['model_states']['net'])
39 |
40 | def mean_rep(x):
41 | distributions = net._encode(
42 | torch.from_numpy(x).float().to(device))
43 | mu = distributions[:, :net.z_dim]
44 | logvar = distributions[:, net.z_dim:]
45 | return np.array(mu.detach().cpu())
46 |
47 | def sample_rep(x):
48 | distributions = net._encode(
49 | torch.from_numpy(x).float().to(device))
50 | mu = distributions[:, :net.z_dim]
51 | logvar = distributions[:, net.z_dim:]
52 | return np.array(reparametrize(mu, logvar).detach().cpu())
53 |
54 | @gin.configurable("evaluation")
55 | def evaluate(post,
56 | output_dir,
57 | evaluation_fn=gin.REQUIRED,
58 | random_seed=gin.REQUIRED,
59 | name=""):
60 | experiment_timer = time.time()
61 | assert post == 'mean' or post == 'sampled'
62 | results_dict = evaluation_fn(
63 | dataset,
64 | mean_rep if post == 'mean' else sample_rep,
65 | random_state=np.random.RandomState(random_seed))
66 | results_dict["elapsed_time"] = time.time() - experiment_timer
67 | results.update_result_directory(output_dir, "evaluation", results_dict)
68 |
69 | random_state = np.random.RandomState(0)
70 | config_dir = 'metric_configs'
71 | eval_config_files = [f for f in os.listdir(config_dir) if not (f.startswith('.') or 'others' in f)]
72 | t0 = time.time()
73 | posts = ['mean']
74 | for post in posts:
75 | for eval_config in eval_config_files:
76 | metric_name = os.path.basename(eval_config).replace(".gin", "")
77 | continuous = False
78 | if args.dataset == 'kittimasks' or (
79 | args.dataset == 'natural' and not args.natural_discrete):
80 | continuous = True
81 | if continuous:
82 | if metric_name != 'mcc':
83 | continue
84 | contains = True
85 | if args.specify:
86 | contains = False
87 | for specific in args.specify.split('_'):
88 | if specific in metric_name:
89 | contains = True
90 | break
91 | if contains:
92 | if args.verbose:
93 | print("Computing metric '{}' on '{}'...".format(metric_name, post))
94 | eval_bindings = [
95 | "evaluation.random_seed = {}".format(random_state.randint(2 ** 32)),
96 | "evaluation.name = '{}'".format(metric_name)
97 | ]
98 | gin.parse_config_files_and_bindings(
99 | [os.path.join(config_dir, eval_config)], eval_bindings)
100 | output_dir = os.path.join(
101 | args.output_dir, 'evaluation', args.ckpt_name, post, metric_name)
102 | evaluate(post, output_dir)
103 | gin.clear_config()
104 | if args.verbose:
105 | print('took', time.time() - t0, 's')
106 | t0 = time.time()
107 |
--------------------------------------------------------------------------------
/baseline_models/icebeem/models/icebeem_wrapper.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import os
3 | import sys
4 | import pathlib
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 | from sklearn.decomposition import FastICA
10 | from torch.distributions import Uniform, TransformedDistribution, SigmoidTransform
11 |
12 | from losses.fce import ConditionalFCE
13 | from .nets import MLP
14 | from .nflib.flows import NormalizingFlowModel, Invertible1x1Conv, ActNorm
15 | from .nflib.spline_flows import NSF_AR
16 |
17 | sys.path.insert(0, str(pathlib.Path(__file__).parent.parent.parent.parent.parent))
18 | from disentanglement_via_mechanism_sparsity.model.nn import MLP as MLP_ilcm
19 |
20 |
21 | def ICEBEEM_wrapper(X, Y, ebm_hidden_size, n_layers_ebm, n_layers_flow, lr_flow, lr_ebm, seed,
22 | ckpt_file='icebeem.pt', test=False):
23 | np.random.seed(seed)
24 | torch.manual_seed(seed)
25 | data_dim = X.shape[1]
26 |
27 | model_ebm = MLP(input_size=data_dim, hidden_size=[ebm_hidden_size] * n_layers_ebm,
28 | n_layers=n_layers_ebm, output_size=data_dim, use_bn=True,
29 | activation_function=F.leaky_relu)
30 |
31 | prior = TransformedDistribution(Uniform(torch.zeros(data_dim), torch.ones(data_dim)),
32 | SigmoidTransform().inv)
33 | nfs_flow = NSF_AR
34 | flows = [nfs_flow(dim=data_dim, K=8, B=3, hidden_dim=16) for _ in range(n_layers_flow)]
35 | convs = [Invertible1x1Conv(dim=data_dim) for _ in flows]
36 | norms = [ActNorm(dim=data_dim) for _ in flows]
37 | flows = list(itertools.chain(*zip(norms, convs, flows)))
38 | # construct the model
39 | model_flow = NormalizingFlowModel(prior, flows)
40 |
41 | pretrain_flow = True
42 | augment_ebm = True
43 |
44 | # instantiate ebmFCE object
45 | fce_ = ConditionalFCE(data=X.astype(np.float32), segments=Y.astype(np.float32),
46 | energy_MLP=model_ebm, flow_model=model_flow, verbose=False)
47 |
48 | init_ckpt_file = os.path.splitext(ckpt_file)[0] + '_0' + os.path.splitext(ckpt_file)[1]
49 | if not test:
50 | if pretrain_flow:
51 | # print('pretraining flow model..')
52 | fce_.pretrain_flow_model(epochs=1, lr=1e-4)
53 | # print('pretraining done.')
54 |
55 | # first we pretrain the final layer of EBM model (this is g(y) as it depends on segments)
56 | fce_.train_ebm_fce(epochs=15, augment=augment_ebm, finalLayerOnly=True, cutoff=.5)
57 |
58 | # then train full EBM via NCE with flow contrastive noise:
59 | fce_.train_ebm_fce(epochs=50, augment=augment_ebm, cutoff=.5, useVAT=False)
60 |
61 | torch.save({'ebm_mlp': fce_.energy_MLP.state_dict(),
62 | 'ebm_finalLayer': fce_.ebm_finalLayer,
63 | 'flow': fce_.flow_model.state_dict()}, init_ckpt_file)
64 | else:
65 | state = torch.load(init_ckpt_file, map_location=fce_.device)
66 | fce_.energy_MLP.load_state_dict(state['ebm_mlp'])
67 | fce_.ebm_finalLayer = state['ebm_finalLayer']
68 | fce_.flow_model.load_stat_dict(state['flow'])
69 |
70 | # evaluate recovery of latents
71 | recov = fce_.unmixSamples(X, modelChoice='ebm')
72 | source_est_ica = FastICA().fit_transform((recov))
73 | recov_sources = [source_est_ica]
74 |
75 | # iterate between updating noise and tuning the EBM
76 | eps = .025
77 | for iter_ in range(3):
78 | mid_ckpt_file = os.path.splitext(ckpt_file)[0] + '_' + str(iter_ + 1) + os.path.splitext(ckpt_file)[1]
79 | if not test:
80 | # update flow model:
81 | fce_.train_flow_fce(epochs=5, objConstant=-1., cutoff=.5 - eps, lr=lr_flow)
82 | # update energy based model:
83 | fce_.train_ebm_fce(epochs=50, augment=augment_ebm, cutoff=.5 + eps, lr=lr_ebm, useVAT=False)
84 |
85 | torch.save({'ebm_mlp': fce_.energy_MLP.state_dict(),
86 | 'ebm_finalLayer': fce_.ebm_finalLayer,
87 | 'flow': fce_.flow_model.state_dict()}, mid_ckpt_file)
88 | else:
89 | state = torch.load(mid_ckpt_file, map_location=fce_.device)
90 | fce_.energy_MLP.load_state_dict(state['ebm_mlp'])
91 | fce_.ebm_finalLayer = state['ebm_finalLayer']
92 | fce_.flow_model.load_stat_dict(state['flow'])
93 |
94 | # evaluate recovery of latents
95 | recov = fce_.unmixSamples(X, modelChoice='ebm')
96 | source_est_ica = FastICA().fit_transform((recov))
97 | recov_sources.append(source_est_ica)
98 |
99 | return recov_sources
100 |
--------------------------------------------------------------------------------
/baseline_models/icebeem/models/tcl/tcl_eval.py:
--------------------------------------------------------------------------------
1 | """ Fuctions for evaluation
2 | This software includes the work that is distributed in the Apache License 2.0
3 | """
4 |
5 | import sys
6 |
7 | import numpy as np
8 | import tensorflow as tf
9 | from sklearn.metrics import confusion_matrix
10 |
11 |
12 | # =============================================================
13 | # =============================================================
14 | def get_tensor(x, vars, sess, data_holder, batch=256):
15 | """Get tensor data .
16 | Args:
17 | x: input data [Ndim, Ndata]
18 | vars: tensors (list)
19 | sess: session
20 | data_holder: data holder
21 | batch: batch size
22 | Returns:
23 | y: value of tensors
24 | """
25 |
26 | Ndata = x.shape[1]
27 | if batch is None:
28 | Nbatch = Ndata
29 | else:
30 | Nbatch = batch
31 | Niter = int(np.ceil(Ndata / Nbatch))
32 |
33 | if not isinstance(vars, list):
34 | vars = [vars]
35 |
36 | # Convert names to tensors (if necessary) -----------------
37 | for i in range(len(vars)):
38 | if not tf.is_numeric_tensor(vars[i]) and isinstance(vars[i], str):
39 | vars[i] = tf.get_default_graph().get_tensor_by_name(vars[i])
40 |
41 | # Start batch-inputs --------------------------------------
42 | y = {}
43 | for iter in range(Niter):
44 |
45 | sys.stdout.write('\r>> Getting tensors... %d/%d' % (iter + 1, Niter))
46 | sys.stdout.flush()
47 |
48 | # Get batch -------------------------------------------
49 | batchidx = np.arange(Nbatch * iter, np.minimum(Nbatch * (iter + 1), Ndata))
50 | xbatch = x[:, batchidx].T
51 |
52 | # Get tensor data -------------------------------------
53 | feed_dict = {data_holder: xbatch}
54 | ybatch = sess.run(vars, feed_dict=feed_dict)
55 |
56 | # Storage
57 | for tn in range(len(ybatch)):
58 | # Initialize
59 | if iter == 0:
60 | y[tn] = np.zeros([Ndata] + list(ybatch[tn].shape[1:]), dtype=np.float32)
61 | # Store
62 | y[tn][batchidx,] = ybatch[tn]
63 |
64 | sys.stdout.write('\r\n')
65 |
66 | return y
67 |
68 |
69 | # =============================================================
70 | # =============================================================
71 | def calc_accuracy(pred, label, normalize_confmat=True):
72 | """ Calculate accuracy and confusion matrix
73 | Args:
74 | pred: [Ndata x Nlabel]
75 | label: [Ndata x Nlabel]
76 | Returns:
77 | accuracy: accuracy
78 | conf: confusion matrix
79 | """
80 |
81 | # print("Calculating accuracy...")
82 |
83 | # Accuracy ------------------------------------------------
84 | correctflag = pred.reshape(-1) == label.reshape(-1)
85 | accuracy = np.mean(correctflag)
86 |
87 | # Confusion matrix ----------------------------------------
88 | conf = confusion_matrix(label[:], pred[:]).astype(np.float32)
89 | # Normalization
90 | if normalize_confmat:
91 | for i in range(conf.shape[0]):
92 | conf[i, :] = conf[i, :] / np.sum(conf[i, :])
93 |
94 | return accuracy, conf
95 |
96 | # =============================================================
97 | # =============================================================
98 | # def correlation(x, y, method='Pearson'):
99 | # """Evaluate correlation
100 | # Args:
101 | # x: data to be sorted
102 | # y: target data
103 | # Returns:
104 | # corr_sort: correlation matrix between x and y (after sorting)
105 | # sort_idx: sorting index
106 | # x_sort: x after sorting
107 | # """
108 | #
109 | # print("Calculating correlation...")
110 | #
111 | # x = x.copy()
112 | # y = y.copy()
113 | # dim = x.shape[0]
114 | #
115 | # # Calculate correlation -----------------------------------
116 | # if method=='Pearson':
117 | # corr = np.corrcoef(y, x)
118 | # corr = corr[0:dim,dim:]
119 | # elif method=='Spearman':
120 | # corr, pvalue = sp.stats.spearmanr(y.T, x.T)
121 | # corr = corr[0:dim, dim:]
122 | #
123 | # # Sort ----------------------------------------------------
124 | # munk = Munkres()
125 | # indexes = munk.compute(-np.absolute(corr))
126 | #
127 | # sort_idx = np.zeros(dim)
128 | # x_sort = np.zeros(x.shape)
129 | # for i in range(dim):
130 | # sort_idx[i] = indexes[i][1]
131 | # x_sort[i,:] = x[indexes[i][1],:]
132 | #
133 | # # Re-calculate correlation --------------------------------
134 | # if method=='Pearson':
135 | # corr_sort = np.corrcoef(y, x_sort)
136 | # corr_sort = corr_sort[0:dim,dim:]
137 | # elif method=='Spearman':
138 | # corr_sort, pvalue = sp.stats.spearmanr(y.T, x_sort.T)
139 | # corr_sort = corr_sort[0:dim, dim:]
140 | #
141 | # return corr_sort, sort_idx, x_sort
142 |
--------------------------------------------------------------------------------
/baseline_models/slowvae_pcl/data_generation/coco/mask.py:
--------------------------------------------------------------------------------
1 | __author__ = 'tsungyi'
2 |
3 | import _mask as _mask
4 |
5 | # Interface for manipulating masks stored in RLE format.
6 | #
7 | # RLE is a simple yet efficient format for storing binary masks. RLE
8 | # first divides a vector (or vectorized image) into a series of piecewise
9 | # constant regions and then for each piece simply stores the length of
10 | # that piece. For example, given M=[0 0 1 1 1 0 1] the RLE counts would
11 | # be [2 3 1 1], or for M=[1 1 1 1 1 1 0] the counts would be [0 6 1]
12 | # (note that the odd counts are always the numbers of zeros). Instead of
13 | # storing the counts directly, additional compression is achieved with a
14 | # variable bitrate representation based on a common scheme called LEB128.
15 | #
16 | # Compression is greatest given large piecewise constant regions.
17 | # Specifically, the size of the RLE is proportional to the number of
18 | # *boundaries* in M (or for an image the number of boundaries in the y
19 | # direction). Assuming fairly simple shapes, the RLE representation is
20 | # O(sqrt(n)) where n is number of pixels in the object. Hence space usage
21 | # is substantially lower, especially for large simple objects (large n).
22 | #
23 | # Many common operations on masks can be computed directly using the RLE
24 | # (without need for decoding). This includes computations such as area,
25 | # union, intersection, etc. All of these operations are linear in the
26 | # size of the RLE, in other words they are O(sqrt(n)) where n is the area
27 | # of the object. Computing these operations on the original mask is O(n).
28 | # Thus, using the RLE can result in substantial computational savings.
29 | #
30 | # The following API functions are defined:
31 | # encode - Encode binary masks using RLE.
32 | # decode - Decode binary masks encoded via RLE.
33 | # merge - Compute union or intersection of encoded masks.
34 | # iou - Compute intersection over union between masks.
35 | # area - Compute area of encoded masks.
36 | # toBbox - Get bounding boxes surrounding encoded masks.
37 | # frPyObjects - Convert polygon, bbox, and uncompressed RLE to encoded RLE mask.
38 | #
39 | # Usage:
40 | # Rs = encode( masks )
41 | # masks = decode( Rs )
42 | # R = merge( Rs, intersect=false )
43 | # o = iou( dt, gt, iscrowd )
44 | # a = area( Rs )
45 | # bbs = toBbox( Rs )
46 | # Rs = frPyObjects( [pyObjects], h, w )
47 | #
48 | # In the API the following formats are used:
49 | # Rs - [dict] Run-length encoding of binary masks
50 | # R - dict Run-length encoding of binary mask
51 | # masks - [hxwxn] Binary mask(s) (must have type np.ndarray(dtype=uint8) in column-major order)
52 | # iscrowd - [nx1] list of np.ndarray. 1 indicates corresponding gt image has crowd region to ignore
53 | # bbs - [nx4] Bounding box(es) stored as [x y w h]
54 | # poly - Polygon stored as [[x1 y1 x2 y2...],[x1 y1 ...],...] (2D list)
55 | # dt,gt - May be either bounding boxes or encoded masks
56 | # Both poly and bbs are 0-indexed (bbox=[0 0 1 1] encloses first pixel).
57 | #
58 | # Finally, a note about the intersection over union (iou) computation.
59 | # The standard iou of a ground truth (gt) and detected (dt) object is
60 | # iou(gt,dt) = area(intersect(gt,dt)) / area(union(gt,dt))
61 | # For "crowd" regions, we use a modified criteria. If a gt object is
62 | # marked as "iscrowd", we allow a dt to match any subregion of the gt.
63 | # Choosing gt' in the crowd gt that best matches the dt can be done using
64 | # gt'=intersect(dt,gt). Since by definition union(gt',dt)=dt, computing
65 | # iou(gt,dt,iscrowd) = iou(gt',dt) = area(intersect(gt,dt)) / area(dt)
66 | # For crowd gt regions we use this modified criteria above for the iou.
67 | #
68 | # To compile run "python setup.py build_ext --inplace"
69 | # Please do not contact us for help with compiling.
70 | #
71 | # Microsoft COCO Toolbox. version 2.0
72 | # Data, paper, and tutorials available at: http://mscoco.org/
73 | # Code written by Piotr Dollar and Tsung-Yi Lin, 2015.
74 | # Licensed under the Simplified BSD License [see coco/license.txt]
75 |
76 | iou = _mask.iou
77 | merge = _mask.merge
78 | frPyObjects = _mask.frPyObjects
79 |
80 | def encode(bimask):
81 | if len(bimask.shape) == 3:
82 | return _mask.encode(bimask)
83 | elif len(bimask.shape) == 2:
84 | h, w = bimask.shape
85 | return _mask.encode(bimask.reshape((h, w, 1), order='F'))[0]
86 |
87 | def decode(rleObjs):
88 | if type(rleObjs) == list:
89 | return _mask.decode(rleObjs)
90 | else:
91 | return _mask.decode([rleObjs])[:,:,0]
92 |
93 | def area(rleObjs):
94 | if type(rleObjs) == list:
95 | return _mask.area(rleObjs)
96 | else:
97 | return _mask.area([rleObjs])[0]
98 |
99 | def toBbox(rleObjs):
100 | if type(rleObjs) == list:
101 | return _mask.toBbox(rleObjs)
102 | else:
103 | return _mask.toBbox([rleObjs])[0]
--------------------------------------------------------------------------------
/baseline_models/icebeem/models/ebm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 |
5 | from .nets import CleanMLP
6 |
7 |
8 | class UnnormalizedConditialEBM(nn.Module):
9 | def __init__(self, input_size, hidden_size, n_hidden, output_size, condition_size, activation='lrelu',
10 | augment=False, positive=False):
11 | super().__init__()
12 |
13 | self.input_size = input_size
14 | self.output_size = output_size
15 | self.hidden_size = hidden_size
16 | self.cond_size = condition_size
17 | self.n_hidden = n_hidden
18 | self.activation = activation
19 | self.augment = augment
20 | self.positive = positive
21 |
22 | self.f = CleanMLP(input_size, hidden_size, n_hidden, output_size, activation=activation)
23 | self.g = nn.Linear(condition_size, output_size, bias=False)
24 |
25 | def forward(self, x, y):
26 | fx, gy = self.f(x).view(-1, self.output_size), self.g(y)
27 |
28 | if self.positive:
29 | fx = F.relu(fx)
30 | gy = F.relu(gy)
31 |
32 | if self.augment:
33 | return torch.einsum('bi,bi->b', [fx, gy]) + torch.einsum('bi,bi->b', [fx.pow(2), gy.pow(2)])
34 |
35 | else:
36 | return torch.einsum('bi,bi->b', [fx, gy])
37 |
38 |
39 | class ModularUnnormalizedConditionalEBM(nn.Module):
40 | def __init__(self, f_net, g_net, augment=False, positive=False):
41 | super().__init__()
42 |
43 | assert f_net.output_size == g_net.output_size
44 |
45 | self.input_size = f_net.input_size
46 | self.output_size = f_net.output_size
47 | self.cond_size = g_net.input_size
48 | self.augment = augment
49 | self.positive = positive
50 |
51 | self.f = f_net
52 | self.g = g_net
53 |
54 | def forward(self, x, y):
55 | fx, gy = self.f(x).view(-1, self.output_size), self.g(y)
56 |
57 | if self.positive:
58 | fx = F.relu(fx)
59 | gy = F.relu(gy)
60 |
61 | if self.augment:
62 | return torch.einsum('bi,bi->b', [fx, gy]) + torch.einsum('bi,bi->b', [fx.pow(2), gy.pow(2)])
63 |
64 | else:
65 | return torch.einsum('bi,bi->b', [fx, gy])
66 |
67 |
68 | class ConditionalEBM(UnnormalizedConditialEBM):
69 | def __init__(self, input_size, hidden_size, n_hidden, output_size, condition_size, activation='lrelu'):
70 | super().__init__(input_size, hidden_size, n_hidden, output_size, condition_size, activation)
71 |
72 | self.log_norm = nn.Parameter(torch.randn(1) - 5, requires_grad=True)
73 |
74 | def forward(self, x, y, augment=True, positive=False):
75 | return super().forward(x, y, augment, positive) + self.log_norm
76 |
77 |
78 | class ModularConditionalEBM(ModularUnnormalizedConditionalEBM):
79 | def __init__(self, f_net, g_net):
80 | super().__init__(f_net, g_net)
81 |
82 | self.log_norm = nn.Parameter(torch.randn(1) - 5, requires_grad=True)
83 |
84 | def forward(self, x, y, augment=True, positive=False):
85 | return super().forward(x, y, augment, positive) + self.log_norm
86 |
87 |
88 | class UnnormalizedEBM(nn.Module):
89 | def __init__(self, input_size, hidden_size, n_hidden, output_size, activation='lrelu'):
90 | super().__init__()
91 |
92 | self.input_size = input_size
93 | self.output_size = output_size
94 | self.hidden_size = hidden_size
95 | self.n_hidden = n_hidden
96 | self.activation = activation
97 |
98 | self.f = CleanMLP(input_size, hidden_size, n_hidden, output_size, activation=activation)
99 | self.g = torch.ones(output_size)
100 |
101 | def forward(self, x, y=None):
102 | fx = self.f(x).view(-1, self.output_size)
103 | return torch.einsum('bi,i->b', [fx, self.g])
104 |
105 |
106 | class ModularUnnormalizedEBM(nn.Module):
107 | def __init__(self, f_net):
108 | super().__init__()
109 |
110 | self.input_size = f_net.input_size
111 | self.output_size = f_net.output_size
112 |
113 | self.f = f_net
114 | self.g = torch.ones(self.output_size)
115 |
116 | def forward(self, x, y=None):
117 | fx = self.f(x).view(-1, self.output_size)
118 | return torch.einsum('bi,i->b', [fx, self.g])
119 |
120 |
121 | class EBM(UnnormalizedEBM):
122 | def __init__(self, input_size, hidden_size, n_hidden, output_size, activation='lrelu'):
123 | super().__init__(input_size, hidden_size, n_hidden, output_size, activation)
124 |
125 | self.log_norm = nn.Parameter(torch.randn(1) - 5, requires_grad=True)
126 |
127 | def forward(self, x, y=None):
128 | return super().forward(x, y) + self.log_norm
129 |
130 |
131 | class ModularEBM(ModularUnnormalizedEBM):
132 | def __init__(self, f_net):
133 | super().__init__(f_net)
134 |
135 | self.log_norm = nn.Parameter(torch.randn(1) - 5, requires_grad=True)
136 |
137 | def forward(self, x, y=None):
138 | return super().forward(x, y) + self.log_norm
139 |
--------------------------------------------------------------------------------
/baseline_models/icebeem/models/tcl/tcl_wrapper_gpu.py:
--------------------------------------------------------------------------------
1 | ### Wrapper function for TCL
2 | #
3 | # this code is adapted from: https://github.com/hirosm/TCL
4 | #
5 | #
6 | import os
7 |
8 | import numpy as np
9 | import tensorflow as tf
10 | from sklearn.decomposition import FastICA
11 |
12 | from .tcl_core import inference
13 | from .tcl_core import train_gpu as train
14 | from .tcl_eval import get_tensor, calc_accuracy
15 | from .tcl_preprocessing import pca
16 |
17 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
18 |
19 |
20 | def TCL_wrapper(sensor, label, list_hidden_nodes, random_seed=0, max_steps=int(7e4), max_steps_init=int(7e4),
21 | ckpt_dir='./', test=False):
22 | # Training ----------------------------------------------------
23 | initial_learning_rate = 0.01 # initial learning rate
24 | momentum = 0.9 # momentum parameter of SGD
25 | # max_steps = int(7e4) # number of iterations (mini-batches)
26 | decay_steps = int(5e4) # decay steps (tf.train.exponential_decay)
27 | decay_factor = 0.1 # decay factor (tf.train.exponential_decay)
28 | batch_size = 512 # mini-batch size
29 | moving_average_decay = 0.9999 # moving average decay of variables to be saved
30 | checkpoint_steps = 1e5 # interval to save checkpoint
31 | num_comp = sensor.shape[0]
32 |
33 | # for MLR initialization
34 | decay_steps_init = int(5e4) # decay steps for initializing only MLR
35 |
36 | # Other -------------------------------------------------------
37 | train_dir = ckpt_dir # save directory
38 |
39 | num_segment = len(np.unique(label))
40 |
41 | # Preprocessing -----------------------------------------------
42 | sensor, pca_parm = pca(sensor, num_comp=num_comp)
43 |
44 | if not test:
45 | # Train model (only MLR) --------------------------------------
46 | train(sensor,
47 | label,
48 | num_class=len(np.unique(label)), # num_segment,
49 | list_hidden_nodes=list_hidden_nodes,
50 | initial_learning_rate=initial_learning_rate,
51 | momentum=momentum,
52 | max_steps=max_steps_init, # For init
53 | decay_steps=decay_steps_init, # For init
54 | decay_factor=decay_factor,
55 | batch_size=batch_size,
56 | train_dir=train_dir,
57 | checkpoint_steps=checkpoint_steps,
58 | moving_average_decay=moving_average_decay,
59 | MLP_trainable=False, # For init
60 | save_file='model_init.ckpt', # For init
61 | random_seed=random_seed)
62 |
63 | init_model_path = os.path.join(train_dir, 'model_init.ckpt')
64 |
65 | # Train model -------------------------------------------------
66 | train(sensor,
67 | label,
68 | num_class=len(np.unique(label)), # num_segment,
69 | list_hidden_nodes=list_hidden_nodes,
70 | initial_learning_rate=initial_learning_rate,
71 | momentum=momentum,
72 | max_steps=max_steps,
73 | decay_steps=decay_steps,
74 | decay_factor=decay_factor,
75 | batch_size=batch_size,
76 | train_dir=train_dir,
77 | checkpoint_steps=checkpoint_steps,
78 | moving_average_decay=moving_average_decay,
79 | load_file=init_model_path,
80 | random_seed=random_seed)
81 |
82 | # now that we have trained everything, we can evaluate results:
83 | eval_dir = ckpt_dir
84 | ckpt = tf.train.get_checkpoint_state(eval_dir)
85 |
86 | with tf.Graph().as_default():
87 | data_holder = tf.placeholder(tf.float32, shape=[None, sensor.shape[0]], name='data')
88 |
89 | # Build a Graph that computes the logits predictions from the
90 | # inference model.
91 | logits, feats = inference(data_holder, list_hidden_nodes, num_class=num_segment)
92 |
93 | # Calculate predictions.
94 | top_value, preds = tf.nn.top_k(logits, k=1, name='preds')
95 |
96 | # Restore the moving averaged version of the learned variables for eval.
97 | variable_averages = tf.train.ExponentialMovingAverage(moving_average_decay)
98 | variables_to_restore = variable_averages.variables_to_restore()
99 | saver = tf.train.Saver(variables_to_restore)
100 |
101 | with tf.Session() as sess:
102 | saver.restore(sess, ckpt.model_checkpoint_path)
103 |
104 | tensor_val = get_tensor(sensor, [preds, feats], sess, data_holder, batch=256)
105 | pred_val = tensor_val[0].reshape(-1)
106 | feat_val = tensor_val[1]
107 |
108 | # Calculate accuracy ------------------------------------------
109 | accuracy, confmat = calc_accuracy(pred_val, label)
110 |
111 | # Apply fastICA -----------------------------------------------
112 | ica = FastICA(random_state=random_seed)
113 | feat_val_ica = ica.fit_transform(feat_val)
114 |
115 | feat_val_ica = feat_val_ica.T # Estimated feature
116 | feat_val = feat_val.T
117 |
118 | return feat_val, feat_val_ica, accuracy
119 |
--------------------------------------------------------------------------------
/model/gumbel_masks.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 |
4 |
5 | class GumbelSigmoid(torch.nn.Module):
6 | def __init__(self, shape, freeze=False, drawhard=True, tau=1, one_gumbel_sample=False):
7 | super(GumbelSigmoid, self).__init__()
8 | self.shape = shape
9 | self.freeze = freeze
10 | self.drawhard = drawhard
11 | self.log_alpha = torch.nn.Parameter(torch.zeros(self.shape))
12 | self.tau = tau
13 | self.one_sample_per_batch = one_gumbel_sample
14 | # useful to make sure these parameters will be pushed to the GPU
15 | self.uniform = torch.distributions.uniform.Uniform(0, 1)
16 | self.register_buffer("fixed_mask", torch.ones(shape))
17 | self.reset_parameters()
18 |
19 | def forward(self, bs):
20 | if self.freeze:
21 | y = self.fixed_mask.unsqueeze(0).expand((bs,) + self.shape)
22 | return y
23 | else:
24 | shape = tuple([bs] + list(self.shape))
25 | logistic_noise = self.sample_logistic(shape).type(self.log_alpha.type()).to(self.log_alpha.device)
26 | y_soft = torch.sigmoid((self.log_alpha + logistic_noise) / self.tau)
27 |
28 | if self.drawhard:
29 | y_hard = (y_soft > 0.5).type(y_soft.type())
30 |
31 | # This weird line does two things:
32 | # 1) at forward, we get a hard sample.
33 | # 2) at backward, we differentiate the gumbel sigmoid
34 | y = y_hard.detach() - y_soft.detach() + y_soft
35 |
36 | else:
37 | y = y_soft
38 |
39 | return y
40 |
41 | def get_proba(self):
42 | """Returns probability of getting one"""
43 | if self.freeze:
44 | return self.fixed_mask
45 | else:
46 | return torch.sigmoid(self.log_alpha)
47 |
48 | def reset_parameters(self):
49 | torch.nn.init.constant_(self.log_alpha, 5) # 5) # will yield a probability ~0.99. Inspired by DCDI
50 |
51 | def sample_logistic(self, shape):
52 | if self.one_sample_per_batch:
53 | bs = shape[0]
54 | u = self.uniform.sample([1] + list(self.shape)).expand((bs,) + self.shape)
55 | return torch.log(u) - torch.log(1 - u)
56 | else:
57 | u = self.uniform.sample(shape)
58 | return torch.log(u) - torch.log(1 - u)
59 |
60 | def threshold(self):
61 | proba = self.get_proba()
62 | self.fixed_mask.copy_((proba > 0.5).type(proba.type()))
63 | self.freeze = True
64 |
65 |
66 | class LouizosGumbelSigmoid(torch.nn.Module):
67 | """My implementation of https://openreview.net/pdf?id=H1Y8hhg0b"""
68 | def __init__(self, shape, freeze=False, tau=1, gamma=-0.1, zeta=1.1):
69 | super(LouizosGumbelSigmoid, self).__init__()
70 | self.shape = shape
71 | self.freeze=freeze
72 | self.log_alpha = torch.nn.Parameter(torch.zeros(self.shape))
73 | self.tau = tau
74 | assert gamma < 0 and zeta > 1
75 | self.gamma = gamma
76 | self.zeta = zeta
77 | # useful to make sure these parameters will be pushed to the GPU
78 | self.uniform = torch.distributions.uniform.Uniform(0, 1)
79 | self.register_buffer("fixed_mask", torch.ones(shape))
80 | self.reset_parameters()
81 |
82 | def forward(self, bs):
83 | if self.freeze:
84 | y = self.fixed_mask.unsqueeze(0).expand((bs,) + self.shape)
85 | return y
86 | else:
87 | shape = tuple([bs] + list(self.shape))
88 | logistic_noise = self.sample_logistic(shape).type(self.log_alpha.type()).to(self.log_alpha.device)
89 | y_soft = torch.sigmoid((self.log_alpha + logistic_noise) / self.tau)
90 | y_soft = y_soft * (self.zeta - self.gamma) + self.gamma
91 | one = torch.ones((1,)).to(self.log_alpha.device)
92 | zero = torch.zeros((1,)).to(self.log_alpha.device)
93 | y_soft = torch.minimum(one , torch.maximum(zero, y_soft))
94 |
95 | return y_soft
96 |
97 | def get_proba(self):
98 | """Returns probability of mask being > 0"""
99 | if self.freeze:
100 | return self.fixed_mask
101 | else:
102 | return torch.sigmoid(self.log_alpha - self.tau * (math.log(-self.gamma) - math.log(self.zeta)))
103 |
104 | def reset_parameters(self):
105 | #torch.nn.init.constant_(self.log_alpha, 5) # 5) # will yield a probability ~0.99. Inspired by DCDI
106 | torch.nn.init.constant_(self.log_alpha, self.tau * (math.log(1 - self.gamma) - math.log(self.zeta - 1))) # at init, half the samples will be exactly one.
107 | # general formula is p(M = 1) = sigmoid(log_alpha - beta (log(1-gamma) - log(zeta - 1))
108 | print(f"initialized so that P(mask != 0) = {self.get_proba().view(-1)[0]}")
109 |
110 | def sample_logistic(self, shape):
111 | u = self.uniform.sample(shape)
112 | return torch.log(u) - torch.log(1 - u)
113 |
114 | def threshold(self):
115 | proba = self.get_proba()
116 | self.fixed_mask.copy_((proba > 0.5).type(proba.type()))
117 | self.freeze = True
118 |
--------------------------------------------------------------------------------
/baseline_models/slowvae_pcl/scripts/model.py:
--------------------------------------------------------------------------------
1 | """model.py"""
2 | import sys
3 | import os
4 | import pathlib
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.init as init
9 | from torch.autograd import Variable
10 | import json
11 | import numpy as np
12 |
13 | sys.path.insert(0, str(pathlib.Path(__file__).parent.parent.parent.parent))
14 | from model.nn import MLP as MLP_ilcm
15 |
16 |
17 | def reparametrize(mu, logvar):
18 | std = logvar.div(2).exp() + 1e-6
19 | eps = Variable(std.data.new(std.size()).normal_())
20 | return mu + std*eps
21 |
22 | def compute_kl(z_1, z_2, logvar_1, logvar_2):
23 | var_1 = logvar_1.exp() + 1e-6
24 | var_2 = logvar_2.exp() + 1e-6
25 | return var_1/var_2 + ((z_2-z_1)**2)/var_2 - 1 + logvar_2 - logvar_1
26 |
27 |
28 | class View(nn.Module):
29 | def __init__(self, size):
30 | super(View, self).__init__()
31 | self.size = size
32 |
33 | def forward(self, tensor):
34 | return tensor.view(self.size)
35 |
36 |
37 | class BetaVAE_H(nn.Module):
38 | """Model proposed in original beta-VAE paper(Higgins et al, ICLR, 2017)."""
39 |
40 | def __init__(self, z_dim=10, nc=3, pcl=False, architecture="standard_conv", image_shape=None):
41 | super(BetaVAE_H, self).__init__()
42 | self.pcl = pcl
43 | self.z_dim = z_dim
44 | self.nc = nc
45 | self.architecture = architecture
46 | self.image_shape = image_shape
47 | if self.architecture == "standard_conv":
48 | assert self.image_shape == (nc, 64, 64)
49 | self.encoder = nn.Sequential(
50 | nn.Conv2d(nc, 32, 4, 2, 1), # B, 32, 32, 32
51 | nn.ReLU(True),
52 | nn.Conv2d(32, 32, 4, 2, 1), # B, 32, 16, 16
53 | nn.ReLU(True),
54 | nn.Conv2d(32, 64, 4, 2, 1), # B, 64, 8, 8
55 | nn.ReLU(True),
56 | nn.Conv2d(64, 64, 4, 2, 1), # B, 64, 4, 4
57 | nn.ReLU(True),
58 | nn.Conv2d(64, 256, 4, 1), # B, 256, 1, 1
59 | nn.ReLU(True),
60 | View((-1, 256*1*1)), # B, 256
61 | nn.Linear(256, z_dim if pcl else z_dim*2), # B, z_dim*2
62 | )
63 | self.decoder = nn.Sequential(
64 | nn.Linear(z_dim, 256), # B, 256
65 | View((-1, 256, 1, 1)), # B, 256, 1, 1
66 | nn.ReLU(True),
67 | nn.ConvTranspose2d(256, 64, 4), # B, 64, 4, 4
68 | nn.ReLU(True),
69 | nn.ConvTranspose2d(64, 64, 4, 2, 1), # B, 64, 8, 8
70 | nn.ReLU(True),
71 | nn.ConvTranspose2d(64, 32, 4, 2, 1), # B, 32, 16, 16
72 | nn.ReLU(True),
73 | nn.ConvTranspose2d(32, 32, 4, 2, 1), # B, 32, 32, 32
74 | nn.ReLU(True),
75 | nn.ConvTranspose2d(32, nc, 4, 2, 1), # B, nc, 64, 64
76 | )
77 |
78 | self.weight_init()
79 | elif self.architecture == "ilcm_tabular":
80 | assert len(self.image_shape) == 1
81 | self.encoder = MLP_ilcm(image_shape[0], z_dim if pcl else 2 * z_dim, 512, 6, spectral_norm=False, batch_norm=False)
82 | self.decoder = MLP_ilcm(z_dim, image_shape[0], 512, 6, spectral_norm=False, batch_norm=False)
83 | self.x_logsigma = torch.nn.Parameter(-5 * torch.ones((1,)))
84 |
85 | def weight_init(self):
86 | for block in self._modules:
87 | for m in self._modules[block]:
88 | kaiming_init(m)
89 |
90 | def forward(self, x, return_z=False):
91 | distributions = self._encode(x)
92 | if self.pcl:
93 | return None, distributions, None
94 | else:
95 | mu = distributions[:, :self.z_dim]
96 | logvar = distributions[:, self.z_dim:]
97 | z = reparametrize(mu, logvar)
98 | x_recon = self._decode(z)
99 |
100 | if len(self.image_shape) == 1:
101 | x_recon = (x_recon, torch.nn.functional.softplus(self.x_logsigma) + 1e-6)
102 |
103 | if return_z:
104 | return x_recon, mu, logvar, z
105 | else:
106 | return x_recon, mu, logvar
107 |
108 | def _encode(self, x):
109 | return self.encoder(x)
110 |
111 | def _decode(self, z):
112 | return self.decoder(z)
113 |
114 | def kaiming_init(m):
115 | if isinstance(m, (nn.Linear, nn.Conv2d)):
116 | init.kaiming_normal(m.weight)
117 | if m.bias is not None:
118 | m.bias.data.fill_(0)
119 | elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
120 | m.weight.data.fill_(1)
121 | if m.bias is not None:
122 | m.bias.data.fill_(0)
123 |
124 |
125 | def normal_init(m, mean, std):
126 | if isinstance(m, (nn.Linear, nn.Conv2d)):
127 | m.weight.data.normal_(mean, std)
128 | if m.bias.data is not None:
129 | m.bias.data.zero_()
130 | elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
131 | m.weight.data.fill_(1)
132 | if m.bias.data is not None:
133 | m.bias.data.zero_()
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Disentanglement via Mechanism Sparsity
2 |
3 | This repository contains the code used to run the experiments in the papers:
4 | [1] [Disentanglement via Mechanism Sparsity Regularization: A New Principle for Nonlinear ICA](https://arxiv.org/abs/2107.10098) (CLeaR2022)
5 | [2] [Nonparametric Partial Disentanglement via Mechanism Sparsity: Sparse Actions, Interventions and Sparse Temporal Dependencies](https://arxiv.org/abs/2401.04890) (Preprint)
6 | By Sébastien Lachapelle, Pau Rodríguez López, Yash Sharma, Katie Everett, Rémi Le Priol, Alexandre Lacoste, Simon Lacoste-Julien
7 |
8 | ### Environment:
9 |
10 | Tested on python 3.7.
11 |
12 | See `requirements.txt`.
13 |
14 | ### Action-sparsity experiment
15 | ```
16 | OUTPUT_DIR=
17 | DATAROOT=
18 | DATASET=
19 | python disentanglement_via_mechanism_sparsity/train.py --dataroot $DATAROOT --output_dir $OUTPUT_DIR --mode vae --dataset toy-nn/action_sparsity_non_trivial --freeze_g --freeze_gc --z_dim 10 --gt_z_dim 10 --gt_x_dim 20 --n_lag 0 --time_limit 3
20 | ```
21 |
22 | ### Time-sparsity experiment
23 | ```
24 | OUTPUT_DIR=
25 | DATAROOT=
26 | python disentanglement_via_mechanism_sparsity/train.py --dataroot $DATAROOT --output_dir $OUTPUT_DIR --mode vae --dataset toy-nn/temporal_sparsity_non_trivial --freeze_g --freeze_gc --z_dim 10 --gt_z_dim 10 --gt_x_dim 20 --n_lag 1 --full_seq --time_limit 3
27 | ```
28 |
29 | ### Adding penalty regularization
30 | In the minimal commands provided above, all regularizations are deactivated (via the `--freeze_g` and `--freeze_gc` flags).
31 | To activate the penalty regularization for say the temporal mask G^z, replace `--freeze_g` by `--g_reg_coeff COEFF_VALUE`.
32 | Same syntax works also for the action mask G^a (named `gc` in the code). Here's the correspondence between the mask names in the code (left) and in the paper (right):
33 |
34 | `g` = Mask G^z (Time sparsity)
35 |
36 | `gc` = Mask G^a (Action sparsity)
37 |
38 | ### Adding constraint regularization
39 | To activate the constraint regularization for say the temporal mask G^z, replace `--freeze_g` by `--g_constraint UPPER_BOUND`.
40 | Same syntax works also for the action mask G^a (named `gc` in the code).
41 | The experiments were all performed with `--constraint_scedule 150000` and `--dual_restarts`.
42 | The option `--set_constraint_to_gt` will automatically set the upper bound of the constraint to the optimal value for the ground-truth graph.
43 |
44 | ### Synthetic datasets (referencing to sections of [2])
45 | Here's a list of the synthetic datasets used. The data is generated before training, so no need to download anything.
46 |
47 | #### Section 8.1 (Used in both [1] and [2])
48 |
49 | - `--dataset toy-nn/action_sparsity_trivial`
50 | - `--dataset toy-nn/action_sparsity_non_trivial`
51 | - `--dataset toy-nn/action_sparsity_non_trivial_no_suff_var`
52 | - `--dataset toy-nn/action_sparsity_non_trivial_k=2`
53 | - `--dataset toy-nn/temporal_sparsity_trivial`
54 | - `--dataset toy-nn/temporal_sparsity_non_trivial`
55 | - `--dataset toy-nn/temporal_sparsity_non_trivial_no_suff_var`
56 | - `--dataset toy-nn/temporal_sparsity_non_trivial_k=2`
57 |
58 | #### Section 8.2 (Used only in [2])
59 | - `--dataset toy-nn/action_sparsity_non_trivial --graph_name graph_action_3_easy`
60 | - `--dataset toy-nn/action_sparsity_non_trivial --graph_name graph_action_3_hard`
61 | - `--dataset toy-nn/temporal_sparsity_non_trivial --graph_name graph_temporal_3_easy`
62 | - `--dataset toy-nn/temporal_sparsity_non_trivial --graph_name graph_temporal_3_hard`
63 | - `--dataset toy-nn/action_sparsity_non_trivial --rand_g_density PROBA_OF_EDGE`
64 | - `--dataset toy-nn/temporal_sparsity_non_trivial --rand_g_density PROBA_OF_EDGE`
65 |
66 |
67 | #### Datasets Used only in [1]
68 | - `--dataset toy-nn/action_sparsity_non_trivial_no_graph_crit`
69 | - `--dataset toy-nn/temporal_sparsity_non_trivial_no_graph_crit`
70 |
71 | ### Baselines
72 | #### TCVAE
73 | Code adapted from: https://github.com/rtqichen/beta-tcvae
74 | ```
75 | OUTPUT_DIR=
76 | DATAROOT=
77 | python disentanglement_via_mechanism_sparsity/baseline_models/beta-tcvae/train.py --dataroot $DATAROOT --output_dir $OUTPUT_DIR --dataset toy-nn/action_sparsity_non_trivial --tcvae --beta 1 --gt_z_dim 10 --gt_x_dim 20 --time_limit 3
78 | ```
79 |
80 | #### iVAE
81 | Code adapted from: https://github.com/ilkhem/icebeem
82 | ```
83 | OUTPUT_DIR=
84 | DATAROOT=
85 | python disentanglement_via_mechanism_sparsity/baseline_models/icebeem/train.py --dataroot $DATAROOT --output_dir $OUTPUT_DIR --dataset toy-nn/action_sparsity_non_trivial --method ivae --gt_z_dim 10 --gt_x_dim 20 --time_limit 3
86 | ```
87 |
88 | #### SlowVAE
89 | Code adapted from: https://github.com/bethgelab/slow_disentanglement
90 | ```
91 | OUTPUT_DIR=
92 | DATAROOT=
93 | python disentanglement_via_mechanism_sparsity/baseline_models/slowvae_pcl/train.py --dataroot $DATAROOT --output_dir $OUTPUT_DIR --dataset toy-nn/temporal_sparsity_non_trivial --gt_z_dim 10 --gt_x_dim 20 --time_limit 3
94 | ```
95 |
96 | #### PCL
97 | Code adapted from: https://github.com/bethgelab/slow_disentanglement/tree/baselines
98 | ```
99 | OUTPUT_DIR=
100 | DATAROOT=
101 | python disentanglement_via_mechanism_sparsity/baseline_models/slowvae_pcl/train.py --dataroot $DATAROOT --output_dir $OUTPUT_DIR --dataset toy-nn/temporal_sparsity_non_trivial --pcl --r_func mlp --gt_z_dim 10 --gt_x_dim 20 --time_limit 3
102 | ```
103 |
104 |
105 |
--------------------------------------------------------------------------------
/optimization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import cooper
3 |
4 |
5 | def compute_nll(log_likelihood, valid, opt):
6 | valid = valid.type(log_likelihood.type())
7 | if opt.include_invalid:
8 | nll = - torch.mean(log_likelihood)
9 | else:
10 | nll = - torch.dot(log_likelihood, valid) / torch.sum(valid)
11 |
12 | return nll
13 |
14 |
15 | class CustomCMP(cooper.ConstrainedMinimizationProblem):
16 | def __init__(self, g_reg_coeff=0., gc_reg_coeff=0., g_constraint=0, gc_constraint=0., g_scaling=1., gc_scaling=1.,
17 | schedule=False, max_g=0, max_gc=0):
18 | self.is_constrained = (g_constraint > 0. or gc_constraint > 0.)
19 | self.g_reg_coeff = g_reg_coeff
20 | self.gc_reg_coeff = gc_reg_coeff
21 | self.g_constraint = g_constraint
22 | self.gc_constraint = gc_constraint
23 |
24 | if schedule:
25 | self.current_g_constraint = max_g
26 | self.current_gc_constraint = max_gc
27 | else:
28 | self.current_g_constraint = g_constraint
29 | self.current_gc_constraint = gc_constraint
30 |
31 | self.max_g, self.max_gc = max_g, max_gc
32 |
33 | self.g_scaling = g_scaling
34 | self.gc_scaling = gc_scaling
35 | super().__init__(is_constrained=self.is_constrained)
36 |
37 | def closure(self, model, obs, cont_c, disc_c, valid, other, opt):
38 | misc = {}
39 | if opt.mode == "vae":
40 | elbo, reconstruction_loss, kl, _ = model.elbo(obs, cont_c, disc_c)
41 | loss = compute_nll(elbo, valid, opt)
42 | elif opt.mode == "supervised_vae":
43 | obs = obs[:, -1]
44 | other = other[:, -1]
45 | z_hat = model.latent_model.mean(model.latent_model.transform_q_params(model.encode(obs)))
46 | loss = torch.mean((z_hat.view(z_hat.shape[0], -1) - other) ** 2)
47 | reconstruction_loss, kl = 0, 0
48 | elif opt.mode == "latent_transition_only":
49 | ll = model.log_likelihood(other, cont_c, disc_c)
50 | loss = compute_nll(ll, valid, opt)
51 | reconstruction_loss, kl = 0, 0
52 | else:
53 | raise NotImplementedError(f"--mode {opt.mode} is not implemented.")
54 |
55 | misc["nll"] = loss.item()
56 | misc["reconstruction_loss"] = reconstruction_loss
57 | misc["kl"] = kl
58 |
59 | # regularization/constraint
60 | g_reg = model.latent_model.g_regularizer()
61 | gc_reg = model.latent_model.gc_regularizer()
62 |
63 | misc["g_reg"], misc["gc_reg"] = g_reg.item(), gc_reg.item()
64 |
65 | if not self.is_constrained:
66 | if self.g_reg_coeff > 0:
67 | loss += opt.g_reg_coeff * g_reg * self.g_scaling
68 | if self.gc_reg_coeff > 0:
69 | loss += opt.gc_reg_coeff * gc_reg * self.gc_scaling
70 |
71 | return cooper.CMPState(loss=loss, ineq_defect=None, eq_defect=None, misc=misc)
72 | else:
73 | defects = []
74 | if self.g_constraint > 0:
75 | defects.append(g_reg - self.current_g_constraint)
76 | if self.gc_constraint > 0:
77 | defects.append(gc_reg - self.current_gc_constraint)
78 |
79 | defects = torch.stack(defects)
80 |
81 | return cooper.CMPState(loss=loss, ineq_defect=defects, eq_defect=None, misc=misc)
82 |
83 | def update_constraint(self, iter, total_iter, no_update_period=0):
84 | if iter <= no_update_period:
85 | if self.g_constraint > 0:
86 | self.current_g_constraint = self.max_g
87 | if self.gc_constraint > 0:
88 | self.current_gc_constraint = self.max_gc
89 | elif no_update_period < iter <= no_update_period + total_iter:
90 | if self.g_constraint > 0:
91 | self.current_g_constraint = (self.max_g - self.g_constraint) * (1 - iter / total_iter) + self.g_constraint
92 | if self.gc_constraint > 0:
93 | self.current_gc_constraint = (self.max_gc - self.gc_constraint) * (1 - iter / total_iter) + self.gc_constraint
94 | else:
95 | if self.g_constraint > 0:
96 | self.current_g_constraint = self.g_constraint
97 | if self.gc_constraint > 0:
98 | self.current_gc_constraint = self.gc_constraint
99 |
100 | return self.current_g_constraint, self.current_gc_constraint
101 |
102 | def update_constraint_adaptive(self, iter, decrease_rate=0.0005, no_update_period=0):
103 | if iter <= no_update_period:
104 | if self.g_constraint > 0:
105 | self.current_g_constraint = self.max_g
106 | if self.gc_constraint > 0:
107 | self.current_gc_constraint = self.max_gc
108 | else:
109 | # decrease constraint only when defect is smaller than 0.1, otherwise do not change constraint.
110 | if self.g_constraint > 0 and self.state.ineq_defect.sum() <= 0.1:
111 | self.current_g_constraint = max(self.current_g_constraint - decrease_rate, self.g_constraint)
112 | if self.gc_constraint > 0 and self.state.ineq_defect.sum() <= 0.1:
113 | self.current_gc_constraint = max(self.current_gc_constraint - decrease_rate, self.gc_constraint)
114 |
115 | return self.current_g_constraint, self.current_gc_constraint
116 |
117 | #if self.g_constraint > 0 and self.state.ineq_defect.sum() <= 0:
118 | # self.current_g_constraint = max(self.current_g_constraint - 1, self.g_constraint)
119 | #if self.gc_constraint > 0 and self.state.ineq_defect.sum() <= 0:
120 | # self.current_gc_constraint = max(self.current_gc_constraint - 1, self.gc_constraint)
121 |
122 |
--------------------------------------------------------------------------------
/baseline_models/slowvae_pcl/data_generation/gen_kitti_masks.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import numpy as np
3 | import os
4 | import pickle
5 | import scipy.ndimage as ndi
6 | import torch
7 | import PIL.Image as Image
8 | from scipy.ndimage.measurements import center_of_mass
9 |
10 | # kitti mots
11 | def get_data(path='./data/kitti/instances/',
12 | folder=range(0, 21),
13 | n_hor_windows=6,
14 | img_size=64):
15 | all_imgs = []
16 | # get each sequence
17 | for j in folder: # 21 sequences
18 | img_folder = f'{j:04d}'
19 | files = sorted(glob.glob(path + f'{img_folder}/*.png'))
20 | imgs = np.zeros((len(files), n_hor_windows, img_size, img_size), dtype=np.uint32)
21 | # load sequence_i
22 | for i, f in enumerate(files):
23 | img = np.array(Image.open(f))
24 | if i == 0:
25 | print('folder', img_folder, 'first image shape ', img.shape)
26 | shape = img.shape
27 | x_size = int(shape[1] // (shape[0] / img_size))
28 | img_t = torch.tensor(img.astype(np.float32))
29 | img = torch.nn.functional.interpolate(img_t[None, None], size=(img_size, x_size)).numpy().astype(np.uint32)[0, 0]
30 | # tile into windows
31 | hor_stride = (x_size - img_size) // (n_hor_windows - 1)
32 | for k in range(n_hor_windows):
33 | offset = k*hor_stride
34 | imgs[i, k, :, :] = img[:, offset:offset+img_size]
35 | all_imgs.append(imgs)
36 | return all_imgs
37 |
38 | def get_individual_sequences(all_imgs, n_hor_windows=6, mask_threshold=30):
39 | sequences = []
40 | for f_id, imgs_windows in enumerate(all_imgs): # all videos
41 | print('sequence orig', f_id)
42 | for window_i in range(n_hor_windows): # one window
43 | imgs = imgs_windows[:, window_i]
44 | ids = np.where(np.bincount(imgs.ravel()) != 0)[0][1:-1] # first is bg, last is non-category
45 | for id_i in ids: # indiv
46 | if id_i // 1000 == 1:
47 | continue
48 | imgs_id_i = np.zeros(imgs.shape, dtype=np.bool)
49 | imgs_id_i[imgs == id_i] = 1
50 | t_inds = np.where(imgs_id_i != 0)[0]
51 | t_inds = np.arange(t_inds[0], t_inds[-1]+1) # make dense
52 | sequence_id_i = []
53 | for t_ind in t_inds: # augean stables
54 | frame = imgs_id_i[t_ind]
55 | if np.sum(frame) < mask_threshold: # mask too small
56 | if len(sequence_id_i) > 2: # min sequ len 2
57 | sequences.append(np.stack(sequence_id_i)) # add to sequences
58 | sequence_id_i = [] # hole in sequence, start new one
59 | continue
60 | else:
61 | sequence_id_i.append(frame) # add to sequence
62 | if len(sequence_id_i) > 1:
63 | sequences.append(np.stack(sequence_id_i)) # add to sequences
64 | return sequences
65 |
66 | # get center of mass, and area
67 | def get_latents(sequence):
68 | all_latents = []
69 | for seq in sequence:
70 | latents = np.zeros((len(seq), 3), dtype=np.float32)
71 | for i, img in enumerate(seq):
72 | com = center_of_mass(img)
73 | latents[i] = np.array([com[0], com[1], np.sum(img)]) # y pos, x pos, area
74 | all_latents.append(latents)
75 | return all_latents
76 |
77 | def main(args):
78 | # raw data from https://www.vision.rwth-aachen.de/page/mots
79 | all_imgs_c = get_data(path='./data/kitti/instances/', folder=range(0, 21),
80 | n_hor_windows=args.n_hor_windows, img_size=args.img_size) # mostly cars
81 | all_imgs_p = get_data(path='./data/kitti/mots/instances/', folder=[2, 5, 9, 11],
82 | n_hor_windows=args.n_hor_windows, img_size=args.img_size) # pedestrians
83 | print('number folders mostly cars', len(all_imgs_c))
84 | print('number folders pedestrians', len(all_imgs_p))
85 | sequences_p = get_individual_sequences(all_imgs_c, n_hor_windows=args.n_hor_windows)
86 | print()
87 | print('pedestrians')
88 | sequences_c = get_individual_sequences(all_imgs_p, n_hor_windows=args.n_hor_windows)
89 | all_sequences = sequences_p + sequences_c
90 | all_latents = get_latents(all_sequences)
91 | # save data
92 | with open(os.path.join('./data/kitti_peds_v2.pickle'), 'wb') as f: # v0, v1 are only internal and were not released
93 | pickle.dump({'pedestrians':all_sequences, 'pedestrians_latents': all_latents}, f)
94 | # this is to do the data analysis
95 | dd = {}
96 | dd['id'] = []
97 | dd['category_id'] = []
98 | dd['category'] = []
99 | dd['x'] = []
100 | dd['x_diff'] = []
101 | dd['y'] = []
102 | dd['y_diff'] = []
103 | dd['area'] = []
104 | dd['area_diff'] = []
105 | dd['masks'] = []
106 | rotate = args.rotate
107 | for id_i, (seq, lat) in enumerate(zip(all_sequences, all_latents)):
108 | for (start_img, next_img), (start_latent, next_latent) in zip(zip(seq[:-1], seq[1:]), zip(lat[:-1], lat[1:])):
109 | if rotate:
110 | start_img = ndi.rotate(start_img, 45)
111 | next_img = ndi.rotate(next_img, 45)
112 | start_latent, next_latent
113 |
114 | start_com = center_of_mass(start_img)
115 | next_com = center_of_mass(next_img)
116 | start_latent = [start_com[0], start_com[1], np.sum(start_img)] # y pos, x pos, area
117 | next_latent = [next_com[0], next_com[1], np.sum(next_img)] # y pos, x pos, area
118 |
119 | dd['id'].append(id_i)
120 | dd['category_id'].append(1)
121 | dd['category'].append('pedestrian')
122 | dd['x'].append([start_latent[1], next_latent[1]])
123 | dd['x_diff'].append(next_latent[1] - start_latent[1])
124 |
125 | dd['y'].append([start_latent[0], next_latent[0]])
126 | dd['y_diff'].append(next_latent[0] - start_latent[0])
127 |
128 | dd['area'].append([np.sum(start_img), np.sum(next_img)])
129 | dd['area_diff'].append(np.sum(next_img) - np.sum(start_img))
130 |
131 | dd['masks'].append([start_img.astype(np.uint8), next_img.astype(np.uint8)])
132 | prefix = ''
133 | if rotate:
134 | prefix += '_rotate'
135 | with open(f'./data/kitti_dict_p_v2{prefix}.pkl', 'wb') as f:
136 | pickle.dump(dd, f)
137 |
138 |
139 | if __name__ == "__main__":
140 | import argparse
141 | parser = argparse.ArgumentParser()
142 | parser.add_argument('--img-size', type=int, default=64)
143 | parser.add_argument('--n-hor-windows', type=int, default=6)
144 | parser.add_argument('--rotate', action='store_true')
145 | args = parser.parse_args()
146 | main(args)
147 |
148 |
149 |
--------------------------------------------------------------------------------
/model/ilcm_vae.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | from .nn import MLP
7 |
8 |
9 | class ILCM_VAE(torch.nn.Module):
10 | def __init__(self, latent_model, image_shape, cont_c_dim, disc_c_dim, disc_n_values, opt):
11 | super().__init__()
12 | self.latent_model = latent_model
13 | self.z_dim = opt.z_max_dim
14 | self.n_lag = opt.n_lag
15 | self.cont_c_dim = cont_c_dim
16 | self.disc_c_dim = disc_c_dim
17 | self.disc_c_n_values = disc_n_values
18 | self.beta = opt.beta
19 | self.learn_var = opt.learn_decoder_var
20 | self.opt = opt
21 |
22 | if self.opt.full_seq:
23 | assert self.n_lag == 1, "--full_seq is supported only for --n_lag 1"
24 |
25 | # choosing encoder
26 | if opt.encoder == "tabular":
27 | assert len(image_shape) == 1, f"The encoder {opt.encoder} works only on tabular data."
28 | self.encoder = MLP(image_shape[0], 2 * self.z_dim, 512, 3 * opt.encoder_depth_multiplier, spectral_norm=False, batch_norm=opt.bn_enc_dec)
29 | else:
30 | raise NotImplementedError(f"The encoder {opt.encoder} is not implemented.")
31 |
32 | # choosing decoder
33 | if opt.decoder == "tabular":
34 | assert len(image_shape) == 1, f"The decoder {opt.decoder} works only on tabular data."
35 | self.decoder = MLP(self.z_dim, image_shape[0], 512, 3 * opt.decoder_depth_multiplier, spectral_norm=False, batch_norm=opt.bn_enc_dec)
36 | else:
37 | raise NotImplementedError(f"The decoder {opt.decoder} is not implemented.")
38 |
39 | # dummy parameter which replaces masked out z_i's in the decoder input
40 | self.eta = torch.nn.Parameter(torch.zeros((self.z_dim,)))
41 | if opt.freeze_dummies:
42 | self.eta.requires_grad = False
43 |
44 | if len(image_shape) == 3:
45 | self.image_shape = image_shape[2:3] + image_shape[0:2]
46 | else:
47 | self.image_shape = image_shape
48 |
49 | # that's a bit messy, keep it like this to keep previous default behavior.
50 | if self.learn_var:
51 | if self.opt.init_decoder_var is None:
52 | self.decoder_logvar = torch.nn.Parameter(-10 * torch.ones((1,)))
53 | else:
54 | self.decoder_logvar = torch.nn.Parameter(math.log(opt.init_decoder_var) * torch.ones((1,)))
55 | else:
56 | if self.opt.init_decoder_var is None:
57 | self.decoder_logvar = torch.nn.Parameter(torch.zeros((1,)))
58 | else:
59 | self.decoder_logvar = torch.nn.Parameter(math.log(opt.init_decoder_var) * torch.ones((1,)))
60 | self.decoder_logvar.requires_grad = False
61 |
62 | def encode(self, x):
63 | b = x.size(0)
64 | z_params = self.encoder(x)
65 | return z_params
66 |
67 | def decode(self, input, logit=False):
68 | if len(self.image_shape) == 1:
69 | return self.decoder(input)
70 | elif len(self.image_shape) == 3:
71 | if logit:
72 | return self.decoder(input)
73 | else:
74 | return torch.sigmoid(self.decoder(input))
75 |
76 | def _mask_z_before_decoding(self, z, m):
77 | # z shape: (b, t, z_max_dim)
78 | b, t = z.shape[0:2]
79 | num_blocks = m.shape[1]
80 | m = m.view(b, 1, num_blocks, 1)
81 | z = z.view(b,t, num_blocks, -1)
82 | eta = self.eta.view(1, 1, num_blocks, -1)
83 | z_masked_eta = m * z + (1 - m) * eta
84 | return z_masked_eta.view(b * t, -1)
85 |
86 | def elbo(self, obs, cont_c, disc_c):
87 | b, t = obs.shape[0:2]
88 |
89 | q_params = self.encode(obs.view((b * t,) + self.image_shape))
90 | z = self.latent_model.reparameterize(q_params)
91 | m = self.latent_model.m(b) # sample a mask
92 |
93 | ## --- Reconstruction --- ##
94 | z_masked_eta = self._mask_z_before_decoding(z.view(b, t, -1), m)
95 | # including the reconstruction term not only for x_t, but also for x_t-1, ..., x_t-k.
96 |
97 | reconstructions = self.decode(z_masked_eta)
98 | std = torch.exp(0.5 * self.decoder_logvar) + 1e-4
99 | # SL: This choice of reduction is picked to keep the relative importance of each terms in line with the original ELBO
100 | rec_loss = - torch.distributions.normal.Normal(reconstructions.view(b, t, -1), std).log_prob(obs.view(b, t, -1)).mean(dim=(1, 2))
101 |
102 | ## --- KL divergence --- #
103 | # mask z and c
104 | if self.latent_model.n_lag > 0:
105 | g = self.latent_model.g(b)
106 | z_lag = z.view(b, t, -1)[:, :-1]
107 | z_masked_gamma_lag = self.latent_model._mask_z_lag(z_lag, m, g) # (bs, num_blocks, n_lag, z_dim)
108 | z_tm1 = z_lag[:, -1].view(b, -1, self.latent_model.z_block_size)
109 | else:
110 | z_masked_gamma_lag = None
111 | z_tm1 = None
112 |
113 | if self.latent_model.cont_c_dim > 0:
114 | gc = self.latent_model.gc(b)
115 | masked_cont_c = self.latent_model._mask_cont_c(cont_c, gc)
116 | else:
117 | masked_cont_c = None
118 |
119 | if self.disc_c_dim > 0:
120 | gc_disc = self.latent_model.gc_disc(b) # (batch_size, num_z_blocks, disc_c_dim) TODO: disc_c_dim == 1 for now...
121 | masked_disc_c_one_hot = self.latent_model._mask_disc_c_one_hot(disc_c, gc_disc) # (bs, num_blocks, cont_c_dim)
122 | else:
123 | masked_disc_c_one_hot = None
124 |
125 | p_params = self.latent_model.network(z_masked_gamma_lag, masked_cont_c, masked_disc_c_one_hot)
126 | q_params = q_params.view(b, t, -1)
127 | kl = self.latent_model.compute_kl(p_params, q_params[:, -1], z_tm1)
128 |
129 | if self.opt.full_seq:
130 | # we divide by two to preserve the relative importance between rec_loss and kl in line with original elbo
131 | init_params = self.latent_model.init_p_params.expand(b, -1, -1)
132 | kl = 0.5 * (kl + self.latent_model.compute_kl(init_params, q_params[:, 0], None, init=True))
133 |
134 | kl_reduced = torch.sum(kl, 1) / int(np.product(self.image_shape))
135 |
136 | elbo = -rec_loss - self.beta * kl_reduced
137 | return elbo, rec_loss.mean().item(), kl_reduced.mean().item(), kl
138 |
139 | def log_likelihood(self, gt_z, cont_c, disc_c):
140 | b, t = gt_z.shape[0:2]
141 |
142 | m = self.latent_model.m(b) # sample a mask
143 |
144 | # mask z and c
145 | if self.latent_model.n_lag > 0:
146 | g = self.latent_model.g(b)
147 | z_lag = gt_z[:, :-1]
148 | z_masked_gamma_lag = self.latent_model._mask_z_lag(z_lag, m, g) # (bs, num_blocks, n_lag, z_dim)
149 | z_tm1 = z_lag[:, -1].view(b, -1, self.latent_model.z_block_size)
150 | else:
151 | z_masked_gamma_lag = None
152 | z_tm1 = None
153 |
154 | if self.latent_model.cont_c_dim > 0:
155 | gc = self.latent_model.gc(b)
156 | masked_cont_c = self.latent_model._mask_cont_c(cont_c, gc)
157 | else:
158 | masked_cont_c = None
159 |
160 | p_params = self.latent_model.network(z_masked_gamma_lag, masked_cont_c, disc_c)
161 | ll = self.latent_model.log_likelihood(p_params, gt_z[:, -1], z_tm1)
162 |
163 | return ll
164 |
165 |
--------------------------------------------------------------------------------
/scripts/compute_udr_npy.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import json
4 | import math
5 | import os
6 | import random
7 | import shutil
8 | import sys
9 | import time
10 | import pathlib
11 |
12 | import numpy as np
13 | import pandas as pd
14 |
15 | #sys.path.insert(0, str(pathlib.Path(__file__).parent.parent))
16 | sys.path.insert(0, os.path.abspath('..'))
17 | from metrics import mean_corr_coef_np
18 |
19 | def load_hparams(folder_path):
20 | with open(os.path.join(folder_path, 'hparams.json'), 'r') as infile:
21 | opt = json.load(infile)
22 |
23 | class Bunch:
24 | def __init__(self, opt):
25 | self.__dict__.update(opt)
26 | return Bunch(opt)
27 |
28 | def compute_udr_mcc_from_npy(inferred_model_reps):
29 | """Computes the UDR score using scikit-learn.
30 |
31 | Args:
32 | ground_truth_data: GroundTruthData to be sampled from.
33 | representation_functions: functions that takes observations as input and
34 | outputs a dim_representation sized representation for each observation.
35 | random_state: numpy random state used for randomness.
36 | batch_size: Number of datapoints to compute in a single batch. Useful for
37 | reducing memory overhead for larger models.
38 | num_data_points: total number of representation datapoints to generate for
39 | computing the correlation matrix.
40 | correlation_matrix: Type of correlation matrix to generate. Can be either
41 | "lasso" or "spearman".
42 | filter_low_kl: If True, filter out elements of the representation vector
43 | which have low computed KL divergence.
44 | include_raw_correlations: Whether or not to include the raw correlation
45 | matrices in the results.
46 | kl_filter_threshold: Threshold which latents with average KL divergence
47 | lower than the threshold will be ignored when computing disentanglement.
48 |
49 | Returns:
50 | scores_dict: a dictionary of the scores computed for UDR with the following
51 | keys:
52 | raw_correlations: (num_models, num_models, latent_dim, latent_dim) - The
53 | raw computed correlation matrices for all models. The pair of models is
54 | indexed by axis 0 and 1 and the matrix represents the computed
55 | correlation matrix between latents in axis 2 and 3.
56 | pairwise_disentanglement_scores: (num_models, num_models, 1) - The
57 | computed disentanglement scores representing the similarity of
58 | representation between pairs of models.
59 | model_scores: (num_models) - List of aggregated model scores corresponding
60 | to the median of the pairwise disentanglement scores for each model.
61 | """
62 |
63 | num_models = len(inferred_model_reps)
64 | mcc_all = np.zeros((num_models, num_models))
65 |
66 | for i in range(num_models):
67 | for j in range(num_models):
68 | if i == j:
69 | continue
70 |
71 | mcc = mean_corr_coef_np(inferred_model_reps[i],
72 | inferred_model_reps[j],
73 | method='pearson', indices=None)[0]
74 |
75 | mcc_all[i, j] = mcc
76 | off_diag = mcc_all[~np.eye(mcc_all.shape[0], dtype=bool)]
77 | return {'median': np.median(off_diag), 'mean': np.mean(off_diag)}
78 |
79 | def create_mode_entry(all_logs_pd):
80 | # for tcvae
81 | all_logs_pd.loc[all_logs_pd["tcvae"] == True, 'mode'] = "tcvae"
82 |
83 | # TODO: create for other methods if necessary.
84 |
85 | return all_logs_pd
86 |
87 | def main(args=None):
88 | parser = argparse.ArgumentParser()
89 | parser.add_argument("--all_logs_file", type=str,
90 | help="Absolute path to all_logs.npy files")
91 |
92 | #GT_GRAPH_NAMES = ["graph_temporal_3_easy", "graph_temporal_3_hard", "graph_action_3_easy", "graph_action_3_hard"]
93 | GT_GRAPH_NAMES = ["toy-nn/temporal_sparsity_trivial", "toy-nn/temporal_sparsity_non_trivial", "toy-nn/action_sparsity_trivial", "toy-nn/action_sparsity_non_trivial"]
94 | MODES = ['vae'] #, 'pcl', 'slowvae', 'tcvae', 'ivae', 'random_vae', 'supervised_vae']
95 | HPARAM_NAMES = {"vae": ['gc_constraint', 'g_constraint']} #,
96 | #"random_vae": [],
97 | #"supervised_vae": [],
98 | #"tcvae": ["beta"],
99 | #"ivae": [],
100 | #"pcl": [],
101 | #"slowvae": ['gamma', 'rate_prior']}
102 |
103 | opt = parser.parse_args()
104 |
105 | all_logs = np.load(opt.all_logs_file, allow_pickle=True).tolist()
106 | all_logs_pd = pd.DataFrame(all_logs)
107 | #all_logs_pd = create_mode_entry(all_logs_pd)
108 |
109 | # to be filled with udr values
110 | all_logs_pd_udr = all_logs_pd.copy()
111 |
112 | for gt_graph_name in GT_GRAPH_NAMES:
113 | for mode in MODES:
114 | print("########## gt_graph_name:", gt_graph_name, "mode:", mode, "##########")
115 | #condition_data_mode = (all_logs_pd['gt_graph_name'] == gt_graph_name) & (all_logs_pd['mode'] == mode)
116 | condition_data_mode = (all_logs_pd['dataset'] == gt_graph_name) & (all_logs_pd['mode'] == mode)
117 | logs = all_logs_pd[condition_data_mode]
118 |
119 | if len(logs) == 0:
120 | print("No logs found.")
121 | continue
122 |
123 | if len(HPARAM_NAMES[mode]) != 0:
124 | hparams_values = logs[HPARAM_NAMES[mode]].drop_duplicates()
125 |
126 | for i in range(len(hparams_values)):
127 | # selecting only the runs with this specific hparam value
128 | condition_hp = (hparams_values.iloc[i] == logs[HPARAM_NAMES[mode]]).all(axis=1)
129 | logs_specific_hp = logs[condition_hp]
130 |
131 | # logging number of seeds used to compute UDR
132 | all_logs_pd_udr.loc[condition_data_mode & condition_hp, "num_seeds"] = len(logs_specific_hp)
133 |
134 | # cannot compute UDR when there is only one seed
135 | if len(logs_specific_hp) == 1:
136 | print(f"Not computing UDR for {hparams_values.iloc[i].to_dict()} since only one seed.")
137 | all_logs_pd_udr.loc[condition_data_mode & condition_hp, "udr_mean"] = -1
138 | all_logs_pd_udr.loc[condition_data_mode & condition_hp, "udr_median"] = -1
139 | continue
140 |
141 | # load their z_hat_final.npy
142 | z_hat_list = []
143 | for output_dir in logs_specific_hp["output_dir"]:
144 | z_hat_list.append(np.load(os.path.join(output_dir, "z_hat_final.npy")))
145 |
146 | print(f"Computing UDR for {hparams_values.iloc[i].to_dict()} on {len(logs_specific_hp)} seeds.")
147 | udr = compute_udr_mcc_from_npy(z_hat_list)
148 |
149 | # Add udr values to table
150 | all_logs_pd_udr.loc[condition_data_mode & condition_hp, "udr_mean"] = udr["mean"]
151 | all_logs_pd_udr.loc[condition_data_mode & condition_hp, "udr_median"] = udr["median"]
152 |
153 | else:
154 | print(f"No hyperparameter search. Total of {len(logs)} seeds.")
155 | # number of sucessful seeds
156 | all_logs_pd_udr.loc[condition_data_mode, "num_seeds"] = len(logs)
157 | # if the method has no hparameter, just set UDR score to -1
158 | all_logs_pd_udr.loc[condition_data_mode, "udr_mean"] = -1
159 | all_logs_pd_udr.loc[condition_data_mode, "udr_median"] = -1
160 |
161 |
162 | np.save(opt.all_logs_file.replace(".npy", "_udr.npy"), all_logs_pd_udr.to_dict('records'))
163 |
164 |
165 | if __name__ == "__main__":
166 | main()
167 |
--------------------------------------------------------------------------------
/scripts/compute_udr_npy_rand_graphs.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import json
4 | import math
5 | import os
6 | import random
7 | import shutil
8 | import sys
9 | import time
10 | import pathlib
11 |
12 | import numpy as np
13 | import pandas as pd
14 |
15 | #sys.path.insert(0, str(pathlib.Path(__file__).parent.parent))
16 | sys.path.insert(0, os.path.abspath('..'))
17 | from metrics import mean_corr_coef_np
18 |
19 | def load_hparams(folder_path):
20 | with open(os.path.join(folder_path, 'hparams.json'), 'r') as infile:
21 | opt = json.load(infile)
22 |
23 | class Bunch:
24 | def __init__(self, opt):
25 | self.__dict__.update(opt)
26 | return Bunch(opt)
27 |
28 | def compute_udr_mcc_from_npy(inferred_model_reps):
29 | """Computes the UDR score using scikit-learn.
30 |
31 | Args:
32 | ground_truth_data: GroundTruthData to be sampled from.
33 | representation_functions: functions that takes observations as input and
34 | outputs a dim_representation sized representation for each observation.
35 | random_state: numpy random state used for randomness.
36 | batch_size: Number of datapoints to compute in a single batch. Useful for
37 | reducing memory overhead for larger models.
38 | num_data_points: total number of representation datapoints to generate for
39 | computing the correlation matrix.
40 | correlation_matrix: Type of correlation matrix to generate. Can be either
41 | "lasso" or "spearman".
42 | filter_low_kl: If True, filter out elements of the representation vector
43 | which have low computed KL divergence.
44 | include_raw_correlations: Whether or not to include the raw correlation
45 | matrices in the results.
46 | kl_filter_threshold: Threshold which latents with average KL divergence
47 | lower than the threshold will be ignored when computing disentanglement.
48 |
49 | Returns:
50 | scores_dict: a dictionary of the scores computed for UDR with the following
51 | keys:
52 | raw_correlations: (num_models, num_models, latent_dim, latent_dim) - The
53 | raw computed correlation matrices for all models. The pair of models is
54 | indexed by axis 0 and 1 and the matrix represents the computed
55 | correlation matrix between latents in axis 2 and 3.
56 | pairwise_disentanglement_scores: (num_models, num_models, 1) - The
57 | computed disentanglement scores representing the similarity of
58 | representation between pairs of models.
59 | model_scores: (num_models) - List of aggregated model scores corresponding
60 | to the median of the pairwise disentanglement scores for each model.
61 | """
62 |
63 | num_models = len(inferred_model_reps)
64 | mcc_all = np.zeros((num_models, num_models))
65 |
66 | for i in range(num_models):
67 | for j in range(num_models):
68 | if i == j:
69 | continue
70 |
71 | mcc = mean_corr_coef_np(inferred_model_reps[i],
72 | inferred_model_reps[j],
73 | method='pearson', indices=None)[0]
74 |
75 | mcc_all[i, j] = mcc
76 | off_diag = mcc_all[~np.eye(mcc_all.shape[0], dtype=bool)]
77 | return {'median': np.median(off_diag), 'mean': np.mean(off_diag)}
78 |
79 | def create_mode_entry(all_logs_pd):
80 | # for tcvae
81 | all_logs_pd.loc[all_logs_pd["tcvae"] == True, 'mode'] = "tcvae"
82 |
83 | # TODO: create for other methods if necessary.
84 |
85 | return all_logs_pd
86 |
87 | def main(args=None):
88 | parser = argparse.ArgumentParser()
89 | parser.add_argument("--all_logs_file", type=str,
90 | help="Absolute path to all_logs.npy files")
91 |
92 | GRAPH_DENSITIES = [0.25, 0.5, 0.75]
93 | DATASET_NAMES = ["toy-nn/action_sparsity_non_trivial", "toy-nn/temporal_sparsity_non_trivial"]
94 | MODES = ['vae'] #, 'pcl', 'slowvae', 'tcvae', 'ivae', 'random_vae', 'supervised_vae']
95 | HPARAM_NAMES = {"vae": ['gc_reg_coeff', 'g_reg_coeff']}
96 | #"random_vae": [],
97 | #"supervised_vae": [],
98 | #"tcvae": ["beta"],
99 | #"ivae": [],
100 | #"pcl": [],
101 | #"slowvae": ['gamma', 'rate_prior']}
102 |
103 | opt = parser.parse_args()
104 |
105 | all_logs = np.load(opt.all_logs_file, allow_pickle=True).tolist()
106 | all_logs_pd = pd.DataFrame(all_logs)
107 | #all_logs_pd = create_mode_entry(all_logs_pd)
108 |
109 | # to be filled with udr values
110 | all_logs_pd_udr = all_logs_pd.copy()
111 |
112 | for dataset in DATASET_NAMES:
113 | for graph_density in GRAPH_DENSITIES:
114 | for mode in MODES:
115 | print("########## dataset:", dataset, "graph_density", graph_density, "mode:", mode, "##########")
116 | condition_data_mode = (all_logs_pd['dataset'] == dataset) & (all_logs_pd['rand_g_density'] == graph_density) & (all_logs_pd['mode'] == mode)
117 | #condition = (all_logs_pd['dataset'] == gt_graph_name) & (all_logs_pd['mode'] == mode) # only for debugging
118 | logs = all_logs_pd[condition_data_mode]
119 |
120 | if len(logs) == 0:
121 | print("No logs found.")
122 | continue
123 |
124 | if len(HPARAM_NAMES[mode]) != 0:
125 | hparams_values = logs[HPARAM_NAMES[mode]].drop_duplicates()
126 |
127 | for i in range(len(hparams_values)):
128 | # selecting only the runs with this specific hparam value
129 | condition_hp = (hparams_values.iloc[i] == logs[HPARAM_NAMES[mode]]).all(axis=1)
130 | logs_specific_hp = logs[condition_hp]
131 |
132 | # logging number of seeds used to compute UDR
133 | all_logs_pd_udr.loc[condition_data_mode & condition_hp, "num_seeds"] = len(logs_specific_hp)
134 |
135 | # cannot compute UDR when there is only one seed
136 | if len(logs_specific_hp) == 1:
137 | print(f"Not computing UDR for {hparams_values.iloc[i].to_dict()} since only one seed.")
138 | all_logs_pd_udr.loc[condition_data_mode & condition_hp, "udr_mean"] = -1
139 | all_logs_pd_udr.loc[condition_data_mode & condition_hp, "udr_median"] = -1
140 | continue
141 |
142 | # load their z_hat_final.npy
143 | z_hat_list = []
144 | for output_dir in logs_specific_hp["output_dir"]:
145 | z_hat_list.append(np.load(os.path.join(output_dir, "z_hat_final.npy")))
146 |
147 | print(f"Computing UDR for {hparams_values.iloc[i].to_dict()} on {len(logs_specific_hp)} seeds.")
148 | udr = compute_udr_mcc_from_npy(z_hat_list)
149 |
150 | # Add udr values to table
151 | all_logs_pd_udr.loc[condition_data_mode & condition_hp, "udr_mean"] = udr["mean"]
152 | all_logs_pd_udr.loc[condition_data_mode & condition_hp, "udr_median"] = udr["median"]
153 |
154 | else:
155 | print(f"No hyperparameter search. Total of {len(logs)} seeds.")
156 | # number of sucessful seeds
157 | all_logs_pd_udr.loc[condition_data_mode, "num_seeds"] = len(logs)
158 | # if the method has no hparameter, just set UDR score to -1
159 | all_logs_pd_udr.loc[condition_data_mode, "udr_mean"] = -1
160 | all_logs_pd_udr.loc[condition_data_mode, "udr_median"] = -1
161 |
162 |
163 | np.save(opt.all_logs_file.replace(".npy", "_udr.npy"), all_logs_pd_udr.to_dict('records'))
164 |
165 |
166 | if __name__ == "__main__":
167 | main()
168 |
--------------------------------------------------------------------------------
/baseline_models/slowvae_pcl/scripts/data_analysis_utils.py:
--------------------------------------------------------------------------------
1 | import csv
2 | from collections import defaultdict
3 | import ast
4 | import pickle
5 | import numpy as np
6 | import matplotlib.pyplot as plt
7 | import pandas as pd
8 | import scipy.stats
9 | import collections
10 | import warnings
11 | from sklearn.feature_selection import mutual_info_regression
12 | import random
13 | #for youtube
14 | name_list = 'person giant_panda lizard parrot skateboard sedan ape dog snake monkey hand rabbit duck cat cow fish train horse turtle bear motorbike giraffe leopard fox deer owl surfboard airplane truck zebra tiger elephant snowboard boat shark mouse frog eagle earless_seal tennis_racket'.split(' ')
15 |
16 | def load_csv(csv_file, sequence=2):
17 | csv_reader = csv.reader(csv_file, delimiter=',')
18 | next(csv_reader)
19 | data = defaultdict(list)
20 | for i,row in enumerate(csv_reader):
21 | for j in range(2,len(row)):
22 | considered_columns = row[j:j+sequence]
23 | if all(considered_columns):
24 | temp = defaultdict(list)
25 | for column in considered_columns:
26 | val_list = ast.literal_eval(column)
27 | for i,val in enumerate(val_list):
28 | if val:
29 | temp['pos'].append(i)
30 | temp['y'].append(val[0])
31 | temp['x'].append(val[1])
32 | temp['area'].append(val[2])
33 | for i in range(len(val_list)):
34 | if temp['pos'].count(i) == sequence:
35 | data['id'].append(int(row[0]))
36 | data['category_id'].append(int(row[1]))
37 | data['area'].append([temp['area'][j] for j in range(len(temp['pos'])) if temp['pos'][j]==i])
38 | data['x'].append([temp['x'][j] for j in range(len(temp['pos'])) if temp['pos'][j]==i])
39 | data['y'].append([temp['y'][j] for j in range(len(temp['pos'])) if temp['pos'][j]==i])
40 | for k in range(1,sequence):
41 | data['area_diff{}'.format(k if k>1 else '')].append(data['area'][-1][k] - data['area'][-1][k-1])
42 | data['x_diff{}'.format(k if k>1 else '')].append(data['x'][-1][k] - data['x'][-1][k-1])
43 | data['y_diff{}'.format(k if k>1 else '')].append(data['y'][-1][k] - data['y'][-1][k-1])
44 | else:
45 | assert temp['pos'].count(i) < sequence
46 | return data
47 |
48 | def load_data(path):
49 | with open(path + '.pkl', 'rb') as data:
50 | data_dict = pickle.load(data)
51 | return data_dict
52 |
53 | def plot_bar(data, key):
54 | bars = np.bincount(data[key])
55 | _ = plt.bar(range(len(bars)), bars)
56 |
57 | def plot_type(data, x, type_='all', semilogy=False):
58 | if type_ == 'all':
59 | _ = plt.hist(x, bins=100)
60 | if semilogy:
61 | _ = plt.semilogy()
62 | else:
63 | fig, axes = plt.subplots(len(name_list) // 8, len(name_list) // 5,
64 | figsize=((len(name_list) // 5)*10,(len(name_list) // 8)*10))
65 | data_per_cat = collections.defaultdict(list)
66 | for j in range(len(data['id'])):
67 | data_per_cat[data['category_id'][j]].append(x[j])
68 | for i in range(len(name_list)):
69 | _ = axes[i % (len(name_list) // 8), i % (len(name_list) // 5)].hist(data_per_cat[i+1], bins=100)
70 | _ = axes[i % (len(name_list) // 8), i % (len(name_list) // 5)].set_title(name_list[i])
71 | plt.show()
72 |
73 |
74 | def plot_val(data, key, low=0., high=1., type_='all'):
75 | x = pd.Series(np.array(data[key])[:,0])
76 | print('orig_min',x.min())
77 | print('orig_max',x.max())
78 | x = x[x.between(x.quantile(low), x.quantile(high))]
79 | print('{}_min'.format(low),x.min())
80 | print('{}_max'.format(high),x.max())
81 | plot_type(data, x, type_)
82 |
83 | def plot_diff(data, key, semilogy=True, type_='all'):
84 | plot_type(data, data[key + '_diff'], type_, semilogy)
85 |
86 | def visualize_mask(mask):
87 | _ = plt.imshow(mask)
88 |
89 | def generate_dataframe(data, type_='all', mi=False, mi_samples=20000):
90 | stats = collections.defaultdict(list)
91 | warnings.filterwarnings(action='ignore')
92 | distributions = [scipy.stats.gennorm, scipy.stats.norm, scipy.stats.laplace]
93 | for i in range((0 if type_=='all' else np.max(data['category_id']))+1):
94 | if i == 0:
95 | stats['category'].append('all')
96 | stats['N'].append(len(data['id']))
97 | area_val = data['area_diff']
98 | x_val = data['x_diff']
99 | y_val = data['y_diff']
100 | else:
101 | stats['category'].append(name_list[i-1])
102 | stats['N'].append(np.count_nonzero(np.array(data['category_id']).astype(int) == i))
103 | area_val = [data['area_diff'][j] for j in range(len(data['area_diff'])) if i == data['category_id'][j]]
104 | x_val = [data['x_diff'][j] for j in range(len(data['x_diff'])) if i == data['category_id'][j]]
105 | y_val = [data['y_diff'][j] for j in range(len(data['y_diff'])) if i == data['category_id'][j]]
106 | vals = {'area': area_val, 'x': x_val, 'y': y_val}
107 | for x in vals.keys():
108 | stats['kurtosis_' + x].append("%.2f" % scipy.stats.kurtosis(vals[x]))
109 | for distribution in distributions:
110 | params = distribution.fit(vals[x])
111 | stats['{}_{}'.format(distribution.name, x)].append(['%.2e' % x for x in params])
112 | arg = params[:-2]
113 | loc = params[-2]
114 | scale = params[-1]
115 | stats['ll_{}_{}'.format(distribution.name, x)].append("%.2e" % distribution.logpdf(vals[x], *params).sum())
116 | stats['ks_{}_{}'.format(distribution.name, x)].append("%.2e" % scipy.stats.kstest(vals[x],
117 | lambda x: distribution.cdf(x,
118 | loc=loc,
119 | scale=scale,
120 | *arg))[1])
121 | stats['pearson_area_x'].append(["%.2f" % x for x in scipy.stats.pearsonr(vals['area'], vals['x'])])
122 | stats['pearson_area_y'].append(["%.2f" % x for x in scipy.stats.pearsonr(vals['area'], vals['y'])])
123 | stats['pearson_x_y'].append(["%.2f" % x for x in scipy.stats.pearsonr(vals['x'], vals['y'])])
124 | if mi:
125 | indices = random.sample(range(len(vals['area'])), min(mi_samples,stats['N'][-1]))
126 | stats['mi_area_x'].append("%.2f" % mutual_info_regression(np.array(vals['area']).reshape(-1,1)[indices],
127 | np.array(vals['x'])[indices]))
128 | stats['mi_area_y'].append("%.2f" % mutual_info_regression(np.array(vals['area']).reshape(-1,1)[indices],
129 | np.array(vals['y'])[indices]))
130 | stats['mi_x_y'].append("%.2f" % mutual_info_regression(np.array(vals['x']).reshape(-1,1)[indices],
131 | np.array(vals['y'])[indices]))
132 | return pd.DataFrame.from_dict(stats).sort_values('N', ascending=True)
133 |
134 |
135 | def find_best(df, criterion='ll'):
136 | best_df = pd.concat([df['category'],
137 | df['N'],
138 | df[[*[col for col in df.columns if ('area' in col and criterion in col)]]].astype(np.float64).idxmax(axis=1),
139 | df[[*[col for col in df.columns if ('x' in col and criterion in col)]]].astype(np.float64).idxmax(axis=1),
140 | df[[*[col for col in df.columns if ('y' in col and criterion in col)]]].astype(np.float64).idxmax(axis=1),
141 | ], axis=1)
142 | return best_df.sort_values('N',ascending=False)
143 |
144 |
145 |
--------------------------------------------------------------------------------
/universal_logger/logger.py:
--------------------------------------------------------------------------------
1 | # Author: Alexandre Drouin
2 | # Adapted by: Sebastien Lachapelle
3 |
4 | import json
5 | import numpy as np
6 | import os
7 | import queue
8 |
9 | from time import strftime, time
10 |
11 |
12 | try:
13 | import comet_ml
14 | COMET_AVAIL = True
15 | except:
16 | COMET_AVAIL = False
17 |
18 | try:
19 | import tensorboardX
20 | TBX_AVAIL = True
21 | except:
22 | TBX_AVAIL = False
23 |
24 |
25 | def _check_randomstate(random_state):
26 | if isinstance(random_state, int):
27 | random_state = np.random.RandomState(random_state)
28 | if not isinstance(random_state, np.random.RandomState):
29 | raise ValueError("Random state must be numpy RandomState or int.")
30 | return random_state
31 |
32 |
33 | def _prefix_stage(metrics, stage):
34 | if stage is not None:
35 | metrics = {"%s/%s" % (stage, k): v for k, v in metrics.items()}
36 | return metrics
37 |
38 |
39 | class CometLogger(object):
40 | def __init__(self, experiment):
41 | self.experiment = experiment
42 |
43 | def log_figure(self, stage, step, name, figure):
44 | prefix = ""
45 | #if stage is not None:
46 | # prefix += stage + "/"
47 | #if step is not None:
48 | # prefix += str(step) + "/"
49 | try:
50 | self.experiment.log_figure(figure_name=prefix + name, figure=figure, step=step, overwrite=False)
51 | except:
52 | self.experiment.log_image(figure, name=prefix + name, step=step, overwrite=False)
53 |
54 | def log_metrics(self, stage, step, metrics):
55 | self.experiment.log_metrics(step=step, dic=_prefix_stage(metrics, stage))
56 |
57 |
58 | class JsonLogger(object):
59 | def __init__(self, path, time=True, max_fig_save=None):
60 | self.path = path
61 | os.makedirs(os.path.dirname(path), exist_ok=True)
62 | self.time = time
63 | self.max_fig_save = max_fig_save # makes sure we don't save too many .png files.
64 | self.current_figs = {}
65 |
66 | def log_metrics(self, stage, step, metrics):
67 | metrics["stage"] = stage
68 | metrics["step"] = step
69 | if self.time:
70 | metrics["time"] = time()
71 | with open(os.path.join(self.path, "log.ndjson"), "a") as f:
72 | f.write(json.dumps(metrics) + "\n")
73 |
74 | def log_figure(self, stage, step, name, figure):
75 | # if stage is not None:
76 | # prefix += stage + "/"
77 | # if step is not None:
78 | # prefix += str(step) + "/"
79 | if name not in self.current_figs.keys():
80 | self.current_figs[name] = queue.Queue()
81 |
82 | file_name = f"{name}_{step}.png"
83 | figure.savefig(os.path.join(self.path, file_name))
84 | self.current_figs[name].put(file_name)
85 |
86 | # removing old figures
87 | if self.current_figs[name].qsize() > self.max_fig_save:
88 | file_to_delete = self.current_figs[name].get()
89 | os.remove(os.path.join(self.path, file_to_delete))
90 |
91 |
92 | class StdoutLogger(object):
93 | def __init__(self, time=True):
94 | self.time = time
95 |
96 | def log_metrics(self, stage, step, metrics):
97 | prefix = "" if stage is None else (stage + " -- ")
98 | try:
99 | log = f"{prefix}{step}:\t" + "\t".join("%s: %.6f" % (m, v) for m, v in metrics.items())
100 | except:
101 | log = f"{prefix}{step}:\t" + "\t".join("%s: %s" % (m, v) for m, v in metrics.items())
102 | if self.time:
103 | log += "\t time: " + strftime('%X %x')
104 | print(log)
105 |
106 |
107 | class TensorboardXLogger(object):
108 | def __init__(self, summary_writer):
109 | self.summary_writer = summary_writer
110 |
111 | def log_metrics(self, stage, step, metrics):
112 | for m, v in _prefix_stage(metrics, stage).items():
113 | self.summary_writer.add_scalar(m, v, step)
114 |
115 |
116 | # XXX: I'll eventually move this class to the shared code-base
117 | class UniversalLogger(object):
118 | """
119 | A logger that simultaneously logs to multiple outputs.
120 |
121 | Parameters:
122 | -----------
123 | comet: comet_ml.Experiment
124 | A comet experiment
125 | json: str
126 | The path to a json file
127 | stdout: bool
128 | Whether or not to print metrics to the standard output
129 | tensorboardx: tensorboardx.SummaryWriter
130 | A summary writer for tensorboardx
131 | time: bool
132 | Whether to log the time in some loggers (supported: json, stdout)
133 | throttle: int
134 | The minimum time between logs in seconds
135 |
136 | """
137 | def __init__(self, comet=None, json=None, stdout=False, tensorboardx=None, time=True, throttle=None, max_fig_save=None):
138 | super().__init__()
139 | loggers = []
140 | if comet is not None:
141 | if not COMET_AVAIL:
142 | raise RuntimeError("comet_ml is not available on this platform. Please install it.")
143 | loggers.append(CometLogger(experiment=comet))
144 | if json is not None:
145 | loggers.append(JsonLogger(json, time=time, max_fig_save=max_fig_save))
146 | if stdout:
147 | loggers.append(StdoutLogger(time=time))
148 | if tensorboardx is not None:
149 | if not TBX_AVAIL:
150 | raise RuntimeError("TensorboardX is not available on this platform. Please install it.")
151 | loggers.append(TensorboardXLogger(tensorboardx))
152 | assert len(loggers) >= 1
153 | self.loggers = loggers
154 |
155 | # Throttling
156 | self.throttle = throttle
157 | self.last_log_time = 0
158 | self.current_stage = None
159 |
160 | def _check_stage(self, stage):
161 | # This will force logging if we are entering a new stage
162 | if self.current_stage != stage:
163 | self.last_log_time = 0
164 | self.current_stage = stage
165 |
166 | def _check_throttle(self):
167 | return self.throttle is None or time() - self.last_log_time > self.throttle
168 |
169 | def log_metrics(self, step, metrics, stage=None, throttle=True):
170 | """
171 | Log a real-valued metric
172 |
173 | Parameters:
174 | -----------
175 | stage: str
176 | The current stage of execution (e.g., "train", "test", etc.)
177 | step: uint
178 | The current step number (e.g., epoch)
179 | metrics: dict
180 | A dictionnary with metric names as keys and metric values as values
181 | throttle: bool
182 | Whether to respect log throttling or not (default is True)
183 |
184 | """
185 | self._check_stage(stage)
186 | if self._check_throttle() or not throttle:
187 | for log in self.loggers:
188 | if hasattr(log, "log_metrics"):
189 | log.log_metrics(stage, step, dict(metrics))
190 | self.last_log_time = time()
191 |
192 | def log_figure(self, name, figure, stage=None, step=None, throttle=True):
193 | """
194 | Log a matplotlib figure
195 |
196 | Parameters:
197 | -----------
198 | name: str
199 | The name of the figure
200 | figure: mpl.Figure
201 | The matplotlib figure
202 | step: uint
203 | The current step number (default: None). If required by a logger and not provided, an exception will be
204 | raised.
205 | throttle: bool
206 | Whether to respect log throttling or not (default is True)
207 |
208 | """
209 | self._check_stage(stage)
210 | if self._check_throttle() or not throttle:
211 | for log in self.loggers:
212 | if hasattr(log, "log_figure"):
213 | log.log_figure(stage, step, name, figure)
214 | self.last_log_time = time()
--------------------------------------------------------------------------------
/baseline_models/icebeem/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | import pathlib
5 | import pickle
6 | import json
7 |
8 | try:
9 | from comet_ml import Experiment
10 | COMET_AVAIL = True
11 | except:
12 | COMET_AVAIL = False
13 | import numpy as np
14 | import torch
15 |
16 | sys.path.insert(0, str(pathlib.Path(__file__).parent.parent.parent))
17 | from baseline_models.icebeem.models.ivae.ivae_wrapper import IVAE_wrapper
18 | from baseline_models.icebeem.models.icebeem_wrapper import ICEBEEM_wrapper
19 | from train import get_dataset, get_loader
20 | from universal_logger.logger import UniversalLogger
21 | from metrics import evaluate_disentanglement
22 |
23 | def parse_sim():
24 | parser = argparse.ArgumentParser(description='')
25 | parser.add_argument("--output_dir", required=True,
26 | help="Directory to output logs and model checkpoints")
27 | parser.add_argument("--dataset", type=str, required=True,
28 | help="Type of the dataset to be used 'toy-MANIFOLD/TRANSITION_MODEL'")
29 | parser.add_argument('--method', type=str, default='icebeem', choices=['icebeem', 'ivae'],
30 | help='Method to employ. Should be TCL, iVAE or ICE-BeeM')
31 | parser.add_argument("--dataroot", type=str, default="./",
32 | help="path to dataset")
33 | parser.add_argument("--gt_z_dim", type=int, default=5,
34 | help="ground truth dimensionality of z (for TRANSITION_MODEL == 'linear_system')")
35 | parser.add_argument("--gt_x_dim", type=int, default=10,
36 | help="ground truth dimensionality of x (for MANIFOLD == 'nn')")
37 | parser.add_argument("--num_samples", type=float, default=int(1e6),
38 | help="Number of samples")
39 | parser.add_argument("--rand_g_density", type=float, default=None,
40 | help="Probability of sampling an edge. When None, the graph is set to a default (or to gt_graph_name).")
41 | parser.add_argument("--gt_graph_name", type=str, default=None,
42 | help="Name of the ground-truth graph to use in synthetic data.")
43 | parser.add_argument("--architecture", type=str, default='ilcm_tabular', choices=['ilcm_tabular'],
44 | help="encoder/decoder architecture.")
45 | parser.add_argument("--learn_decoder_var", action='store_true',
46 | help="Whether to learn the variance of the output decoder")
47 | parser.add_argument("--train_prop", type=float, default=None,
48 | help="proportion of all samples used in validation set")
49 | parser.add_argument("--valid_prop", type=float, default=0.10,
50 | help="proportion of all samples used in validation set")
51 | parser.add_argument("--test_prop", type=float, default=0.10,
52 | help="proportion of all samples used in test set")
53 | parser.add_argument("--batch_size", type=int, default=1024,
54 | help="batch size used during training")
55 | parser.add_argument("--eval_batch_size", type=int, default=1024,
56 | help="batch size used during evaluation")
57 | parser.add_argument("--time_limit", type=float, default=None,
58 | help="After this amount of time, terminate training. (in hours)")
59 | parser.add_argument("--max_iter", type=int, default=int(1e6),
60 | help="Maximal amount of iterations")
61 | parser.add_argument("--ivae_lr", type=float, default=1e-4,
62 | help="After this amount of time, terminate training. (in hours)")
63 | parser.add_argument("--seed", type=int, default=0,
64 | help="manual seed")
65 | parser.add_argument('--no_print', action="store_true",
66 | help='do not print')
67 | parser.add_argument('--comet_key', type=str, default=None,
68 | help="comet api-key")
69 | parser.add_argument('--comet_tag', type=str, default=None,
70 | help="comet tag, to ease comparison")
71 | parser.add_argument('--comet_workspace', type=str, default=None,
72 | help="comet workspace")
73 | parser.add_argument('--comet_project_name', type=str, default=None,
74 | help="comet project_name")
75 | parser.add_argument("--no_cuda", action="store_false", dest="cuda",
76 | help="Disables cuda")
77 |
78 | #parser.add_argument('--dataset', type=str, default='TCL', help='Dataset to run experiments. Should be TCL or IMCA')
79 |
80 | #parser.add_argument('--config', type=str, default='imca.yaml', help='Path to the config file')
81 | #parser.add_argument('--run', type=str, default='run/', help='Path for saving running related data.')
82 | #parser.add_argument('--nSims', type=int, default=10, help='Number of simulations to run')
83 |
84 | #parser.add_argument('--test', action='store_true', help='Whether to evaluate the models from checkpoints')
85 | #parser.add_argument('--plot', action='store_true')
86 |
87 | return parser.parse_args()
88 |
89 | def main(args):
90 | print('WARNING: this code do not support discrete auxiliary variable. See warning in mean_corr_coef function in metrics.py')
91 | print('Running {} experiments using {}'.format(args.dataset, args.method))
92 |
93 | device = torch.device('cuda:0' if args.cuda else 'cpu')
94 |
95 | if args.method.lower() == 'ivae':
96 | args.mode = "ivae"
97 | else:
98 | raise ValueError('Unsupported method {}'.format(args.method))
99 |
100 | ## ---- Save hparams ---- ##
101 | kwargs = vars(args)
102 | with open(os.path.join(args.output_dir, "hparams.json"), "w") as fp:
103 | json.dump(kwargs, fp, sort_keys=True, indent=4)
104 |
105 | ## ---- Data ---- ##
106 | args.n_lag = 0
107 | args.no_norm = False
108 | args.n_workers = 0 # can't put it to 4 since we get weird error msg...
109 | image_shape, cont_c_dim, disc_c_dim, disc_c_n_values, train_dataset, valid_dataset, test_dataset = get_dataset(args)
110 | _, _, test_loader = get_loader(args, train_dataset, valid_dataset, test_dataset)
111 | x, y, s = train_dataset.x.cpu().numpy(), train_dataset.c.cpu().numpy(), train_dataset.z.cpu().numpy()
112 |
113 | ## ---- Logger ---- ##
114 | if COMET_AVAIL and args.comet_key is not None:
115 | comet_exp = Experiment(api_key=args.comet_key, project_name=args.comet_project_name,
116 | workspace=args.comet_workspace, auto_metric_logging=False, auto_param_logging=False)
117 | comet_exp.log_parameters(vars(args))
118 | if args.comet_tag is not None:
119 | comet_exp.add_tag(args.comet_tag)
120 | else:
121 | comet_exp = None
122 | logger = UniversalLogger(comet=comet_exp,
123 | stdout=(not args.no_print),
124 | json=args.output_dir, throttle=None)
125 |
126 | ## ---- Running ---- ##
127 | # the argument `architecture="ilcm_tabular"` will choose the same encoder/decoder as in ilcm for synthetic experiments.
128 | model = IVAE_wrapper(x, y, args.gt_z_dim, ckpt_folder=args.output_dir, batch_size=args.batch_size, max_iter=args.max_iter, #max_iter=100,
129 | seed=args.seed, n_layers=5, hidden_dim=512, lr=args.ivae_lr,
130 | architecture=args.architecture, logger=logger, time_limit=args.time_limit, learn_decoder_var=args.learn_decoder_var)
131 |
132 | ## ---- Evaluate performance ---- ##
133 | mcc, consistent_r, r, cc, C_hat, C_pattern, perm_mat, z, z_hat, transposed_consistent_r = evaluate_disentanglement(model, test_loader, device, args)
134 |
135 | ## ---- Save ---- ##
136 | # save scores
137 | metrics = {"mean_corr_coef_final": mcc,
138 | "consistent_r_final": consistent_r,
139 | "r_final": r,
140 | "transposed_consistent_r_final": transposed_consistent_r}
141 | logger.log_metrics(step=0, metrics=metrics)
142 |
143 | # save both ground_truth and learned latents
144 | np.save(os.path.join(args.output_dir, "z_hat_final.npy"), z_hat)
145 | np.save(os.path.join(args.output_dir, "z_gt_final.npy"), z)
146 |
147 | if __name__ == '__main__':
148 | args = parse_sim()
149 | main(args)
150 |
--------------------------------------------------------------------------------
/baseline_models/slowvae_pcl/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import shutil
3 | import os, json, sys, traceback, time
4 | import pathlib
5 |
6 | try:
7 | from comet_ml import Experiment
8 | COMET_AVAIL = True
9 | except:
10 | COMET_AVAIL = False
11 | import numpy as np
12 | import torch
13 | import datetime
14 |
15 | sys.path.insert(0, str(pathlib.Path(__file__).parent))
16 | from scripts.solver import Solver
17 |
18 | sys.path.insert(0, str(pathlib.Path(__file__).parent.parent.parent))
19 | from train import get_dataset, get_loader
20 | from universal_logger.logger import UniversalLogger
21 | from metrics import evaluate_disentanglement
22 |
23 | torch.backends.cudnn.enabled = True
24 | torch.backends.cudnn.benchmark = True
25 |
26 |
27 | def main(args, writer=None):
28 | torch.manual_seed(args.seed)
29 | torch.cuda.manual_seed(args.seed)
30 | np.random.seed(args.seed)
31 |
32 | device = torch.device('cuda:0' if args.cuda else 'cpu')
33 |
34 | ## ---- Data ---- ##
35 | args.no_norm = False
36 | args.n_lag = 0 # get dataset expects this argument, but no effect.
37 | args.num_workers = args.n_workers
38 | image_shape, cont_c_dim, disc_c_dim, disc_c_n_values, train_dataset, valid_dataset, test_dataset, = get_dataset(args)
39 | train_loader, valid_loader, test_loader = get_loader(args, train_dataset, valid_dataset, test_dataset)
40 | data_loader = train_loader
41 | if len(image_shape) == 3:
42 | args.num_channel = image_shape[-1]
43 | else:
44 | args.num_channel = None
45 |
46 | ## ---- Logging ---- ##
47 | if COMET_AVAIL and args.comet_key is not None and args.comet_workspace is not None and args.comet_project_name is not None:
48 | comet_exp = Experiment(api_key=args.comet_key, project_name=args.comet_project_name,
49 | workspace=args.comet_workspace, auto_metric_logging=False, auto_param_logging=False)
50 | comet_exp.log_parameters(vars(args))
51 | if args.comet_tag is not None:
52 | comet_exp.add_tag(args.comet_tag)
53 | else:
54 | comet_exp = None
55 | logger = UniversalLogger(comet=comet_exp,
56 | stdout=(not args.no_print),
57 | json=args.output_dir, throttle=None)
58 |
59 | t0 = time.time()
60 |
61 | # saving hp
62 | ## ---- Save hparams ---- ##
63 | if args.pcl:
64 | args.mode = "pcl"
65 | else:
66 | args.mode = 'slowvae'
67 | kwargs = vars(args)
68 | with open(os.path.join(args.output_dir, "hparams.json"), "w") as fp:
69 | json.dump(kwargs, fp, sort_keys=True, indent=4)
70 | with open(os.path.join(args.output_dir, "args"), "w") as f:
71 | json.dump(args.__dict__, f)
72 |
73 | net = Solver(args, image_shape, data_loader=data_loader, logger=logger, z_dim=train_dataset.z_dim)
74 | failure = net.train(writer)
75 | if failure:
76 | print('failed in %.2fs' % (time.time() - t0))
77 | #shutil.rmtree(args.output_dir)
78 | else:
79 | print('done in %.2fs' % (time.time() - t0))
80 |
81 | ## ---- Evaluate performance ---- #
82 | mcc, consistent_r, r, cc, C_hat, C_pattern, perm_mat, z, z_hat, transposed_consistent_r = evaluate_disentanglement(net.net, test_loader, device, args)
83 |
84 | ## ---- Save ---- ##
85 | # save scores
86 | metrics = {"mean_corr_coef_final": mcc,
87 | "consistent_r_final": consistent_r,
88 | "r_final": r,
89 | "transposed_consistent_r_final": transposed_consistent_r}
90 | logger.log_metrics(step=0, metrics=metrics)
91 |
92 | # save both ground_truth and learned latents
93 | np.save(os.path.join(args.output_dir, "z_hat_final.npy"), z_hat)
94 | np.save(os.path.join(args.output_dir, "z_gt_final.npy"), z)
95 |
96 |
97 | ### For Random Search ###
98 | def randint(low, high):
99 | return np.int(np.random.randint(low, high, 1)[0])
100 |
101 | def uniform(low, high):
102 | return np.random.uniform(low, high, 1)[0]
103 |
104 | def loguniform(low, high):
105 | return np.exp(np.random.uniform(np.log(low), np.log(high), 1))[0]
106 |
107 | if __name__ == "__main__":
108 | parser = argparse.ArgumentParser(description='slowVAE')
109 | parser.add_argument('--pcl', action='store_true')
110 | parser.add_argument('--r_func', type=str, default='default', choices=['default', 'mlp'],
111 | help='Type of regression function used for PCL')
112 | parser.add_argument("--output_dir", required=True,
113 | help="Directory to output logs and model checkpoints")
114 | parser.add_argument("--dataset", type=str, required=True,
115 | help="Type of the dataset to be used 'toy-MANIFOLD/TRANSITION_MODEL'")
116 | parser.add_argument("--dataroot", type=str, default="./",
117 | help="path to dataset")
118 | parser.add_argument("--gt_z_dim", type=int, default=10,
119 | help="ground truth dimensionality of z (for TRANSITION_MODEL == 'linear_system')")
120 | parser.add_argument("--gt_x_dim", type=int, default=20,
121 | help="ground truth dimensionality of x (for MANIFOLD == 'nn')")
122 | parser.add_argument("--num_samples", type=float, default=int(1e6),
123 | help="number of trajectories in toy dataset")
124 | parser.add_argument("--architecture", type=str, default='ilcm_tabular', choices=['ilcm_tabular', 'standard_conv'],
125 | help="VAE encoder/decoder architecture.")
126 | parser.add_argument("--train_prop", type=float, default=None,
127 | help="proportion of all samples used in validation set")
128 | parser.add_argument("--valid_prop", type=float, default=0.10,
129 | help="proportion of all samples used in validation set")
130 | parser.add_argument("--test_prop", type=float, default=0.10,
131 | help="proportion of all samples used in test set")
132 | parser.add_argument("--n_workers", type=int, default=4)
133 | parser.add_argument("--time_limit", type=float, default=None,
134 | help="After this amount of time, terminate training. (in hours)")
135 | parser.add_argument("--max_iter", type=int, default=int(1e6),
136 | help="Maximal amount of iterations")
137 | parser.add_argument("--seed", type=int, default=0,
138 | help="manual seed")
139 | parser.add_argument('--no_print', action="store_true",
140 | help='do not print')
141 | parser.add_argument('--comet_key', type=str, default=None,
142 | help="comet api-key")
143 | parser.add_argument('--comet_tag', type=str, default=None,
144 | help="comet tag, to ease comparison")
145 | parser.add_argument('--comet_workspace', type=str, default=None,
146 | help="comet workspace")
147 | parser.add_argument('--comet_project_name', type=str, default=None,
148 | help="comet project_name")
149 | parser.add_argument("--rand_g_density", type=float, default=None,
150 | help="Probability of sampling an edge. When None, the graph is set to a default (or to gt_graph_name).")
151 | parser.add_argument("--gt_graph_name", type=str, default=None,
152 | help="Name of the ground-truth graph to use in synthetic data.")
153 | parser.add_argument("--add_noise", type=float, default=0.0,
154 | help="Add normal noise sigma = add_noise on images (only training data)")
155 | parser.add_argument("--no_cuda", action="store_false", dest="cuda",
156 | help="Disables cuda")
157 | parser.add_argument("--batch_size", type=int, default=1024,
158 | help="batch size used during training")
159 | parser.add_argument("--eval_batch_size", type=int, default=1024,
160 | help="batch size used during evaluation")
161 | parser.add_argument('--beta', default=1, type=float,
162 | help='weight for kl to normal')
163 | parser.add_argument('--gamma', default=10, type=float,
164 | help='weight for kl to laplace')
165 | parser.add_argument('--rate_prior', default=6, type=float,
166 | help='rate (or inverse scale) for prior laplace (larger -> sparser).')
167 | parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
168 | parser.add_argument('--beta1', default=0.9, type=float,
169 | help='Adam optimizer beta1')
170 | parser.add_argument('--beta2', default=0.999, type=float,
171 | help='Adam optimizer beta2')
172 | parser.add_argument('--ckpt-name', default='last', type=str,
173 | help='load previous checkpoint. insert checkpoint filename')
174 | parser.add_argument('--log_step', default=100, type=int,
175 | help='numer of iterations after which data is logged')
176 | parser.add_argument('--save_step', default=10000, type=int,
177 | help='number of iterations after which a checkpoint is saved')
178 | args = parser.parse_args()
179 |
180 | args = main(args)
181 |
--------------------------------------------------------------------------------
/baseline_models/beta-tcvae/elbo_decomposition.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | from numbers import Number
4 | from tqdm import tqdm
5 | import torch
6 | from torch.autograd import Variable
7 |
8 | import lib.dist as dist
9 | import lib.flows as flows
10 |
11 |
12 | def estimate_entropies(qz_samples, qz_params, q_dist):
13 | """Computes the term:
14 | E_{p(x)} E_{q(z|x)} [-log q(z)]
15 | and
16 | E_{p(x)} E_{q(z_j|x)} [-log q(z_j)]
17 | where q(z) = 1/N sum_n=1^N q(z|x_n).
18 | Assumes samples are from q(z|x) for *all* x in the dataset.
19 | Assumes that q(z|x) is factorial ie. q(z|x) = prod_j q(z_j|x).
20 |
21 | Computes numerically stable NLL:
22 | - log q(z) = log N - logsumexp_n=1^N log q(z|x_n)
23 |
24 | Inputs:
25 | -------
26 | qz_samples (K, S) Variable
27 | qz_params (N, K, nparams) Variable
28 | """
29 |
30 | # Only take a sample subset of the samples
31 | qz_samples = qz_samples.index_select(1, Variable(torch.randperm(qz_samples.size(1))[:10000].cuda()))
32 |
33 | K, S = qz_samples.size()
34 | N, _, nparams = qz_params.size()
35 | assert(nparams == q_dist.nparams)
36 | assert(K == qz_params.size(1))
37 |
38 | marginal_entropies = torch.zeros(K).cuda()
39 | joint_entropy = torch.zeros(1).cuda()
40 |
41 | pbar = tqdm(total=S)
42 | k = 0
43 | while k < S:
44 | batch_size = min(10, S - k)
45 | logqz_i = q_dist.log_density(
46 | qz_samples.view(1, K, S).expand(N, K, S)[:, :, k:k + batch_size],
47 | qz_params.view(N, K, 1, nparams).expand(N, K, S, nparams)[:, :, k:k + batch_size])
48 | k += batch_size
49 |
50 | # computes - log q(z_i) summed over minibatch
51 | marginal_entropies += (math.log(N) - logsumexp(logqz_i, dim=0, keepdim=False).data).sum(1)
52 | # computes - log q(z) summed over minibatch
53 | logqz = logqz_i.sum(1) # (N, S)
54 | joint_entropy += (math.log(N) - logsumexp(logqz, dim=0, keepdim=False).data).sum(0)
55 | pbar.update(batch_size)
56 | pbar.close()
57 |
58 | marginal_entropies /= S
59 | joint_entropy /= S
60 |
61 | return marginal_entropies, joint_entropy
62 |
63 |
64 | def logsumexp(value, dim=None, keepdim=False):
65 | """Numerically stable implementation of the operation
66 |
67 | value.exp().sum(dim, keepdim).log()
68 | """
69 | if dim is not None:
70 | m, _ = torch.max(value, dim=dim, keepdim=True)
71 | value0 = value - m
72 | if keepdim is False:
73 | m = m.squeeze(dim)
74 | return m + torch.log(torch.sum(torch.exp(value0),
75 | dim=dim, keepdim=keepdim))
76 | else:
77 | m = torch.max(value)
78 | sum_exp = torch.sum(torch.exp(value - m))
79 | if isinstance(sum_exp, Number):
80 | return m + math.log(sum_exp)
81 | else:
82 | return m + torch.log(sum_exp)
83 |
84 |
85 | def analytical_NLL(qz_params, q_dist, prior_dist, qz_samples=None):
86 | """Computes the quantities
87 | 1/N sum_n=1^N E_{q(z|x)} [ - log q(z|x) ]
88 | and
89 | 1/N sum_n=1^N E_{q(z_j|x)} [ - log p(z_j) ]
90 |
91 | Inputs:
92 | -------
93 | qz_params (N, K, nparams) Variable
94 |
95 | Returns:
96 | --------
97 | nlogqz_condx (K,) Variable
98 | nlogpz (K,) Variable
99 | """
100 | pz_params = Variable(torch.zeros(1).type_as(qz_params.data).expand(qz_params.size()), volatile=True)
101 |
102 | nlogqz_condx = q_dist.NLL(qz_params).mean(0)
103 | nlogpz = prior_dist.NLL(pz_params, qz_params).mean(0)
104 | return nlogqz_condx, nlogpz
105 |
106 |
107 | def elbo_decomposition(vae, dataset_loader):
108 | N = len(dataset_loader.dataset) # number of data samples
109 | K = vae.z_dim # number of latent variables
110 | S = 1 # number of latent variable samples
111 | nparams = vae.q_dist.nparams
112 |
113 | print('Computing q(z|x) distributions.')
114 | # compute the marginal q(z_j|x_n) distributions
115 | qz_params = torch.Tensor(N, K, nparams)
116 | n = 0
117 | logpx = 0
118 | for xs in dataset_loader:
119 | batch_size = xs.size(0)
120 | xs = Variable(xs.view(batch_size, -1, 64, 64).cuda(), volatile=True)
121 | z_params = vae.encoder.forward(xs).view(batch_size, K, nparams)
122 | qz_params[n:n + batch_size] = z_params.data
123 | n += batch_size
124 |
125 | # estimate reconstruction term
126 | for _ in range(S):
127 | z = vae.q_dist.sample(params=z_params)
128 | x_params = vae.decoder.forward(z)
129 | logpx += vae.x_dist.log_density(xs, params=x_params).view(batch_size, -1).data.sum()
130 | # Reconstruction term
131 | logpx = logpx / (N * S)
132 |
133 | qz_params = Variable(qz_params.cuda(), volatile=True)
134 |
135 | print('Sampling from q(z).')
136 | # sample S times from each marginal q(z_j|x_n)
137 | qz_params_expanded = qz_params.view(N, K, 1, nparams).expand(N, K, S, nparams)
138 | qz_samples = vae.q_dist.sample(params=qz_params_expanded)
139 | qz_samples = qz_samples.transpose(0, 1).contiguous().view(K, N * S)
140 |
141 | print('Estimating entropies.')
142 | marginal_entropies, joint_entropy = estimate_entropies(qz_samples, qz_params, vae.q_dist)
143 |
144 | if hasattr(vae.q_dist, 'NLL'):
145 | nlogqz_condx = vae.q_dist.NLL(qz_params).mean(0)
146 | else:
147 | nlogqz_condx = - vae.q_dist.log_density(qz_samples,
148 | qz_params_expanded.transpose(0, 1).contiguous().view(K, N * S)).mean(1)
149 |
150 | if hasattr(vae.prior_dist, 'NLL'):
151 | pz_params = vae._get_prior_params(N * K).contiguous().view(N, K, -1)
152 | nlogpz = vae.prior_dist.NLL(pz_params, qz_params).mean(0)
153 | else:
154 | nlogpz = - vae.prior_dist.log_density(qz_samples.transpose(0, 1)).mean(0)
155 |
156 | # nlogqz_condx, nlogpz = analytical_NLL(qz_params, vae.q_dist, vae.prior_dist)
157 | nlogqz_condx = nlogqz_condx.data
158 | nlogpz = nlogpz.data
159 |
160 | # Independence term
161 | # KL(q(z)||prod_j q(z_j)) = log q(z) - sum_j log q(z_j)
162 | dependence = (- joint_entropy + marginal_entropies.sum())[0]
163 |
164 | # Information term
165 | # KL(q(z|x)||q(z)) = log q(z|x) - log q(z)
166 | information = (- nlogqz_condx.sum() + joint_entropy)[0]
167 |
168 | # Dimension-wise KL term
169 | # sum_j KL(q(z_j)||p(z_j)) = sum_j (log q(z_j) - log p(z_j))
170 | dimwise_kl = (- marginal_entropies + nlogpz).sum()
171 |
172 | # Compute sum of terms analytically
173 | # KL(q(z|x)||p(z)) = log q(z|x) - log p(z)
174 | analytical_cond_kl = (- nlogqz_condx + nlogpz).sum()
175 |
176 | print('Dependence: {}'.format(dependence))
177 | print('Information: {}'.format(information))
178 | print('Dimension-wise KL: {}'.format(dimwise_kl))
179 | print('Analytical E_p(x)[ KL(q(z|x)||p(z)) ]: {}'.format(analytical_cond_kl))
180 | print('Estimated ELBO: {}'.format(logpx - analytical_cond_kl))
181 |
182 | return logpx, dependence, information, dimwise_kl, analytical_cond_kl, marginal_entropies, joint_entropy
183 |
184 |
185 | if __name__ == '__main__':
186 | import argparse
187 | parser = argparse.ArgumentParser()
188 | parser.add_argument('-checkpt', required=True)
189 | parser.add_argument('-save', type=str, default='.')
190 | parser.add_argument('-gpu', type=int, default=0)
191 | args = parser.parse_args()
192 |
193 | def load_model_and_dataset(checkpt_filename):
194 | checkpt = torch.load(checkpt_filename)
195 | args = checkpt['args']
196 | state_dict = checkpt['state_dict']
197 |
198 | # backwards compatibility
199 | if not hasattr(args, 'conv'):
200 | args.conv = False
201 |
202 | from vae_quant import VAE, setup_data_loaders
203 |
204 | # model
205 | if args.dist == 'normal':
206 | prior_dist = dist.Normal()
207 | q_dist = dist.Normal()
208 | elif args.dist == 'laplace':
209 | prior_dist = dist.Laplace()
210 | q_dist = dist.Laplace()
211 | elif args.dist == 'flow':
212 | prior_dist = flows.FactorialNormalizingFlow(dim=args.latent_dim, nsteps=32)
213 | q_dist = dist.Normal()
214 | vae = VAE(z_dim=args.latent_dim, use_cuda=True, prior_dist=prior_dist, q_dist=q_dist, conv=args.conv)
215 | vae.load_state_dict(state_dict, strict=False)
216 | vae.eval()
217 |
218 | # dataset loader
219 | loader = setup_data_loaders(args, use_cuda=True)
220 | return vae, loader
221 |
222 | torch.cuda.set_device(args.gpu)
223 | vae, dataset_loader = load_model_and_dataset(args.checkpt)
224 | logpx, dependence, information, dimwise_kl, analytical_cond_kl, marginal_entropies, joint_entropy = \
225 | elbo_decomposition(vae, dataset_loader)
226 | torch.save({
227 | 'logpx': logpx,
228 | 'dependence': dependence,
229 | 'information': information,
230 | 'dimwise_kl': dimwise_kl,
231 | 'analytical_cond_kl': analytical_cond_kl,
232 | 'marginal_entropies': marginal_entropies,
233 | 'joint_entropy': joint_entropy
234 | }, os.path.join(args.save, 'elbo_decomposition.pth'))
235 |
--------------------------------------------------------------------------------
/baseline_models/beta-tcvae/disentanglement_metrics.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import torch
4 | from tqdm import tqdm
5 | from torch.utils.data import DataLoader
6 | from torch.autograd import Variable
7 |
8 | import lib.utils as utils
9 | from metric_helpers.loader import load_model_and_dataset
10 | from metric_helpers.mi_metric import compute_metric_shapes, compute_metric_faces
11 |
12 |
13 | def estimate_entropies(qz_samples, qz_params, q_dist, n_samples=10000, weights=None):
14 | """Computes the term:
15 | E_{p(x)} E_{q(z|x)} [-log q(z)]
16 | and
17 | E_{p(x)} E_{q(z_j|x)} [-log q(z_j)]
18 | where q(z) = 1/N sum_n=1^N q(z|x_n).
19 | Assumes samples are from q(z|x) for *all* x in the dataset.
20 | Assumes that q(z|x) is factorial ie. q(z|x) = prod_j q(z_j|x).
21 |
22 | Computes numerically stable NLL:
23 | - log q(z) = log N - logsumexp_n=1^N log q(z|x_n)
24 |
25 | Inputs:
26 | -------
27 | qz_samples (K, N) Variable
28 | qz_params (N, K, nparams) Variable
29 | weights (N) Variable
30 | """
31 |
32 | # Only take a sample subset of the samples
33 | if weights is None:
34 | qz_samples = qz_samples.index_select(1, Variable(torch.randperm(qz_samples.size(1))[:n_samples].cuda()))
35 | else:
36 | sample_inds = torch.multinomial(weights, n_samples, replacement=True)
37 | qz_samples = qz_samples.index_select(1, sample_inds)
38 |
39 | K, S = qz_samples.size()
40 | N, _, nparams = qz_params.size()
41 | assert(nparams == q_dist.nparams)
42 | assert(K == qz_params.size(1))
43 |
44 | if weights is None:
45 | weights = -math.log(N)
46 | else:
47 | weights = torch.log(weights.view(N, 1, 1) / weights.sum())
48 |
49 | entropies = torch.zeros(K).cuda()
50 |
51 | pbar = tqdm(total=S)
52 | k = 0
53 | while k < S:
54 | batch_size = min(10, S - k)
55 | logqz_i = q_dist.log_density(
56 | qz_samples.view(1, K, S).expand(N, K, S)[:, :, k:k + batch_size],
57 | qz_params.view(N, K, 1, nparams).expand(N, K, S, nparams)[:, :, k:k + batch_size])
58 | k += batch_size
59 |
60 | # computes - log q(z_i) summed over minibatch
61 | entropies += - utils.logsumexp(logqz_i + weights, dim=0, keepdim=False).data.sum(1)
62 | pbar.update(batch_size)
63 | pbar.close()
64 |
65 | entropies /= S
66 |
67 | return entropies
68 |
69 |
70 | def mutual_info_metric_shapes(vae, shapes_dataset):
71 | dataset_loader = DataLoader(shapes_dataset, batch_size=1000, num_workers=1, shuffle=False)
72 |
73 | N = len(dataset_loader.dataset) # number of data samples
74 | K = vae.z_dim # number of latent variables
75 | nparams = vae.q_dist.nparams
76 | vae.eval()
77 |
78 | print('Computing q(z|x) distributions.')
79 | qz_params = torch.Tensor(N, K, nparams)
80 |
81 | n = 0
82 | for xs in dataset_loader:
83 | batch_size = xs.size(0)
84 | xs = Variable(xs.view(batch_size, 1, 64, 64).cuda(), volatile=True)
85 | qz_params[n:n + batch_size] = vae.encoder.forward(xs).view(batch_size, vae.z_dim, nparams).data
86 | n += batch_size
87 |
88 | qz_params = Variable(qz_params.view(3, 6, 40, 32, 32, K, nparams).cuda())
89 | qz_samples = vae.q_dist.sample(params=qz_params)
90 |
91 | print('Estimating marginal entropies.')
92 | # marginal entropies
93 | marginal_entropies = estimate_entropies(
94 | qz_samples.view(N, K).transpose(0, 1),
95 | qz_params.view(N, K, nparams),
96 | vae.q_dist)
97 |
98 | marginal_entropies = marginal_entropies.cpu()
99 | cond_entropies = torch.zeros(4, K)
100 |
101 | print('Estimating conditional entropies for scale.')
102 | for i in range(6):
103 | qz_samples_scale = qz_samples[:, i, :, :, :, :].contiguous()
104 | qz_params_scale = qz_params[:, i, :, :, :, :].contiguous()
105 |
106 | cond_entropies_i = estimate_entropies(
107 | qz_samples_scale.view(N // 6, K).transpose(0, 1),
108 | qz_params_scale.view(N // 6, K, nparams),
109 | vae.q_dist)
110 |
111 | cond_entropies[0] += cond_entropies_i.cpu() / 6
112 |
113 | print('Estimating conditional entropies for orientation.')
114 | for i in range(40):
115 | qz_samples_scale = qz_samples[:, :, i, :, :, :].contiguous()
116 | qz_params_scale = qz_params[:, :, i, :, :, :].contiguous()
117 |
118 | cond_entropies_i = estimate_entropies(
119 | qz_samples_scale.view(N // 40, K).transpose(0, 1),
120 | qz_params_scale.view(N // 40, K, nparams),
121 | vae.q_dist)
122 |
123 | cond_entropies[1] += cond_entropies_i.cpu() / 40
124 |
125 | print('Estimating conditional entropies for pos x.')
126 | for i in range(32):
127 | qz_samples_scale = qz_samples[:, :, :, i, :, :].contiguous()
128 | qz_params_scale = qz_params[:, :, :, i, :, :].contiguous()
129 |
130 | cond_entropies_i = estimate_entropies(
131 | qz_samples_scale.view(N // 32, K).transpose(0, 1),
132 | qz_params_scale.view(N // 32, K, nparams),
133 | vae.q_dist)
134 |
135 | cond_entropies[2] += cond_entropies_i.cpu() / 32
136 |
137 | print('Estimating conditional entropies for pox y.')
138 | for i in range(32):
139 | qz_samples_scale = qz_samples[:, :, :, :, i, :].contiguous()
140 | qz_params_scale = qz_params[:, :, :, :, i, :].contiguous()
141 |
142 | cond_entropies_i = estimate_entropies(
143 | qz_samples_scale.view(N // 32, K).transpose(0, 1),
144 | qz_params_scale.view(N // 32, K, nparams),
145 | vae.q_dist)
146 |
147 | cond_entropies[3] += cond_entropies_i.cpu() / 32
148 |
149 | metric = compute_metric_shapes(marginal_entropies, cond_entropies)
150 | return metric, marginal_entropies, cond_entropies
151 |
152 |
153 | def mutual_info_metric_faces(vae, shapes_dataset):
154 | dataset_loader = DataLoader(shapes_dataset, batch_size=1000, num_workers=1, shuffle=False)
155 |
156 | N = len(dataset_loader.dataset) # number of data samples
157 | K = vae.z_dim # number of latent variables
158 | nparams = vae.q_dist.nparams
159 | vae.eval()
160 |
161 | print('Computing q(z|x) distributions.')
162 | qz_params = torch.Tensor(N, K, nparams)
163 |
164 | n = 0
165 | for xs in dataset_loader:
166 | batch_size = xs.size(0)
167 | xs = Variable(xs.view(batch_size, 1, 64, 64).cuda(), volatile=True)
168 | qz_params[n:n + batch_size] = vae.encoder.forward(xs).view(batch_size, vae.z_dim, nparams).data
169 | n += batch_size
170 |
171 | qz_params = Variable(qz_params.view(50, 21, 11, 11, K, nparams).cuda())
172 | qz_samples = vae.q_dist.sample(params=qz_params)
173 |
174 | print('Estimating marginal entropies.')
175 | # marginal entropies
176 | marginal_entropies = estimate_entropies(
177 | qz_samples.view(N, K).transpose(0, 1),
178 | qz_params.view(N, K, nparams),
179 | vae.q_dist)
180 |
181 | marginal_entropies = marginal_entropies.cpu()
182 | cond_entropies = torch.zeros(3, K)
183 |
184 | print('Estimating conditional entropies for azimuth.')
185 | for i in range(21):
186 | qz_samples_pose_az = qz_samples[:, i, :, :, :].contiguous()
187 | qz_params_pose_az = qz_params[:, i, :, :, :].contiguous()
188 |
189 | cond_entropies_i = estimate_entropies(
190 | qz_samples_pose_az.view(N // 21, K).transpose(0, 1),
191 | qz_params_pose_az.view(N // 21, K, nparams),
192 | vae.q_dist)
193 |
194 | cond_entropies[0] += cond_entropies_i.cpu() / 21
195 |
196 | print('Estimating conditional entropies for elevation.')
197 | for i in range(11):
198 | qz_samples_pose_el = qz_samples[:, :, i, :, :].contiguous()
199 | qz_params_pose_el = qz_params[:, :, i, :, :].contiguous()
200 |
201 | cond_entropies_i = estimate_entropies(
202 | qz_samples_pose_el.view(N // 11, K).transpose(0, 1),
203 | qz_params_pose_el.view(N // 11, K, nparams),
204 | vae.q_dist)
205 |
206 | cond_entropies[1] += cond_entropies_i.cpu() / 11
207 |
208 | print('Estimating conditional entropies for lighting.')
209 | for i in range(11):
210 | qz_samples_lighting = qz_samples[:, :, :, i, :].contiguous()
211 | qz_params_lighting = qz_params[:, :, :, i, :].contiguous()
212 |
213 | cond_entropies_i = estimate_entropies(
214 | qz_samples_lighting.view(N // 11, K).transpose(0, 1),
215 | qz_params_lighting.view(N // 11, K, nparams),
216 | vae.q_dist)
217 |
218 | cond_entropies[2] += cond_entropies_i.cpu() / 11
219 |
220 | metric = compute_metric_faces(marginal_entropies, cond_entropies)
221 | return metric, marginal_entropies, cond_entropies
222 |
223 |
224 | if __name__ == '__main__':
225 | import argparse
226 | parser = argparse.ArgumentParser()
227 | parser.add_argument('--checkpt', required=True)
228 | parser.add_argument('--gpu', type=int, default=0)
229 | parser.add_argument('--save', type=str, default='.')
230 | args = parser.parse_args()
231 |
232 | if args.gpu != 0:
233 | torch.cuda.set_device(args.gpu)
234 | vae, dataset, cpargs = load_model_and_dataset(args.checkpt)
235 | metric, marginal_entropies, cond_entropies = eval('mutual_info_metric_' + cpargs.dataset)(vae, dataset)
236 | torch.save({
237 | 'metric': metric,
238 | 'marginal_entropies': marginal_entropies,
239 | 'cond_entropies': cond_entropies,
240 | }, os.path.join(args.save, 'disentanglement_metric.pth'))
241 | print('MIG: {:.2f}'.format(metric))
242 |
--------------------------------------------------------------------------------
/baseline_models/icebeem/models/nets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 |
5 |
6 | class smoothReLU(nn.Module):
7 | """
8 | smooth ReLU activation function
9 | """
10 |
11 | def __init__(self, beta=1):
12 | super().__init__()
13 | self.beta = 1
14 |
15 | def forward(self, x):
16 | return x / (1 + torch.exp(-self.beta * x))
17 |
18 |
19 | class LeafParam(nn.Module):
20 | """
21 | just ignores the input and outputs a parameter tensor
22 | """
23 |
24 | def __init__(self, n):
25 | super().__init__()
26 | self.p = nn.Parameter(torch.zeros(1, n))
27 |
28 | def forward(self, x):
29 | return self.p.expand(x.size(0), self.p.size(1))
30 |
31 |
32 | class PositionalEncoder(nn.Module):
33 | """
34 | Each dimension of the input gets expanded out with sins/coses
35 | to "carve" out the space. Useful in low-dimensional cases with
36 | tightly "curled up" data.
37 | """
38 |
39 | def __init__(self, freqs=(.5, 1, 2, 4, 8)):
40 | super().__init__()
41 | self.freqs = freqs
42 |
43 | def forward(self, x):
44 | sines = [torch.sin(x * f) for f in self.freqs]
45 | coses = [torch.cos(x * f) for f in self.freqs]
46 | out = torch.cat(sines + coses, dim=1)
47 | return out
48 |
49 |
50 | class MLP4(nn.Module):
51 | """ a simple 4-layer MLP4 """
52 |
53 | def __init__(self, nin, nout, nh):
54 | super().__init__()
55 | self.net = nn.Sequential(
56 | nn.Linear(nin, nh),
57 | nn.LeakyReLU(0.2),
58 | nn.Linear(nh, nh),
59 | nn.LeakyReLU(0.2),
60 | nn.Linear(nh, nh),
61 | nn.LeakyReLU(0.2),
62 | nn.Linear(nh, nout),
63 | )
64 |
65 | def forward(self, x):
66 | return self.net(x)
67 |
68 |
69 | class PosEncMLP(nn.Module):
70 | """
71 | Position Encoded MLP4, where the first layer performs position encoding.
72 | Each dimension of the input gets transformed to len(freqs)*2 dimensions
73 | using a fixed transformation of sin/cos of given frequencies.
74 | """
75 |
76 | def __init__(self, nin, nout, nh, freqs=(.5, 1, 2, 4, 8)):
77 | super().__init__()
78 | self.net = nn.Sequential(
79 | PositionalEncoder(freqs),
80 | MLP4(nin * len(freqs) * 2, nout, nh),
81 | )
82 |
83 | def forward(self, x):
84 | return self.net(x)
85 |
86 |
87 | class MLPlayer(nn.Module):
88 | """
89 | implement basic module for MLP
90 |
91 | note that this module keeps the dimensions fixed! will implement a mapping from a
92 | vector of dimension input_size to another vector of dimension input_size
93 | """
94 |
95 | def __init__(self, input_size, output_size=None, activation_function=nn.functional.relu, use_bn=False):
96 | super().__init__()
97 | if output_size is None:
98 | output_size = input_size
99 | self.activation_function = activation_function
100 | self.linear_layer = nn.Linear(input_size, output_size)
101 | self.use_bn = use_bn
102 | self.bn_layer = nn.BatchNorm1d(input_size)
103 |
104 | def forward(self, x):
105 | if self.use_bn:
106 | x = self.bn_layer(x)
107 | linear_act = self.linear_layer(x)
108 | H_x = self.activation_function(linear_act)
109 | return H_x
110 |
111 |
112 | class MLP(nn.Module):
113 | """
114 | define a MLP network - this is a more general class than MLP4 above, allows for user to specify
115 | the dimensions at each layer of the network
116 | """
117 |
118 | def __init__(self, input_size, hidden_size, n_layers, output_size=None, activation_function=F.relu, use_bn=False):
119 | """
120 | Input:
121 | - input_size : dimension of input data (e.g., 784 for MNIST)
122 | - hidden_size : list of hidden representations, one entry per layer
123 | - n_layers : number of hidden layers
124 | """
125 | super().__init__()
126 |
127 | if output_size is None:
128 | output_size = 1 # because we approximating a log density, output should be scalar!
129 |
130 | self.use_bn = use_bn
131 | self.activation_function = activation_function
132 | self.linear1st = nn.Linear(input_size, hidden_size[0]) # map from data dim to dimension of hidden units
133 | self.Layers = nn.ModuleList([MLPlayer(hidden_size[i - 1], hidden_size[i],
134 | activation_function=self.activation_function, use_bn=self.use_bn) for i in
135 | range(1, n_layers)])
136 | self.linearLast = nn.Linear(hidden_size[-1],
137 | output_size) # map from dimension of hidden units to dimension of output
138 |
139 | def forward(self, x):
140 | """
141 | forward pass through resnet
142 | """
143 | x = self.linear1st(x)
144 | for current_layer in self.Layers:
145 | x = current_layer(x)
146 | x = self.linearLast(x)
147 | return x
148 |
149 |
150 | class CleanMLP(nn.Module):
151 | def __init__(self, input_size, hidden_size, n_hidden, output_size, activation='lrelu', batch_norm=False):
152 | super().__init__()
153 |
154 | self.input_size = input_size
155 | self.output_size = output_size
156 | self.hidden_size = hidden_size
157 | self.n_hidden = n_hidden
158 | self.activation = activation
159 | self.batch_norm = batch_norm
160 |
161 | if activation == 'lrelu':
162 | act = nn.LeakyReLU(0.2, inplace=True)
163 | elif activation == 'relu':
164 | act = nn.ReLU()
165 | else:
166 | raise ValueError('wrong activation')
167 |
168 | # construct model
169 | if n_hidden == 0:
170 | modules = [nn.Linear(input_size, output_size)]
171 | else:
172 | modules = [nn.Linear(input_size, hidden_size), act] + batch_norm * [nn.BatchNorm1d(hidden_size)]
173 |
174 | for i in range(n_hidden - 1):
175 | modules += [nn.Linear(hidden_size, hidden_size), act] + batch_norm * [nn.BatchNorm1d(hidden_size)]
176 |
177 | modules += [nn.Linear(hidden_size, output_size)]
178 |
179 | self.net = nn.Sequential(*modules)
180 |
181 | def forward(self, x, y=None):
182 | return self.net(x)
183 |
184 |
185 | class SimpleLinear(nn.Linear):
186 | """
187 | a wrapper around nn.Linear that defines custom fields
188 | """
189 |
190 | def __init__(self, nin, nout, bias=False):
191 | super().__init__(nin, nout, bias=bias)
192 | self.input_size = nin
193 | self.output_size = nout
194 |
195 |
196 | class FullMLP(nn.Module):
197 | def __init__(self, config):
198 | super().__init__()
199 | self.num_classes = config.model.num_classes
200 | self.image_size = config.data.image_size
201 | self.n_channels = config.data.channels
202 | self.ngf = ngf = config.model.ngf
203 |
204 | self.input_size = config.data.image_size ** 2 * config.data.channels
205 | self.output_size = self.input_size
206 | if config.model.final_layer:
207 | self.output_size = config.model.feature_size
208 |
209 | self.linear = nn.Sequential(
210 | nn.Linear(self.input_size, ngf * 8),
211 | nn.LeakyReLU(inplace=True, negative_slope=.1),
212 | nn.Linear(ngf * 8, ngf * 6),
213 | nn.LeakyReLU(inplace=True, negative_slope=.1),
214 | nn.Dropout(p=0.1),
215 | nn.Linear(ngf * 6, ngf * 4),
216 | nn.LeakyReLU(inplace=True, negative_slope=.1),
217 | nn.Linear(ngf * 4, ngf * 4),
218 | nn.LeakyReLU(inplace=True, negative_slope=.1),
219 | nn.Linear(ngf * 4, self.output_size)
220 | )
221 |
222 | def forward(self, x):
223 | output = x.view(x.shape[0], -1)
224 | output = self.linear(output)
225 | return output
226 |
227 |
228 | class ConvMLP(nn.Module):
229 | def __init__(self, config):
230 | super().__init__()
231 | self.num_classes = config.model.num_classes
232 | self.image_size = im = config.data.image_size
233 | self.n_channels = nc = config.data.channels
234 | self.ngf = ngf = config.model.ngf
235 |
236 | self.input_size = config.data.image_size ** 2 * config.data.channels
237 | self.output_size = self.input_size
238 | if config.model.final_layer:
239 | self.output_size = config.model.feature_size
240 |
241 | # convolutional bit is [(conv, bn, relu, maxpool)*2, resize_conv)
242 | self.conv = nn.Sequential(
243 | # input is (nc, im, im)
244 | nn.Conv2d(nc, ngf // 2, 3, 1, 1), # (ngf/2, im, im)
245 | nn.BatchNorm2d(ngf // 2), # (ngf/2, im, im)
246 | nn.ReLU(inplace=True), # (ngf/2, im, im)
247 | nn.Conv2d(ngf // 2, ngf, 3, 1, 1), # (ngf, im, im)
248 | nn.BatchNorm2d(ngf), # (ngf, im, im)
249 | nn.ReLU(inplace=True), # (ngf, im, im)
250 | nn.MaxPool2d(kernel_size=2, stride=2), # (ngf, im/2, im/2)
251 | nn.Conv2d(ngf, ngf * 2, 3, 1, 1), # (ngf*2, im/2, im/2)
252 | nn.BatchNorm2d(ngf * 2), # (ngf*2, im/2, im/2)
253 | nn.ReLU(inplace=True), # (ngf*2, im/2, im/2)
254 | nn.Conv2d(ngf * 2, ngf * 4, 3, 1, 1), # (ngf*4, im/2, im/2)
255 | nn.BatchNorm2d(ngf * 4), # (ngf*4, im/2, im/2)
256 | nn.ReLU(inplace=True), # (ngf*4, im/2, im/2)
257 | nn.MaxPool2d(kernel_size=2, stride=2), # (ngf*4, im/4, im/4)
258 | nn.Conv2d(ngf * 4, ngf * 4, im // 4, 1, 0) # (ngf*4, 1, 1)
259 | )
260 | # linear bit is [drop, (lin, lrelu)*2, lin]
261 | self.linear = nn.Sequential(
262 | nn.Dropout(p=0.1),
263 | nn.Linear(ngf * 4, ngf * 4),
264 | # nn.LeakyReLU(inplace=True, negative_slope=.1),
265 | # nn.Linear(ngf * 2, ngf * 2),
266 | nn.LeakyReLU(inplace=True, negative_slope=.1),
267 | nn.Linear(ngf * 4, self.output_size)
268 | )
269 |
270 | def forward(self, x):
271 | h = self.conv(x).squeeze()
272 | output = self.linear(h)
273 | return output
274 |
--------------------------------------------------------------------------------